Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| from torch.nn import init | |
| import cv2 | |
| import numpy as np | |
| import time | |
| import requests | |
| from IndicPhotoOCR.detection import east_config as cfg | |
| from IndicPhotoOCR.detection import east_preprossing as preprossing | |
| from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms | |
| # Example usage: | |
| model_info = { | |
| "east": { | |
| "paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path], | |
| "urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"] | |
| }, | |
| } | |
| class ModelManager: | |
| def __init__(self): | |
| # self.root_model_dir = "bharatOCR/detection/" | |
| pass | |
| def download_model(self, url, path): | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: # Filter out keep-alive chunks | |
| f.write(chunk) | |
| print(f"Downloaded: {path}") | |
| else: | |
| print(f"Failed to download from {url}") | |
| def ensure_model(self, model_name): | |
| model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths | |
| urls = model_info[model_name]["urls"] # Changed to handle multiple URLs | |
| for model_path, url in zip(model_paths, urls): | |
| # full_model_path = os.path.join(self.root_model_dir, model_path) | |
| # Ensure the model path directory exists | |
| os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True) | |
| if not os.path.exists(model_path): | |
| print(f"Model not found locally. Downloading {model_name} from {url}...") | |
| self.download_model(url, model_path) | |
| else: | |
| print(f"Model already exists at {model_path}. No need to download.") | |
| # # Initialize ModelManager and ensure Hindi models are downloaded | |
| model_manager = ModelManager() | |
| model_manager.ensure_model("east") | |
| def init_weights(m_list, init_type=cfg.init_type, gain=0.02): | |
| print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type)) | |
| # this will apply to each layer | |
| for m in m_list: | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find('BatchNorm2d') != -1: | |
| init.normal_(m.weight.data, 1.0, gain) | |
| init.constant_(m.bias.data, 0.0) | |
| print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type)) | |
| def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'): | |
| """[summary] | |
| [description] | |
| Arguments: | |
| state {[type]} -- [description] a dict describe some params | |
| Keyword Arguments: | |
| filename {str} -- [description] (default: {'checkpoint.pth.tar'}) | |
| """ | |
| weightpath = os.path.abspath(cfg.checkpoint) | |
| print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath)) | |
| checkpoint = torch.load(weightpath) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| model.load_state_dict(checkpoint['state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| scheduler.load_state_dict(checkpoint['scheduler']) | |
| print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath)) | |
| return start_epoch | |
| def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'): | |
| """[summary] | |
| [description] | |
| Arguments: | |
| state {[type]} -- [description] a dict describe some params | |
| Keyword Arguments: | |
| filename {str} -- [description] (default: {'checkpoint.pth.tar'}) | |
| """ | |
| print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch)) | |
| state = { | |
| 'epoch': epoch, | |
| 'state_dict': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'scheduler': scheduler.state_dict() | |
| } | |
| weight_dir = cfg.save_model_path | |
| if not os.path.exists(weight_dir): | |
| os.mkdir(weight_dir) | |
| filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar' | |
| file_path = os.path.join(weight_dir, filename) | |
| torch.save(state, file_path) | |
| print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch)) | |
| class Regularization(torch.nn.Module): | |
| def __init__(self, model, weight_decay, p=2): | |
| super(Regularization, self).__init__() | |
| if weight_decay < 0: | |
| print("param weight_decay can not <0") | |
| exit(0) | |
| self.model = model | |
| self.weight_decay = weight_decay | |
| self.p = p | |
| self.weight_list = self.get_weight(model) | |
| # self.weight_info(self.weight_list) | |
| def to(self, device): | |
| self.device = device | |
| super().to(device) | |
| return self | |
| def forward(self, model): | |
| self.weight_list = self.get_weight(model) # 获得最新的权重 | |
| reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p) | |
| return reg_loss | |
| def get_weight(self, model): | |
| weight_list = [] | |
| for name, param in model.named_parameters(): | |
| if 'weight' in name: | |
| weight = (name, param) | |
| weight_list.append(weight) | |
| return weight_list | |
| def regularization_loss(self, weight_list, weight_decay, p=2): | |
| reg_loss = 0 | |
| for name, w in weight_list: | |
| l2_reg = torch.norm(w, p=p) | |
| reg_loss = reg_loss + l2_reg | |
| reg_loss = weight_decay * reg_loss | |
| return reg_loss | |
| def weight_info(self, weight_list): | |
| print("---------------regularization weight---------------") | |
| for name, w in weight_list: | |
| print(name) | |
| print("---------------------------------------------------") | |
| def resize_image(im, max_side_len=2400): | |
| ''' | |
| resize image to a size multiple of 32 which is required by the network | |
| :param im: the resized image | |
| :param max_side_len: limit of max image size to avoid out of memory in gpu | |
| :return: the resized image and the resize ratio | |
| ''' | |
| h, w, _ = im.shape | |
| resize_w = w | |
| resize_h = h | |
| # limit the max side | |
| """ | |
| if max(resize_h, resize_w) > max_side_len: | |
| ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w | |
| else: | |
| ratio = 1. | |
| resize_h = int(resize_h * ratio) | |
| resize_w = int(resize_w * ratio) | |
| """ | |
| resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 | |
| resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 | |
| #resize_h, resize_w = 512, 512 | |
| im = cv2.resize(im, (int(resize_w), int(resize_h))) | |
| ratio_h = resize_h / float(h) | |
| ratio_w = resize_w / float(w) | |
| return im, (ratio_h, ratio_w) | |
| def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): | |
| ''' | |
| restore text boxes from score map and geo map | |
| :param score_map: | |
| :param geo_map: | |
| :param timer: | |
| :param score_map_thresh: threshhold for score map | |
| :param box_thresh: threshhold for boxes | |
| :param nms_thres: threshold for nms | |
| :return: | |
| ''' | |
| # score_map 和 geo_map 的维数进行调整 | |
| if len(score_map.shape) == 4: | |
| score_map = score_map[0, :, :, 0] | |
| geo_map = geo_map[0, :, :, :] | |
| # filter the score map | |
| xy_text = np.argwhere(score_map > score_map_thresh) | |
| # sort the text boxes via the y axis | |
| xy_text = xy_text[np.argsort(xy_text[:, 0])] | |
| # restore | |
| start = time.time() | |
| text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4, | |
| geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 | |
| # print('{} text boxes before nms'.format(text_box_restored.shape[0])) | |
| boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) | |
| boxes[:, :8] = text_box_restored.reshape((-1, 8)) | |
| boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] | |
| timer['restore'] = time.time() - start | |
| # nms part | |
| start = time.time() | |
| boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres) | |
| timer['nms'] = time.time() - start | |
| # print(timer['nms']) | |
| if boxes.shape[0] == 0: | |
| return None, timer | |
| # here we filter some low score boxes by the average score map, this is different from the orginal paper | |
| for i, box in enumerate(boxes): | |
| mask = np.zeros_like(score_map, dtype=np.uint8) | |
| cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) | |
| boxes[i, 8] = cv2.mean(score_map, mask)[0] | |
| boxes = boxes[boxes[:, 8] > box_thresh] | |
| return boxes, timer | |
| def sort_poly(p): | |
| min_axis = np.argmin(np.sum(p, axis=1)) | |
| p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] | |
| if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): | |
| return p | |
| else: | |
| return p[[0, 3, 2, 1]] | |
| def mean_image_subtraction(images, means=cfg.means): | |
| ''' | |
| image normalization | |
| :param images: bs * w * h * channel | |
| :param means: | |
| :return: | |
| ''' | |
| num_channels = images.data.shape[1] | |
| if len(means) != num_channels: | |
| raise ValueError('len(means) must match the number of channels') | |
| for i in range(num_channels): | |
| images.data[:, i, :, :] -= means[i] | |
| return images | |