Spaces:
Running
Running
| """ | |
| This code is refer from: | |
| https://github.com/THU-MIG/RepViT | |
| """ | |
| import torch.nn as nn | |
| import torch | |
| from torch.nn.init import constant_ | |
| def _make_divisible(v, divisor, min_value=None): | |
| """ | |
| This function is taken from the original tf repo. | |
| It ensures that all layers have a channel number that is divisible by 8 | |
| It can be seen here: | |
| https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | |
| :param v: | |
| :param divisor: | |
| :param min_value: | |
| :return: | |
| """ | |
| if min_value is None: | |
| min_value = divisor | |
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |
| # Make sure that round down does not go down by more than 10%. | |
| if new_v < 0.9 * v: | |
| new_v += divisor | |
| return new_v | |
| def make_divisible(v, divisor=8, min_value=None, round_limit=0.9): | |
| min_value = min_value or divisor | |
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |
| # Make sure that round down does not go down by more than 10%. | |
| if new_v < round_limit * v: | |
| new_v += divisor | |
| return new_v | |
| class SEModule(nn.Module): | |
| """SE Module as defined in original SE-Nets with a few additions | |
| Additions include: | |
| * divisor can be specified to keep channels % div == 0 (default: 8) | |
| * reduction channels can be specified directly by arg (if rd_channels is set) | |
| * reduction channels can be specified by float rd_ratio (default: 1/16) | |
| * global max pooling can be added to the squeeze aggregation | |
| * customizable activation, normalization, and gate layer | |
| """ | |
| def __init__( | |
| self, | |
| channels, | |
| rd_ratio=1.0 / 16, | |
| rd_channels=None, | |
| rd_divisor=8, | |
| act_layer=nn.ReLU, | |
| ): | |
| super(SEModule, self).__init__() | |
| if not rd_channels: | |
| rd_channels = make_divisible(channels * rd_ratio, | |
| rd_divisor, | |
| round_limit=0.0) | |
| self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) | |
| self.act = act_layer() | |
| self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) | |
| def forward(self, x): | |
| x_se = x.mean((2, 3), keepdim=True) | |
| x_se = self.fc1(x_se) | |
| x_se = self.act(x_se) | |
| x_se = self.fc2(x_se) | |
| return x * torch.sigmoid(x_se) | |
| class Conv2D_BN(nn.Sequential): | |
| def __init__( | |
| self, | |
| a, | |
| b, | |
| ks=1, | |
| stride=1, | |
| pad=0, | |
| dilation=1, | |
| groups=1, | |
| bn_weight_init=1, | |
| resolution=-10000, | |
| ): | |
| super().__init__() | |
| self.add_module( | |
| 'c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, | |
| bias=False)) | |
| self.add_module('bn', nn.BatchNorm2d(b)) | |
| constant_(self.bn.weight, bn_weight_init) | |
| constant_(self.bn.bias, 0) | |
| def fuse(self): | |
| c, bn = self._modules.values() | |
| w = bn.weight / (bn.running_var + bn.eps)**0.5 | |
| w = c.weight * w[:, None, None, None] | |
| b = bn.bias - bn.running_mean * bn.weight / \ | |
| (bn.running_var + bn.eps)**0.5 | |
| m = nn.Conv2d(w.size(1) * self.c.groups, | |
| w.size(0), | |
| w.shape[2:], | |
| stride=self.c.stride, | |
| padding=self.c.padding, | |
| dilation=self.c.dilation, | |
| groups=self.c.groups, | |
| device=c.weight.device) | |
| m.weight.data.copy_(w) | |
| m.bias.data.copy_(b) | |
| return m | |
| class Residual(torch.nn.Module): | |
| def __init__(self, m, drop=0.): | |
| super().__init__() | |
| self.m = m | |
| self.drop = drop | |
| def forward(self, x): | |
| if self.training and self.drop > 0: | |
| return x + self.m(x) * torch.rand( | |
| x.size(0), 1, 1, 1, device=x.device).ge_( | |
| self.drop).div(1 - self.drop).detach() | |
| else: | |
| return x + self.m(x) | |
| def fuse(self): | |
| if isinstance(self.m, Conv2D_BN): | |
| m = self.m.fuse() | |
| assert (m.groups == m.in_channels) | |
| identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) | |
| identity = nn.functional.pad(identity, [1, 1, 1, 1]) | |
| m.weight += identity.to(m.weight.device) | |
| return m | |
| elif isinstance(self.m, nn.Conv2d): | |
| m = self.m | |
| assert (m.groups != m.in_channels) | |
| identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) | |
| identity = nn.functional.pad(identity, [1, 1, 1, 1]) | |
| m.weight += identity.to(m.weight.device) | |
| return m | |
| else: | |
| return self | |
| class RepVGGDW(nn.Module): | |
| def __init__(self, ed) -> None: | |
| super().__init__() | |
| self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed) | |
| self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed) | |
| self.dim = ed | |
| self.bn = nn.BatchNorm2d(ed) | |
| def forward(self, x): | |
| return self.bn((self.conv(x) + self.conv1(x)) + x) | |
| def fuse(self): | |
| conv = self.conv.fuse() | |
| conv1 = self.conv1 | |
| conv_w = conv.weight | |
| conv_b = conv.bias | |
| conv1_w = conv1.weight | |
| conv1_b = conv1.bias | |
| conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1]) | |
| identity = nn.functional.pad( | |
| torch.ones(conv1_w.shape[0], | |
| conv1_w.shape[1], | |
| 1, | |
| 1, | |
| device=conv1_w.device), [1, 1, 1, 1]) | |
| final_conv_w = conv_w + conv1_w + identity | |
| final_conv_b = conv_b + conv1_b | |
| conv.weight.data.copy_(final_conv_w) | |
| conv.bias.data.copy_(final_conv_b) | |
| bn = self.bn | |
| w = bn.weight / (bn.running_var + bn.eps)**0.5 | |
| w = conv.weight * w[:, None, None, None] | |
| b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \ | |
| (bn.running_var + bn.eps)**0.5 | |
| conv.weight.data.copy_(w) | |
| conv.bias.data.copy_(b) | |
| return conv | |
| class RepViTBlock(nn.Module): | |
| def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, | |
| use_hs): | |
| super(RepViTBlock, self).__init__() | |
| self.identity = stride == 1 and inp == oup | |
| assert hidden_dim == 2 * inp | |
| if stride != 1: | |
| self.token_mixer = nn.Sequential( | |
| Conv2D_BN(inp, | |
| inp, | |
| kernel_size, | |
| stride, (kernel_size - 1) // 2, | |
| groups=inp), | |
| SEModule(inp, 0.25) if use_se else nn.Identity(), | |
| Conv2D_BN(inp, oup, ks=1, stride=1, pad=0), | |
| ) | |
| self.channel_mixer = Residual( | |
| nn.Sequential( | |
| # pw | |
| Conv2D_BN(oup, 2 * oup, 1, 1, 0), | |
| nn.GELU() if use_hs else nn.GELU(), | |
| # pw-linear | |
| Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0), | |
| )) | |
| else: | |
| assert self.identity | |
| self.token_mixer = nn.Sequential( | |
| RepVGGDW(inp), | |
| SEModule(inp, 0.25) if use_se else nn.Identity(), | |
| ) | |
| self.channel_mixer = Residual( | |
| nn.Sequential( | |
| # pw | |
| Conv2D_BN(inp, hidden_dim, 1, 1, 0), | |
| nn.GELU() if use_hs else nn.GELU(), | |
| # pw-linear | |
| Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0), | |
| )) | |
| def forward(self, x): | |
| return self.channel_mixer(self.token_mixer(x)) | |
| class RepViT(nn.Module): | |
| def __init__(self, cfgs, in_channels=3, out_indices=None): | |
| super(RepViT, self).__init__() | |
| # setting of inverted residual blocks | |
| self.cfgs = cfgs | |
| # building first layer | |
| input_channel = self.cfgs[0][2] | |
| patch_embed = nn.Sequential( | |
| Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1), | |
| nn.GELU(), | |
| Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1), | |
| ) | |
| layers = [patch_embed] | |
| # building inverted residual blocks | |
| block = RepViTBlock | |
| for k, t, c, use_se, use_hs, s in self.cfgs: | |
| output_channel = _make_divisible(c, 8) | |
| exp_size = _make_divisible(input_channel * t, 8) | |
| layers.append( | |
| block(input_channel, exp_size, output_channel, k, s, use_se, | |
| use_hs)) | |
| input_channel = output_channel | |
| self.features = nn.ModuleList(layers) | |
| self.out_indices = out_indices | |
| if out_indices is not None: | |
| self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices] | |
| else: | |
| self.out_channels = self.cfgs[-1][2] | |
| def forward(self, x): | |
| if self.out_indices is not None: | |
| return self.forward_det(x) | |
| return self.forward_rec(x) | |
| def forward_det(self, x): | |
| outs = [] | |
| for i, f in enumerate(self.features): | |
| x = f(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| return outs | |
| def forward_rec(self, x): | |
| for f in self.features: | |
| x = f(x) | |
| h = x.shape[2] | |
| x = nn.functional.avg_pool2d(x, [h, 2]) | |
| return x | |
| def RepSVTR(in_channels=3): | |
| """ | |
| Constructs a MobileNetV3-Large model | |
| """ | |
| # k, t, c, SE, HS, s | |
| cfgs = [ | |
| [3, 2, 96, 1, 0, 1], | |
| [3, 2, 96, 0, 0, 1], | |
| [3, 2, 96, 0, 0, 1], | |
| [3, 2, 192, 0, 1, (2, 1)], | |
| [3, 2, 192, 1, 1, 1], | |
| [3, 2, 192, 0, 1, 1], | |
| [3, 2, 192, 1, 1, 1], | |
| [3, 2, 192, 0, 1, 1], | |
| [3, 2, 192, 1, 1, 1], | |
| [3, 2, 192, 0, 1, 1], | |
| [3, 2, 384, 0, 1, (2, 1)], | |
| [3, 2, 384, 1, 1, 1], | |
| [3, 2, 384, 0, 1, 1], | |
| ] | |
| return RepViT(cfgs, in_channels=in_channels) | |
| def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]): | |
| """ | |
| Constructs a MobileNetV3-Large model | |
| """ | |
| # k, t, c, SE, HS, s | |
| cfgs = [ | |
| [3, 2, 48, 1, 0, 1], | |
| [3, 2, 48, 0, 0, 1], | |
| [3, 2, 96, 0, 0, 2], | |
| [3, 2, 96, 1, 0, 1], | |
| [3, 2, 96, 0, 0, 1], | |
| [3, 2, 192, 0, 1, 2], | |
| [3, 2, 192, 1, 1, 1], | |
| [3, 2, 192, 0, 1, 1], | |
| [3, 2, 192, 1, 1, 1], | |
| [3, 2, 192, 0, 1, 1], | |
| [3, 2, 384, 0, 1, 2], | |
| [3, 2, 384, 1, 1, 1], | |
| [3, 2, 384, 0, 1, 1], | |
| ] | |
| return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices) | |