Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import math | |
| import torch | |
| from IndicPhotoOCR.detection import east_config as cfg | |
| from IndicPhotoOCR.detection import east_utils as utils | |
| def conv_bn(inp, oup, stride): | |
| return nn.Sequential( | |
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |
| nn.BatchNorm2d(oup), | |
| nn.ReLU6(inplace=True) | |
| ) | |
| class InvertedResidual(nn.Module): | |
| def __init__(self, inp, oup, stride, expand_ratio): | |
| super(InvertedResidual, self).__init__() | |
| self.stride = stride | |
| assert stride in [1, 2] | |
| hidden_dim = round(inp * expand_ratio) | |
| self.use_res_connect = self.stride == 1 and inp == oup | |
| if expand_ratio == 1: | |
| self.conv = nn.Sequential( | |
| # dw | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # pw-linear | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oup), | |
| ) | |
| else: | |
| self.conv = nn.Sequential( | |
| # pw | |
| nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # dw | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # pw-linear | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oup), | |
| ) | |
| def forward(self, x): | |
| if self.use_res_connect: | |
| return x + self.conv(x) | |
| else: | |
| return self.conv(x) | |
| class MobileNetV2(nn.Module): | |
| def __init__(self, width_mult=1.): | |
| super(MobileNetV2, self).__init__() | |
| block = InvertedResidual | |
| input_channel = 32 | |
| last_channel = 1280 | |
| interverted_residual_setting = [ | |
| # t, c, n, s | |
| [1, 16, 1, 1], | |
| [6, 24, 2, 2], | |
| [6, 32, 3, 2], | |
| [6, 64, 4, 2], | |
| [6, 96, 3, 1], | |
| [6, 160, 3, 2], | |
| # [6, 320, 1, 1], | |
| ] | |
| # building first layer | |
| # assert input_size % 32 == 0 | |
| input_channel = int(input_channel * width_mult) | |
| self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel | |
| self.features = [conv_bn(3, input_channel, 2)] | |
| # building inverted residual blocks | |
| for t, c, n, s in interverted_residual_setting: | |
| output_channel = int(c * width_mult) | |
| for i in range(n): | |
| if i == 0: | |
| self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) | |
| else: | |
| self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) | |
| input_channel = output_channel | |
| # make it nn.Sequential | |
| self.features = nn.Sequential(*self.features) | |
| self._initialize_weights() | |
| def forward(self, x): | |
| x = self.features(x) | |
| # x = x.mean(3).mean(2) | |
| # x = self.classifier(x) | |
| return x | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.Linear): | |
| n = m.weight.size(1) | |
| m.weight.data.normal_(0, 0.01) | |
| m.bias.data.zero_() | |
| def mobilenet(pretrained=True, **kwargs): | |
| """ | |
| Constructs a ResNet-50 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = MobileNetV2() | |
| if pretrained: | |
| model_dict = model.state_dict() | |
| pretrained_dict = torch.load(cfg.pretrained_basemodel_path,map_location=torch.device('cpu'), weights_only=True) | |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu | |
| # model.load_state_dict(state_dict) | |
| return model | |
| class East(nn.Module): | |
| def __init__(self): | |
| super(East, self).__init__() | |
| self.mobilenet = mobilenet(True) | |
| # self.si for stage i | |
| self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4]) | |
| self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7]) | |
| self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14]) | |
| self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17]) | |
| self.conv1 = nn.Conv2d(160+96, 128, 1) | |
| self.bn1 = nn.BatchNorm2d(128) | |
| self.relu1 = nn.ReLU() | |
| self.conv2 = nn.Conv2d(128, 128, 3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(128) | |
| self.relu2 = nn.ReLU() | |
| self.conv3 = nn.Conv2d(128+32, 64, 1) | |
| self.bn3 = nn.BatchNorm2d(64) | |
| self.relu3 = nn.ReLU() | |
| self.conv4 = nn.Conv2d(64, 64, 3, padding=1) | |
| self.bn4 = nn.BatchNorm2d(64) | |
| self.relu4 = nn.ReLU() | |
| self.conv5 = nn.Conv2d(64+24, 64, 1) | |
| self.bn5 = nn.BatchNorm2d(64) | |
| self.relu5 = nn.ReLU() | |
| self.conv6 = nn.Conv2d(64, 32, 3, padding=1) | |
| self.bn6 = nn.BatchNorm2d(32) | |
| self.relu6 = nn.ReLU() | |
| self.conv7 = nn.Conv2d(32, 32, 3, padding=1) | |
| self.bn7 = nn.BatchNorm2d(32) | |
| self.relu7 = nn.ReLU() | |
| self.conv8 = nn.Conv2d(32, 1, 1) | |
| self.sigmoid1 = nn.Sigmoid() | |
| self.conv9 = nn.Conv2d(32, 4, 1) | |
| self.sigmoid2 = nn.Sigmoid() | |
| self.conv10 = nn.Conv2d(32, 1, 1) | |
| self.sigmoid3 = nn.Sigmoid() | |
| self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear') | |
| self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear') | |
| self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear') | |
| # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4, | |
| # self.conv5,self.conv6,self.conv7,self.conv8, | |
| # self.conv9,self.conv10,self.bn1,self.bn2, | |
| # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7]) | |
| def forward(self, images): | |
| images = utils.mean_image_subtraction(images) | |
| f0 = self.s1(images) | |
| f1 = self.s2(f0) | |
| f2 = self.s3(f1) | |
| f3 = self.s4(f2) | |
| # _, f = self.mobilenet(images) | |
| h = f3 # bs 2048 w/32 h/32 | |
| g = (self.unpool1(h)) # bs 2048 w/16 h/16 | |
| c = self.conv1(torch.cat((g, f2), 1)) | |
| c = self.bn1(c) | |
| c = self.relu1(c) | |
| h = self.conv2(c) # bs 128 w/16 h/16 | |
| h = self.bn2(h) | |
| h = self.relu2(h) | |
| g = self.unpool2(h) # bs 128 w/8 h/8 | |
| c = self.conv3(torch.cat((g, f1), 1)) | |
| c = self.bn3(c) | |
| c = self.relu3(c) | |
| h = self.conv4(c) # bs 64 w/8 h/8 | |
| h = self.bn4(h) | |
| h = self.relu4(h) | |
| g = self.unpool3(h) # bs 64 w/4 h/4 | |
| c = self.conv5(torch.cat((g, f0), 1)) | |
| c = self.bn5(c) | |
| c = self.relu5(c) | |
| h = self.conv6(c) # bs 32 w/4 h/4 | |
| h = self.bn6(h) | |
| h = self.relu6(h) | |
| g = self.conv7(h) # bs 32 w/4 h/4 | |
| g = self.bn7(g) | |
| g = self.relu7(g) | |
| F_score = self.conv8(g) # bs 1 w/4 h/4 | |
| F_score = self.sigmoid1(F_score) | |
| geo_map = self.conv9(g) | |
| geo_map = self.sigmoid2(geo_map) * 512 | |
| angle_map = self.conv10(g) | |
| angle_map = self.sigmoid3(angle_map) | |
| angle_map = (angle_map - 0.5) * math.pi / 2 | |
| F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4 | |
| return F_score, F_geometry | |
| model=East() | |