Ray-1026
update
a856109
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
dilation=1,
bias=True,
groups=1,
norm="in",
nonlinear="relu",
):
super(ConvLayer, self).__init__()
reflection_padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
bias=bias,
dilation=dilation,
)
self.norm = norm
self.nonlinear = nonlinear
if norm == "bn":
self.normalization = nn.BatchNorm2d(out_channels)
elif norm == "in":
self.normalization = nn.InstanceNorm2d(out_channels, affine=False)
else:
self.normalization = None
if nonlinear == "relu":
self.activation = nn.ReLU(inplace=True)
elif nonlinear == "leakyrelu":
self.activation = nn.LeakyReLU(0.2)
elif nonlinear == "PReLU":
self.activation = nn.PReLU()
else:
self.activation = None
def forward(self, x):
out = self.conv2d(self.reflection_pad(x))
if self.normalization is not None:
out = self.normalization(out)
if self.activation is not None:
out = self.activation(out)
return out
class Aggreation(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3):
super(Aggreation, self).__init__()
self.attention = SelfAttention(in_channels, k=8, nonlinear="relu")
self.conv = ConvLayer(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dilation=1,
nonlinear="leakyrelu",
norm=None,
)
def forward(self, x):
return self.conv(self.attention(x))
class SelfAttention(nn.Module):
def __init__(self, channels, k, nonlinear="relu"):
super(SelfAttention, self).__init__()
self.channels = channels
self.k = k
self.nonlinear = nonlinear
self.linear1 = nn.Linear(channels, channels // k)
self.linear2 = nn.Linear(channels // k, channels)
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
if nonlinear == "relu":
self.activation = nn.ReLU(inplace=True)
elif nonlinear == "leakyrelu":
self.activation = nn.LeakyReLU(0.2)
elif nonlinear == "PReLU":
self.activation = nn.PReLU()
else:
raise ValueError
def attention(self, x):
N, C, H, W = x.size()
out = torch.flatten(self.global_pooling(x), 1)
out = self.activation(self.linear1(out))
out = torch.sigmoid(self.linear2(out)).view(N, C, 1, 1)
return out.mul(x)
def forward(self, x):
return self.attention(x)
class SPP(nn.Module):
def __init__(
self, in_channels, out_channels, num_layers=4, interpolation_type="bilinear"
):
super(SPP, self).__init__()
self.conv = nn.ModuleList()
self.num_layers = num_layers
self.interpolation_type = interpolation_type
for _ in range(self.num_layers):
self.conv.append(
ConvLayer(
in_channels,
in_channels,
kernel_size=1,
stride=1,
dilation=1,
nonlinear="leakyrelu",
norm=None,
)
)
self.fusion = ConvLayer(
(in_channels * (self.num_layers + 1)),
out_channels,
kernel_size=3,
stride=1,
norm="False",
nonlinear="leakyrelu",
)
def forward(self, x):
N, C, H, W = x.size()
out = []
for level in range(self.num_layers):
out.append(
F.interpolate(
self.conv[level](
F.avg_pool2d(
x,
kernel_size=2 * 2 ** (level + 1),
stride=2 * 2 ** (level + 1),
padding=2 * 2 ** (level + 1) % 2,
)
),
size=(H, W),
mode=self.interpolation_type,
)
)
out.append(x)
return self.fusion(torch.cat(out, dim=1))