Ray-1026
update
a856109
import numbers
import einops
from einops import rearrange
from .backbone import *
from .blocks import *
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_features, in_features, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(in_features, in_features, 3, padding=1),
)
def forward(self, x):
return x + self.block(x)
def gauss_kernel(channels=3):
kernel = torch.tensor(
[
[1.0, 4.0, 6.0, 4.0, 1],
[4.0, 16.0, 24.0, 16.0, 4.0],
[6.0, 24.0, 36.0, 24.0, 6.0],
[4.0, 16.0, 24.0, 16.0, 4.0],
[1.0, 4.0, 6.0, 4.0, 1.0],
]
)
kernel /= 256.0
kernel = kernel.repeat(channels, 1, 1, 1)
return kernel
class LapPyramidConv(nn.Module):
def __init__(self, num_high=4):
super(LapPyramidConv, self).__init__()
self.num_high = num_high
self.kernel = gauss_kernel()
def downsample(self, x):
return x[:, :, ::2, ::2]
def upsample(self, x):
cc = torch.cat(
[
x,
torch.zeros(
x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device
),
],
dim=3,
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat(
[
cc,
torch.zeros(
x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device
),
],
dim=3,
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2)
return self.conv_gauss(x_up, 4 * self.kernel)
def conv_gauss(self, img, kernel):
# 对最后两个维度进行填充,(左右上下)
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect")
# 分组卷积
out = torch.nn.functional.conv2d(
img, kernel.to(img.device), groups=img.shape[1]
)
return out
def pyramid_decom(self, img):
current = img
pyr = []
for _ in range(self.num_high):
filtered = self.conv_gauss(current, self.kernel)
down = self.downsample(filtered)
up = self.upsample(down)
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
up = nn.functional.interpolate(
up, size=(current.shape[2], current.shape[3])
)
diff = current - up
pyr.append(diff)
current = down
pyr.append(current)
return pyr
def pyramid_recons(self, pyr):
image = pyr[-1]
for level in reversed(pyr[:-1]):
up = self.upsample(image)
if up.shape[2] != level.shape[2] or up.shape[3] != level.shape[3]:
up = nn.functional.interpolate(
up, size=(level.shape[2], level.shape[3])
)
image = up + level
return image
class TransHigh(nn.Module):
def __init__(self, num_residual_blocks, num_high=3):
super(TransHigh, self).__init__()
self.num_high = num_high
blocks = [nn.Conv2d(9, 64, 3, padding=1), nn.LeakyReLU()]
for _ in range(num_residual_blocks):
blocks += [ResidualBlock(64)]
blocks += [nn.Conv2d(64, 3, 3, padding=1)]
self.model = nn.Sequential(*blocks)
channels = 3
# Stage1
self.block1_1 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=2,
norm=None,
nonlinear="leakyrelu",
)
self.block1_2 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=4,
norm=None,
nonlinear="leakyrelu",
)
self.aggreation1_rgb = Aggreation(
in_channels=channels * 3, out_channels=channels
)
# Stage2
self.block2_1 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=8,
norm=None,
nonlinear="leakyrelu",
)
self.block2_2 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=16,
norm=None,
nonlinear="leakyrelu",
)
self.aggreation2_rgb = Aggreation(
in_channels=channels * 3, out_channels=channels
)
# Stage3
self.block3_1 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=32,
norm=None,
nonlinear="leakyrelu",
)
self.block3_2 = ConvLayer(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
dilation=64,
norm=None,
nonlinear="leakyrelu",
)
self.aggreation3_rgb = Aggreation(
in_channels=channels * 3, out_channels=channels
)
# self.block_3 = NAFNet(middle_blk_num=2, enc_blk_nums=[
# 1,1], dec_blk_nums=[1,1])
self.trans_mask_block_1 = nn.Sequential(
nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
)
self.trans_mask_block_2 = nn.Sequential(
nn.Conv2d(3, 16, 1), nn.LeakyReLU(), nn.Conv2d(16, 3, 1)
)
# self.trans_mask_block = NAFNet(
# middle_blk_num=1, enc_blk_nums=[1], dec_blk_nums=[1])
# Stage3
self.spp_img = SPP(
in_channels=channels,
out_channels=channels,
num_layers=4,
interpolation_type="bicubic",
)
self.block4_1 = nn.Conv2d(
in_channels=channels, out_channels=3, kernel_size=1, stride=1
)
def forward(self, x, pyr_original, fake_low):
pyr_result = [fake_low]
mask = self.model(x)
mask = nn.functional.interpolate(
mask, size=(pyr_original[-2].shape[2], pyr_original[-2].shape[3])
)
mask = self.trans_mask_block_1(mask)
result_highfreq = torch.mul(pyr_original[-2], mask) + pyr_original[-2]
# result_highfreq = self.block_3(result_highfreq)
out1_1 = self.block1_1(result_highfreq)
out1_2 = self.block1_2(out1_1)
agg1_rgb = self.aggreation1_rgb(
torch.cat((result_highfreq, out1_1, out1_2), dim=1)
)
pyr_result.append(agg1_rgb)
mask = nn.functional.interpolate(
mask, size=(pyr_original[-3].shape[2], pyr_original[-3].shape[3])
)
mask = self.trans_mask_block_2(mask)
result_highfreq = torch.mul(pyr_original[-3], mask) + pyr_original[-3]
# result_highfreq = self.block_3(result_highfreq)
out2_1 = self.block2_1(result_highfreq)
out2_2 = self.block2_2(out2_1)
agg2_rgb = self.aggreation2_rgb(
torch.cat((result_highfreq, out2_1, out2_2), dim=1)
)
out3_1 = self.block3_1(agg2_rgb)
out3_2 = self.block3_2(out3_1)
agg3_rgb = self.aggreation3_rgb(torch.cat((agg2_rgb, out3_1, out3_2), dim=1))
spp_rgb = self.spp_img(agg3_rgb)
out_rgb = self.block4_1(spp_rgb)
pyr_result.append(out_rgb)
pyr_result.reverse()
return pyr_result
# Layer Norm
def to_3d(x):
return rearrange(x, "b c h w -> b (h w) c")
def to_4d(x, h, w):
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == "BiasFree":
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
# Axis-based Multi-head Self-Attention
class NextAttentionImplZ(nn.Module):
def __init__(self, num_dims, num_heads, bias) -> None:
super().__init__()
self.num_dims = num_dims
self.num_heads = num_heads
self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias)
self.q2 = nn.Conv2d(
num_dims * 3,
num_dims * 3,
kernel_size=3,
padding=1,
groups=num_dims * 3,
bias=bias,
)
self.q3 = nn.Conv2d(
num_dims * 3,
num_dims * 3,
kernel_size=3,
padding=1,
groups=num_dims * 3,
bias=bias,
)
self.fac = nn.Parameter(torch.ones(1))
self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias)
return
def forward(self, x):
# x: [n, c, h, w]
n, c, h, w = x.size()
n_heads, dim_head = self.num_heads, c // self.num_heads
def reshape(x):
return einops.rearrange(
x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head
)
qkv = self.q3(self.q2(self.q1(x)))
q, k, v = map(reshape, qkv.chunk(3, dim=1))
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
# fac = dim_head ** -0.5
res = k.transpose(-2, -1)
res = torch.matmul(q, res) * self.fac
res = torch.softmax(res, dim=-1)
res = torch.matmul(res, v)
res = einops.rearrange(
res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h
)
res = self.fin(res)
return res
# Axis-based Multi-head Self-Attention (row and col attention)
class NextAttentionZ(nn.Module):
def __init__(self, num_dims, num_heads=1, bias=True) -> None:
super().__init__()
assert num_dims % num_heads == 0
self.num_dims = num_dims
self.num_heads = num_heads
self.row_att = NextAttentionImplZ(num_dims, num_heads, bias)
self.col_att = NextAttentionImplZ(num_dims, num_heads, bias)
return
def forward(self, x: torch.Tensor):
assert len(x.size()) == 4
x = self.row_att(x)
x = x.transpose(-2, -1)
x = self.col_att(x)
x = x.transpose(-2, -1)
return x
# Dual Gated Feed-Forward Networ
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(
hidden_features * 2,
hidden_features * 2,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x2) * x1 + F.gelu(x1) * x2
x = self.project_out(x)
return x
# Axis-based Transformer Block
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
num_heads=1,
ffn_expansion_factor=2.66,
bias=True,
LayerNorm_type="WithBias",
):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = NextAttentionZ(dim, num_heads)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
# Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(
in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias
)
def forward(self, x):
x = self.proj(x)
return x
##########################################################################
# Resizing modules
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(
n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False
),
nn.PixelUnshuffle(2),
)
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, n_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(
n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False
),
nn.PixelShuffle(2),
)
def forward(self, x):
return self.body(x)
# Cross-layer Attention Fusion Block
class LAM_Module_v2(nn.Module):
"""Layer attention module"""
def __init__(self, in_dim, bias=True):
super(LAM_Module_v2, self).__init__()
self.chanel_in = in_dim
self.temperature = nn.Parameter(torch.ones(1))
self.qkv = nn.Conv2d(
self.chanel_in, self.chanel_in * 3, kernel_size=1, bias=bias
)
self.qkv_dwconv = nn.Conv2d(
self.chanel_in * 3,
self.chanel_in * 3,
kernel_size=3,
stride=1,
padding=1,
groups=self.chanel_in * 3,
bias=bias,
)
self.project_out = nn.Conv2d(
self.chanel_in, self.chanel_in, kernel_size=1, bias=bias
)
def forward(self, x):
"""
inputs :
x : input feature maps( B X N X C X H X W)
returns :
out : attention value + input feature
attention: B X N X N
"""
m_batchsize, N, C, height, width = x.size()
x_input = x.view(m_batchsize, N * C, height, width)
qkv = self.qkv_dwconv(self.qkv(x_input))
q, k, v = qkv.chunk(3, dim=1)
q = q.view(m_batchsize, N, -1)
k = k.view(m_batchsize, N, -1)
v = v.view(m_batchsize, N, -1)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out_1 = attn @ v
out_1 = out_1.view(m_batchsize, -1, height, width)
out_1 = self.project_out(out_1)
out_1 = out_1.view(m_batchsize, N, C, height, width)
out = out_1 + x
out = out.view(m_batchsize, -1, height, width)
return out
##########################################################################
# ---------- LLFormer -----------------------
class Backbone(nn.Module):
def __init__(
self,
inp_channels=3,
out_channels=3,
dim=3,
num_blocks=[1, 2, 4, 8],
num_refinement_blocks=1,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
attention=True,
):
super(Backbone, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_1 = nn.Sequential(
*[
TransformerBlock(
dim=dim,
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_blocks[0])
]
)
self.encoder_2 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_blocks[0])
]
)
self.encoder_3 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_blocks[0])
]
)
self.layer_fussion = LAM_Module_v2(in_dim=int(dim * 3))
self.conv_fuss = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
# self.latent = nn.Sequential(*[
# TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
# LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
# self.trans_low = NAFNet()
# self.coefficient_1_0 = nn.Parameter(torch.ones(
# (2, int(int(dim)))), requires_grad=attention)
self.latent_1 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_blocks[0])
]
)
"""
self.latent_2 = nn.Sequential(*[
TransformerBlock(dim=int(dim), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
LayerNorm_type=LayerNorm_type) for _ in range(num_blocks[0])])
"""
self.trans_low_1 = NAFNet(
middle_blk_num=10, enc_blk_nums=[1, 2, 4], dec_blk_nums=[4, 2, 1]
)
# self.trans_low_2 = NAFNet()
self.coefficient_1_0 = nn.Parameter(
torch.ones((2, int(int(dim)))), requires_grad=attention
)
# self.coefficient_2_0 = nn.Parameter(torch.ones(
# (2, int(int(dim)))), requires_grad=attention)
self.refinement_1 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_refinement_blocks)
]
)
self.refinement_2 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_refinement_blocks)
]
)
self.refinement_3 = nn.Sequential(
*[
TransformerBlock(
dim=int(dim),
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
)
for _ in range(num_refinement_blocks)
]
)
self.layer_fussion_2 = LAM_Module_v2(in_dim=int(dim * 3))
self.conv_fuss_2 = nn.Conv2d(int(dim * 3), int(dim), kernel_size=1, bias=bias)
self.output = nn.Conv2d(
int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
def forward(self, inp):
inp_enc_encoder1 = self.patch_embed(inp)
out_enc_encoder1 = self.encoder_1(inp_enc_encoder1)
out_enc_encoder2 = self.encoder_2(out_enc_encoder1)
out_enc_encoder3 = self.encoder_3(out_enc_encoder2)
inp_fusion_123 = torch.cat(
[
out_enc_encoder1.unsqueeze(1),
out_enc_encoder2.unsqueeze(1),
out_enc_encoder3.unsqueeze(1),
],
dim=1,
)
out_fusion_123 = self.layer_fussion(inp_fusion_123)
out_fusion_123 = self.conv_fuss(out_fusion_123)
# out_enc = self.trans_low(out_fusion_123)
# out_fusion_123 = self.latent(out_fusion_123)
# out = self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123 + self.coefficient_1_0[1, :][None, :,None, None] * out_enc
out_enc_1 = self.trans_low_1(out_fusion_123)
out_fusion_123_1 = self.latent_1(out_fusion_123)
out = (
self.coefficient_1_0[0, :][None, :, None, None] * out_fusion_123_1
+ self.coefficient_1_0[1, :][None, :, None, None] * out_enc_1
)
# out_enc_2 = self.trans_low_2(out)
# out_fusion_123_2 = self.latent_2(out)
# out = self.coefficient_2_0[0, :][None, :, None, None] * out_fusion_123_2 + self.coefficient_2_0[1, :][None, :,None, None] * out_enc_2
out_1 = self.refinement_1(out)
out_2 = self.refinement_2(out_1)
out_3 = self.refinement_3(out_2)
inp_fusion = torch.cat(
[out_1.unsqueeze(1), out_2.unsqueeze(1), out_3.unsqueeze(1)], dim=1
)
out_fusion_123 = self.layer_fussion_2(inp_fusion)
out = self.conv_fuss_2(out_fusion_123)
result = self.output(out)
return result
class Model(nn.Module):
def __init__(self, depth=2):
super(Model, self).__init__()
self.backbone = Backbone()
self.lap_pyramid = LapPyramidConv(depth)
self.trans_high = TransHigh(3, num_high=depth)
def forward(self, inp):
pyr_inp = self.lap_pyramid.pyramid_decom(img=inp)
out_low = self.backbone(pyr_inp[-1])
inp_up = nn.functional.interpolate(
pyr_inp[-1], size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
)
out_up = nn.functional.interpolate(
out_low, size=(pyr_inp[-2].shape[2], pyr_inp[-2].shape[3])
)
high_with_low = torch.cat([pyr_inp[-2], inp_up, out_up], 1)
pyr_inp_trans = self.trans_high(high_with_low, pyr_inp, out_low)
result = self.lap_pyramid.pyramid_recons(pyr_inp_trans)
return result
if __name__ == "__main__":
tensor = torch.randn(1, 3, 1024, 1024).cuda()
model = Model().cuda()
output = model(tensor)
print(output.shape)