Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import torch | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| # from IndicPhotoOCR.detection.east_detector import EASTdetector | |
| from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier | |
| from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser | |
| import IndicPhotoOCR.detection.east_config as cfg | |
| from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector | |
| class OCR: | |
| def __init__(self, device='cuda:0', verbose=False): | |
| # self.detect_model_checkpoint = detect_model_checkpoint | |
| self.device = device | |
| self.verbose = verbose | |
| # self.image_path = image_path | |
| # self.detector = EASTdetector() | |
| self.detector = TextBPNpp_detector(device=self.device) | |
| self.recogniser = PARseqrecogniser() | |
| self.identifier = CLIPidentifier() | |
| # def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint): | |
| # """Run the detection model to get bounding boxes of text areas.""" | |
| # if self.verbose: | |
| # print("Running text detection...") | |
| # detections = self.detector.detect(image_path, detect_model_checkpoint, self.device) | |
| # # print(detections) | |
| # return detections['detections'] | |
| def detect(self, image_path): | |
| self.detections = self.detector.detect(image_path) | |
| return self.detections['detections'] | |
| def visualize_detection(self, image_path, detections, save_path=None, show=False): | |
| # Default save path if none is provided | |
| default_save_path = "test.png" | |
| path_to_save = save_path if save_path is not None else default_save_path | |
| # Get the directory part of the path | |
| directory = os.path.dirname(path_to_save) | |
| # Check if the directory exists, and create it if it doesn’t | |
| if directory and not os.path.exists(directory): | |
| os.makedirs(directory) | |
| print(f"Created directory: {directory}") | |
| # Read the image and draw bounding boxes | |
| image = cv2.imread(image_path) | |
| for box in detections: | |
| # Convert list of points to a numpy array with int type | |
| points = np.array(box, np.int32) | |
| # Compute the top-left and bottom-right corners of the bounding box | |
| x_min = np.min(points[:, 0]) | |
| y_min = np.min(points[:, 1]) | |
| x_max = np.max(points[:, 0]) | |
| y_max = np.max(points[:, 1]) | |
| # Draw the rectangle | |
| cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=3) | |
| # Show the image if 'show' is True | |
| if show: | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| plt.axis("off") | |
| plt.show() | |
| # Save the annotated image | |
| cv2.imwrite(path_to_save, image) | |
| print(f"Image saved at: {path_to_save}") | |
| def crop_and_identify_script(self, image, bbox): | |
| """ | |
| Crop a text area from the image and identify its script language. | |
| Args: | |
| image (PIL.Image): The full image. | |
| bbox (list): List of four corner points, each a [x, y] pair. | |
| Returns: | |
| str: Identified script language. | |
| """ | |
| # Extract x and y coordinates from the four corner points | |
| x_coords = [point[0] for point in bbox] | |
| y_coords = [point[1] for point in bbox] | |
| # Get the bounding box coordinates (min and max) | |
| x_min, y_min = min(x_coords), min(y_coords) | |
| x_max, y_max = max(x_coords), max(y_coords) | |
| # Crop the image based on the bounding box | |
| cropped_image = image.crop((x_min, y_min, x_max, y_max)) | |
| root_image_dir = "IndicPhotoOCR/script_identification" | |
| os.makedirs(f"{root_image_dir}/images", exist_ok=True) | |
| # Temporarily save the cropped image to pass to the script model | |
| cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg' | |
| cropped_image.save(cropped_path) | |
| # Predict script language, here we assume "hindi" as the model name | |
| if self.verbose: | |
| print("Identifying script for the cropped area...") | |
| script_lang = self.identifier.identify(cropped_path, "hindi") # Use "hindi" as the model name | |
| # print(script_lang) | |
| # Clean up temporary file | |
| # os.remove(cropped_path) | |
| return script_lang, cropped_path | |
| def recognise(self, cropped_image_path, script_lang): | |
| """Recognize text in a cropped image area using the identified script.""" | |
| if self.verbose: | |
| print("Recognizing text in detected area...") | |
| recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose) | |
| # print(recognized_text) | |
| return recognized_text | |
| def ocr(self, image_path): | |
| """Process the image by detecting text areas, identifying script, and recognizing text.""" | |
| recognized_words = [] | |
| image = Image.open(image_path) | |
| # Run detection | |
| detections = self.detect(image_path) | |
| # Process each detected text area | |
| for bbox in detections: | |
| # Crop and identify script language | |
| script_lang, cropped_path = self.crop_and_identify_script(image, bbox) | |
| # Check if the script language is valid | |
| if script_lang: | |
| # Recognize text | |
| recognized_word = self.recognise(cropped_path, script_lang) | |
| recognized_words.append(recognized_word) | |
| if self.verbose: | |
| print(f"Recognized word: {recognized_word}") | |
| return recognized_words | |
| if __name__ == '__main__': | |
| # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar' | |
| sample_image_path = 'test_images/image_141.jpg' | |
| cropped_image_path = 'test_images/cropped_image/image_141_0.jpg' | |
| ocr = OCR(device="cuda", verbose=False) | |
| # detections = ocr.detect(sample_image_path) | |
| # print(detections) | |
| # ocr.visualize_detection(sample_image_path, detections) | |
| # recognition = ocr.recognise(cropped_image_path, "hindi") | |
| # print(recognition) | |
| recognised_words = ocr.ocr(sample_image_path) | |
| print(recognised_words) |