Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ConvLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| dilation=1, | |
| bias=True, | |
| groups=1, | |
| norm="in", | |
| nonlinear="relu", | |
| ): | |
| super(ConvLayer, self).__init__() | |
| reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 | |
| self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
| self.conv2d = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| groups=groups, | |
| bias=bias, | |
| dilation=dilation, | |
| ) | |
| self.norm = norm | |
| self.nonlinear = nonlinear | |
| if norm == "bn": | |
| self.normalization = nn.BatchNorm2d(out_channels) | |
| elif norm == "in": | |
| self.normalization = nn.InstanceNorm2d(out_channels, affine=False) | |
| else: | |
| self.normalization = None | |
| if nonlinear == "relu": | |
| self.activation = nn.ReLU(inplace=True) | |
| elif nonlinear == "leakyrelu": | |
| self.activation = nn.LeakyReLU(0.2) | |
| elif nonlinear == "PReLU": | |
| self.activation = nn.PReLU() | |
| else: | |
| self.activation = None | |
| def forward(self, x): | |
| out = self.conv2d(self.reflection_pad(x)) | |
| if self.normalization is not None: | |
| out = self.normalization(out) | |
| if self.activation is not None: | |
| out = self.activation(out) | |
| return out | |
| class Aggreation(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3): | |
| super(Aggreation, self).__init__() | |
| self.attention = SelfAttention(in_channels, k=8, nonlinear="relu") | |
| self.conv = ConvLayer( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=1, | |
| nonlinear="leakyrelu", | |
| norm=None, | |
| ) | |
| def forward(self, x): | |
| return self.conv(self.attention(x)) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, channels, k, nonlinear="relu"): | |
| super(SelfAttention, self).__init__() | |
| self.channels = channels | |
| self.k = k | |
| self.nonlinear = nonlinear | |
| self.linear1 = nn.Linear(channels, channels // k) | |
| self.linear2 = nn.Linear(channels // k, channels) | |
| self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) | |
| if nonlinear == "relu": | |
| self.activation = nn.ReLU(inplace=True) | |
| elif nonlinear == "leakyrelu": | |
| self.activation = nn.LeakyReLU(0.2) | |
| elif nonlinear == "PReLU": | |
| self.activation = nn.PReLU() | |
| else: | |
| raise ValueError | |
| def attention(self, x): | |
| N, C, H, W = x.size() | |
| out = torch.flatten(self.global_pooling(x), 1) | |
| out = self.activation(self.linear1(out)) | |
| out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1) | |
| return out.mul(x) | |
| def forward(self, x): | |
| return self.attention(x) | |
| class SPP(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, num_layers=4, interpolation_type="bilinear" | |
| ): | |
| super(SPP, self).__init__() | |
| self.conv = nn.ModuleList() | |
| self.num_layers = num_layers | |
| self.interpolation_type = interpolation_type | |
| for _ in range(self.num_layers): | |
| self.conv.append( | |
| ConvLayer( | |
| in_channels, | |
| in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| dilation=1, | |
| nonlinear="leakyrelu", | |
| norm=None, | |
| ) | |
| ) | |
| self.fusion = ConvLayer( | |
| (in_channels * (self.num_layers + 1)), | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| norm="False", | |
| nonlinear="leakyrelu", | |
| ) | |
| def forward(self, x): | |
| N, C, H, W = x.size() | |
| out = [] | |
| for level in range(self.num_layers): | |
| out.append( | |
| F.interpolate( | |
| self.conv[level]( | |
| F.avg_pool2d( | |
| x, | |
| kernel_size=2 * 2 ** (level + 1), | |
| stride=2 * 2 ** (level + 1), | |
| padding=2 * 2 ** (level + 1) % 2, | |
| ) | |
| ), | |
| size=(H, W), | |
| mode=self.interpolation_type, | |
| ) | |
| ) | |
| out.append(x) | |
| return self.fusion(torch.cat(out, dim=1)) | |