Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import nltk | |
| from PIL import Image | |
| import os | |
| from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py | |
| from IndicPhotoOCR.theme import Seafoam | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| ) | |
| from IndicTransToolkit import IndicProcessor | |
| import torch | |
| DEVICE = "cpu" | |
| # Initialize the OCR object for text detection and recognition | |
| ocr = OCR(device="cpu", verbose=False) | |
| def translate(given_str,lang): | |
| model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
| ip = IndicProcessor(inference=True) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" ) | |
| batch = ip.preprocess_batch( | |
| [given_str], | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| ) | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| # Decode the generated tokens into text | |
| with tokenizer.as_target_tokenizer(): | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0] | |
| return translation | |
| def detect_para(bbox_dict): | |
| alpha1 = 0.2 | |
| alpha2 = 0.7 | |
| beta1 = 0.4 | |
| data = bbox_dict | |
| word_crops = list(data.keys()) | |
| for i in word_crops: | |
| data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"] | |
| data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2 | |
| data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2 | |
| data[i]["w"] = data[i]["x2"] - data[i]["x1"] | |
| data[i]["h"] = data[i]["y2"] - data[i]["y1"] | |
| patch_info = {} | |
| while word_crops: | |
| img_name = word_crops[0].split("_")[0] | |
| word_crop_collection = [ | |
| word_crop for word_crop in word_crops if word_crop.startswith(img_name) | |
| ] | |
| centroids = {} | |
| lines = [] | |
| img_word_crops = word_crop_collection.copy() | |
| para = [] | |
| while img_word_crops: | |
| clusters = [] | |
| para_words_group = [ | |
| img_word_crops[0], | |
| ] | |
| added = [ | |
| img_word_crops[0], | |
| ] | |
| img_word_crops.remove(img_word_crops[0]) | |
| ## determining the paragraph | |
| while added: | |
| word_crop = added.pop() | |
| for i in range(len(img_word_crops)): | |
| word_crop_ = img_word_crops[i] | |
| if ( | |
| abs(data[word_crop_]["yc"] - data[word_crop]["yc"]) | |
| < data[word_crop]["h"] * alpha1 | |
| ): | |
| if data[word_crop]["xc"] > data[word_crop_]["xc"]: | |
| if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[ | |
| word_crop | |
| ]["h"] * alpha2: | |
| para_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| else: | |
| if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[ | |
| word_crop | |
| ]["h"] * alpha2: | |
| para_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| else: | |
| if data[word_crop]["yc"] > data[word_crop_]["yc"]: | |
| if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[ | |
| word_crop | |
| ]["h"] * beta1 and ( | |
| ( | |
| (data[word_crop_]["x1"] < data[word_crop]["x2"]) | |
| and (data[word_crop_]["x1"] > data[word_crop]["x1"]) | |
| ) | |
| or ( | |
| (data[word_crop_]["x2"] < data[word_crop]["x2"]) | |
| and (data[word_crop_]["x2"] > data[word_crop]["x1"]) | |
| ) | |
| or ( | |
| (data[word_crop]["x1"] > data[word_crop_]["x1"]) | |
| and (data[word_crop]["x2"] < data[word_crop_]["x2"]) | |
| ) | |
| ): | |
| para_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| else: | |
| if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[ | |
| word_crop | |
| ]["h"] * beta1 and ( | |
| ( | |
| (data[word_crop_]["x1"] < data[word_crop]["x2"]) | |
| and (data[word_crop_]["x1"] > data[word_crop]["x1"]) | |
| ) | |
| or ( | |
| (data[word_crop_]["x2"] < data[word_crop]["x2"]) | |
| and (data[word_crop_]["x2"] > data[word_crop]["x1"]) | |
| ) | |
| or ( | |
| (data[word_crop]["x1"] > data[word_crop_]["x1"]) | |
| and (data[word_crop]["x2"] < data[word_crop_]["x2"]) | |
| ) | |
| ): | |
| para_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| img_word_crops = [p for p in img_word_crops if p not in para_words_group] | |
| ## processing for the line | |
| while para_words_group: | |
| line_words_group = [ | |
| para_words_group[0], | |
| ] | |
| added = [ | |
| para_words_group[0], | |
| ] | |
| para_words_group.remove(para_words_group[0]) | |
| ## determining the line | |
| while added: | |
| word_crop = added.pop() | |
| for i in range(len(para_words_group)): | |
| word_crop_ = para_words_group[i] | |
| if ( | |
| abs(data[word_crop_]["yc"] - data[word_crop]["yc"]) | |
| < data[word_crop]["h"] * alpha1 | |
| ): | |
| if data[word_crop]["xc"] > data[word_crop_]["xc"]: | |
| if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[ | |
| word_crop | |
| ]["h"] * alpha2: | |
| line_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| else: | |
| if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[ | |
| word_crop | |
| ]["h"] * alpha2: | |
| line_words_group.append(word_crop_) | |
| added.append(word_crop_) | |
| para_words_group = [ | |
| p for p in para_words_group if p not in line_words_group | |
| ] | |
| xc = [data[word_crop]["xc"] for word_crop in line_words_group] | |
| idxs = np.argsort(xc) | |
| patch_cluster_ = [line_words_group[i] for i in idxs] | |
| line_words_group = patch_cluster_ | |
| x1 = [data[word_crop]["x1"] for word_crop in line_words_group] | |
| x2 = [data[word_crop]["x2"] for word_crop in line_words_group] | |
| y1 = [data[word_crop]["y1"] for word_crop in line_words_group] | |
| y2 = [data[word_crop]["y2"] for word_crop in line_words_group] | |
| txt_line = [data[word_crop]["txt"] for word_crop in line_words_group] | |
| txt = " ".join(txt_line) | |
| x = [x1[0]] | |
| y1_ = [y1[0]] | |
| y2_ = [y2[0]] | |
| l = [len(txt_l) for txt_l in txt_line] | |
| for i in range(1, len(x1)): | |
| x.append((x1[i] + x2[i - 1]) / 2) | |
| y1_.append((y1[i] + y1[i - 1]) / 2) | |
| y2_.append((y2[i] + y2[i - 1]) / 2) | |
| x.append(x2[-1]) | |
| y1_.append(y1[-1]) | |
| y2_.append(y2[-1]) | |
| line_info = { | |
| "x": x, | |
| "y1": y1_, | |
| "y2": y2_, | |
| "l": l, | |
| "txt": txt, | |
| "word_crops": line_words_group, | |
| } | |
| clusters.append(line_info) | |
| y_ = [clusters[i]["y1"][0] for i in range(len(clusters))] | |
| idxs = np.argsort(y_) | |
| clusters_ = [clusters[i] for i in idxs] | |
| txt = [clusters[i]["txt"] for i in idxs] | |
| l = [len(t) for t in txt] | |
| txt = " ".join(txt) | |
| para_info = {"lines": clusters_, "l": l, "txt": txt} | |
| para.append(para_info) | |
| for word_crop in word_crop_collection: | |
| word_crops.remove(word_crop) | |
| return "\n".join([para[i]["txt"] for i in range(len(para))]) | |
| def process_image(image): | |
| """ | |
| Processes the uploaded image for text detection and recognition. | |
| - Detects bounding boxes in the image | |
| - Draws bounding boxes on the image and identifies script in each detected area | |
| - Recognizes text in each cropped region and returns the annotated image and recognized text | |
| Parameters: | |
| image (PIL.Image): The input image to be processed. | |
| Returns: | |
| tuple: A PIL.Image with bounding boxes and a string of recognized text. | |
| """ | |
| # Save the input image temporarily | |
| image_path = "input_image.jpg" | |
| image.save(image_path) | |
| # Detect bounding boxes on the image using OCR | |
| detections = ocr.detect(image_path) | |
| # Draw bounding boxes on the image and save it as output | |
| ocr.visualize_detection(image_path, detections, save_path="output_image.png") | |
| # Load the annotated image with bounding boxes drawn | |
| output_image = Image.open("output_image.png") | |
| # Initialize list to hold recognized text from each detected area | |
| recognized_texts = {} | |
| pil_image = Image.open(image_path) | |
| script_lang = "english" | |
| # Process each detected bounding box for script identification and text recognition | |
| for id,bbox in enumerate(detections): | |
| # Identify the script and crop the image to this region | |
| script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox) | |
| x1 = min([bbox[i][0] for i in range(len(bbox))]) | |
| y1 = min([bbox[i][1] for i in range(len(bbox))]) | |
| x2 = max([bbox[i][0] for i in range(len(bbox))]) | |
| y2 = max([bbox[i][1] for i in range(len(bbox))]) | |
| if script_lang: | |
| recognized_text = ocr.recognise(cropped_path,script_lang) | |
| recognized_texts[f"img_{id}"] = {"txt":recognized_text,"bbox":[x1,y1,x2,y2]} | |
| translated = translate(detect_para(recognized_texts),script_lang) | |
| # Combine recognized texts into a single string for display | |
| return output_image,translated | |
| # Custom HTML for interface header with logos and alignment | |
| interface_html = """ | |
| <div style="text-align: left; padding: 10px;"> | |
| <div style="background-color: white; padding: 10px; display: inline-block;"> | |
| <img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;"> | |
| </div> | |
| <img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;"> | |
| </div> | |
| """ | |
| # Links to GitHub and Dataset repositories with GitHub icon | |
| links_html = """ | |
| <div style="text-align: center; padding-top: 20px;"> | |
| <a href="https://github.com/Bhashini-IITJ/visualTranslation" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;"> | |
| GitHub Repository | |
| </a> | |
| <a href="https://vl2g.github.io/projects/visTrans" target="_blank" style="font-size: 18px; text-decoration: none;"> | |
| Project Page | |
| </a> | |
| </div> | |
| """ | |
| # Custom CSS to style the text box font size | |
| custom_css = """ | |
| .custom-textbox textarea { | |
| font-size: 20px !important; | |
| } | |
| """ | |
| # Create an instance of the Seafoam theme for a consistent visual style | |
| seafoam = Seafoam() | |
| # Define examples for users to try out | |
| examples = [ | |
| ["test_images/208.jpg"], | |
| ["test_images/1310.jpg"] | |
| ] | |
| title = "<h1 style='text-align: center;'>Developed by IITJ</h1>" | |
| # Set up the Gradio Interface with the defined function and customizations | |
| demo = gr.Interface( | |
| allow_flagging="never", | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", image_mode="RGB"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Detected Bounding Boxes"), | |
| gr.Textbox(label="Translated Text", elem_classes="custom-textbox") | |
| ], | |
| title="IndicPhotoOCR - Indic Scene Text Recogniser Toolkit", | |
| description=title+interface_html+links_html, | |
| theme=seafoam, | |
| css=custom_css, | |
| examples=examples | |
| ) | |
| # Server setup and launch configuration | |
| # if __name__ == "__main__": | |
| # server = "0.0.0.0" # IP address for server | |
| # port = 7865 # Port to run the server on | |
| # demo.launch(server_name=server, server_port=port) | |
| demo.launch() | |