mrrrme-emotion-ai / model /AU_model.py
michon's picture
Initial commit for standalone repo
de7b5f1
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import math
def normalize_digraph(A):
b, n, _ = A.shape
node_degrees = A.detach().sum(dim = -1)
degs_inv_sqrt = node_degrees ** -0.5
norm_degs_matrix = torch.eye(n)
dev = A.get_device()
if dev >= 0:
norm_degs_matrix = norm_degs_matrix.to(dev)
norm_degs_matrix = norm_degs_matrix.view(1, n, n) * degs_inv_sqrt.view(b, n, 1)
norm_A = torch.bmm(torch.bmm(norm_degs_matrix,A),norm_degs_matrix)
return norm_A
class GNN(nn.Module):
def __init__(self, in_channels, num_classes, neighbor_num=4, metric='dots'):
super(GNN, self).__init__()
# in_channels: dim of node feature
# num_classes: num of nodes
# neighbor_num: K in paper and we select the top-K nearest neighbors for each node feature.
# metric: metric for assessing node similarity. Used in FGG module to build a dynamical graph
# X' = ReLU(X + BN(V(X) + A x U(X)) )
self.in_channels = in_channels
self.num_classes = num_classes
self.relu = nn.ReLU()
self.metric = metric
self.neighbor_num = neighbor_num
# network
self.U = nn.Linear(self.in_channels,self.in_channels)
self.V = nn.Linear(self.in_channels,self.in_channels)
self.bnv = nn.BatchNorm1d(num_classes)
# init
self.U.weight.data.normal_(0, math.sqrt(2. / self.in_channels))
self.V.weight.data.normal_(0, math.sqrt(2. / self.in_channels))
self.bnv.weight.data.fill_(1)
self.bnv.bias.data.zero_()
def forward(self, x):
b, n, c = x.shape
# build dynamical graph
if self.metric == 'dots':
si = x.detach()
si = torch.einsum('b i j , b j k -> b i k', si, si.transpose(1, 2))
threshold = si.topk(k=self.neighbor_num, dim=-1, largest=True)[0][:, :, -1].view(b, n, 1)
adj = (si >= threshold).float()
elif self.metric == 'cosine':
si = x.detach()
si = F.normalize(si, p=2, dim=-1)
si = torch.einsum('b i j , b j k -> b i k', si, si.transpose(1, 2))
threshold = si.topk(k=self.neighbor_num, dim=-1, largest=True)[0][:, :, -1].view(b, n, 1)
adj = (si >= threshold).float()
elif self.metric == 'l1':
si = x.detach().repeat(1, n, 1).view(b, n, n, c)
si = torch.abs(si.transpose(1, 2) - si)
si = si.sum(dim=-1)
threshold = si.topk(k=self.neighbor_num, dim=-1, largest=False)[0][:, :, -1].view(b, n, 1)
adj = (si <= threshold).float()
else:
raise Exception("Error: wrong metric: ", self.metric)
# GNN process
A = normalize_digraph(adj)
aggregate = torch.einsum('b i j, b j k->b i k', A, self.V(x))
x = self.relu(x + self.bnv(aggregate + self.U(x)))
return x
class Head(nn.Module):
def __init__(self, in_channels, num_classes, neighbor_num=4, metric='dots'):
super(Head, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
class_linear_layers = []
for i in range(self.num_classes):
layer = nn.Linear(self.in_channels, self.in_channels)
class_linear_layers += [layer]
self.class_linears = nn.ModuleList(class_linear_layers)
self.gnn = GNN(self.in_channels, self.num_classes,neighbor_num=neighbor_num,metric=metric)
self.sc = nn.Parameter(torch.FloatTensor(torch.zeros(self.num_classes, self.in_channels)))
self.relu = nn.ReLU()
nn.init.xavier_uniform_(self.sc)
def forward(self, x):
# AFG
f_u = []
for i, layer in enumerate(self.class_linears):
f_u.append(layer(x).unsqueeze(1))
f_u = torch.cat(f_u, dim=1)
# f_v = f_u.mean(dim=-2)
# FGG
f_v = self.gnn(f_u)
# f_v = self.gnn(f_v)
b, n, c = f_v.shape
sc = self.sc
sc = self.relu(sc)
sc = F.normalize(sc, p=2, dim=-1)
cl = F.normalize(f_v, p=2, dim=-1)
cl = (cl * sc.view(1, n, c)).sum(dim=-1)
return cl