Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import math | |
| def normalize_digraph(A): | |
| b, n, _ = A.shape | |
| node_degrees = A.detach().sum(dim = -1) | |
| degs_inv_sqrt = node_degrees ** -0.5 | |
| norm_degs_matrix = torch.eye(n) | |
| dev = A.get_device() | |
| if dev >= 0: | |
| norm_degs_matrix = norm_degs_matrix.to(dev) | |
| norm_degs_matrix = norm_degs_matrix.view(1, n, n) * degs_inv_sqrt.view(b, n, 1) | |
| norm_A = torch.bmm(torch.bmm(norm_degs_matrix,A),norm_degs_matrix) | |
| return norm_A | |
| class GNN(nn.Module): | |
| def __init__(self, in_channels, num_classes, neighbor_num=4, metric='dots'): | |
| super(GNN, self).__init__() | |
| # in_channels: dim of node feature | |
| # num_classes: num of nodes | |
| # neighbor_num: K in paper and we select the top-K nearest neighbors for each node feature. | |
| # metric: metric for assessing node similarity. Used in FGG module to build a dynamical graph | |
| # X' = ReLU(X + BN(V(X) + A x U(X)) ) | |
| self.in_channels = in_channels | |
| self.num_classes = num_classes | |
| self.relu = nn.ReLU() | |
| self.metric = metric | |
| self.neighbor_num = neighbor_num | |
| # network | |
| self.U = nn.Linear(self.in_channels,self.in_channels) | |
| self.V = nn.Linear(self.in_channels,self.in_channels) | |
| self.bnv = nn.BatchNorm1d(num_classes) | |
| # init | |
| self.U.weight.data.normal_(0, math.sqrt(2. / self.in_channels)) | |
| self.V.weight.data.normal_(0, math.sqrt(2. / self.in_channels)) | |
| self.bnv.weight.data.fill_(1) | |
| self.bnv.bias.data.zero_() | |
| def forward(self, x): | |
| b, n, c = x.shape | |
| # build dynamical graph | |
| if self.metric == 'dots': | |
| si = x.detach() | |
| si = torch.einsum('b i j , b j k -> b i k', si, si.transpose(1, 2)) | |
| threshold = si.topk(k=self.neighbor_num, dim=-1, largest=True)[0][:, :, -1].view(b, n, 1) | |
| adj = (si >= threshold).float() | |
| elif self.metric == 'cosine': | |
| si = x.detach() | |
| si = F.normalize(si, p=2, dim=-1) | |
| si = torch.einsum('b i j , b j k -> b i k', si, si.transpose(1, 2)) | |
| threshold = si.topk(k=self.neighbor_num, dim=-1, largest=True)[0][:, :, -1].view(b, n, 1) | |
| adj = (si >= threshold).float() | |
| elif self.metric == 'l1': | |
| si = x.detach().repeat(1, n, 1).view(b, n, n, c) | |
| si = torch.abs(si.transpose(1, 2) - si) | |
| si = si.sum(dim=-1) | |
| threshold = si.topk(k=self.neighbor_num, dim=-1, largest=False)[0][:, :, -1].view(b, n, 1) | |
| adj = (si <= threshold).float() | |
| else: | |
| raise Exception("Error: wrong metric: ", self.metric) | |
| # GNN process | |
| A = normalize_digraph(adj) | |
| aggregate = torch.einsum('b i j, b j k->b i k', A, self.V(x)) | |
| x = self.relu(x + self.bnv(aggregate + self.U(x))) | |
| return x | |
| class Head(nn.Module): | |
| def __init__(self, in_channels, num_classes, neighbor_num=4, metric='dots'): | |
| super(Head, self).__init__() | |
| self.in_channels = in_channels | |
| self.num_classes = num_classes | |
| class_linear_layers = [] | |
| for i in range(self.num_classes): | |
| layer = nn.Linear(self.in_channels, self.in_channels) | |
| class_linear_layers += [layer] | |
| self.class_linears = nn.ModuleList(class_linear_layers) | |
| self.gnn = GNN(self.in_channels, self.num_classes,neighbor_num=neighbor_num,metric=metric) | |
| self.sc = nn.Parameter(torch.FloatTensor(torch.zeros(self.num_classes, self.in_channels))) | |
| self.relu = nn.ReLU() | |
| nn.init.xavier_uniform_(self.sc) | |
| def forward(self, x): | |
| # AFG | |
| f_u = [] | |
| for i, layer in enumerate(self.class_linears): | |
| f_u.append(layer(x).unsqueeze(1)) | |
| f_u = torch.cat(f_u, dim=1) | |
| # f_v = f_u.mean(dim=-2) | |
| # FGG | |
| f_v = self.gnn(f_u) | |
| # f_v = self.gnn(f_v) | |
| b, n, c = f_v.shape | |
| sc = self.sc | |
| sc = self.relu(sc) | |
| sc = F.normalize(sc, p=2, dim=-1) | |
| cl = F.normalize(f_v, p=2, dim=-1) | |
| cl = (cl * sc.view(1, n, c)).sum(dim=-1) | |
| return cl |