Spaces:
Runtime error
Runtime error
| from transformers import VisionEncoderDecoderConfig | |
| from typing import List, Tuple, Optional | |
| from transformers import TrOCRProcessor | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime | |
| import math | |
| import time | |
| import cv2 | |
| import os | |
| class TextRecognition: | |
| """ | |
| ONNX-based text recognition class using TrOCR for handwritten text recognition. | |
| Processes text line images through an encoder-decoder architecture, supporting | |
| batch processing and CUDA acceleration. | |
| Args: | |
| model_path: Path to the model directory containing ONNX models and config | |
| device: Device identifier (default: 'cuda:0') | |
| batch_size: Number of lines to process in parallel (default: 10) | |
| img_height: Target height for input images (default: 192) | |
| img_width: Target width for input images (default: 1024) | |
| max_length: Maximum sequence length for generation (default: 128) | |
| """ | |
| def __init__(self, | |
| model_path: str, | |
| device: str = 'cuda:0', | |
| batch_size: int = 10, | |
| img_height: int = 192, | |
| img_width: int = 1024, | |
| max_length: int = 128): | |
| self.model_path = model_path | |
| self.device = device | |
| self.batch_size = batch_size | |
| self.img_height = img_height | |
| self.img_width = img_width | |
| self.max_length = max_length | |
| # Validate model path | |
| if not os.path.exists(self.model_path): | |
| raise FileNotFoundError(f"Model path does not exist: {model_path}") | |
| self.init_processor() | |
| self.init_recognition_model() | |
| def init_processor(self) -> None: | |
| """ | |
| Initialize the TrOCR processor with custom image dimensions. | |
| Raises: | |
| Exception: If processor initialization fails | |
| """ | |
| try: | |
| self.processor = TrOCRProcessor.from_pretrained( | |
| str(self.model_path), | |
| use_fast=True, | |
| do_resize=True, | |
| size={ | |
| 'height': self.img_height, | |
| 'width': self.img_width | |
| } | |
| ) | |
| print(f"✓ Processor loaded with custom image size: {self.img_height}x{self.img_width}") | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to initialize processor: {e}') | |
| def init_recognition_model(self) -> None: | |
| """ | |
| Initialize the ONNX encoder and decoder models with optimized settings. | |
| Raises: | |
| FileNotFoundError: If model files are not found | |
| RuntimeError: If model loading fails | |
| """ | |
| encoder_path = os.path.join(self.model_path, "encoder_model.onnx") | |
| decoder_path = os.path.join(self.model_path, "decoder_model.onnx") | |
| if not os.path.exists(encoder_path): | |
| raise FileNotFoundError(f"Encoder model not found: {encoder_path}") | |
| if not os.path.exists(decoder_path): | |
| raise FileNotFoundError(f"Decoder model not found: {decoder_path}") | |
| # Session options for better performance | |
| sess_options = onnxruntime.SessionOptions() | |
| sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.intra_op_num_threads = 4 | |
| providers = [ | |
| 'CUDAExecutionProvider', | |
| 'CPUExecutionProvider' | |
| ] | |
| # Load model config | |
| self.config = VisionEncoderDecoderConfig.from_pretrained(str(self.model_path)) | |
| try: | |
| print("Loading encoder...") | |
| self.encoder = onnxruntime.InferenceSession( | |
| str(encoder_path), | |
| sess_options=sess_options, | |
| providers=providers | |
| ) | |
| print("Loading decoder...") | |
| self.decoder = onnxruntime.InferenceSession( | |
| str(decoder_path), | |
| sess_options=sess_options, | |
| providers=providers | |
| ) | |
| # Report which provider is actually being used | |
| encoder_provider = self.encoder.get_providers()[0] | |
| decoder_provider = self.decoder.get_providers()[0] | |
| print(f"✓ Using execution provider: Encoder={encoder_provider}, Decoder={decoder_provider}") | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to load recognition models: {e}') | |
| def crop_line(self, image: np.ndarray, polygon: List[List[float]]) -> Optional[np.ndarray]: | |
| """ | |
| Crop a text line from an image based on polygon coordinates. | |
| Creates a masked crop where the polygon area contains the original image | |
| and the background is filled with white pixels. | |
| Args: | |
| image: Source image as numpy array | |
| polygon: List of [x, y] coordinate pairs defining the text line region | |
| Returns: | |
| Cropped and masked text line image, or None if polygon is invalid | |
| """ | |
| # Convert polygon to integer coordinates | |
| polygon_array = np.array([[int(pt[0]), int(pt[1])] for pt in polygon], dtype=np.int32) | |
| # Get bounding rectangle | |
| rect = cv2.boundingRect(polygon_array) | |
| x, y, w, h = rect | |
| # Validate rectangle | |
| if w <= 0 or h <= 0: | |
| print(f"Warning: Invalid bounding rect dimensions: {w}x{h}") | |
| return None | |
| # Crop image to bounding rectangle | |
| cropped_image = image[y:y + h, x:x + w] | |
| if cropped_image.size == 0: | |
| print(f"Warning: Empty cropped image at rect {rect}") | |
| return None | |
| # Create mask for the polygon region | |
| mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8) | |
| # Adjust polygon coordinates relative to the cropped region | |
| polygon_offset = polygon_array - np.array([[x, y]]) | |
| cv2.drawContours(mask, [polygon_offset], -1, (255, 255, 255), -1, cv2.LINE_AA) | |
| # Extract the polygon region from the cropped image | |
| masked_region = cv2.bitwise_and(cropped_image, cropped_image, mask=mask) | |
| # Create white background | |
| white_background = np.ones_like(cropped_image, np.uint8) * 255 | |
| cv2.bitwise_not(white_background, white_background, mask=mask) | |
| # Overlay the masked region on white background | |
| result = white_background + masked_region | |
| return result | |
| def crop_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[np.ndarray]: | |
| """ | |
| Crop multiple text lines from an image. | |
| Args: | |
| polygons: List of polygon coordinate lists | |
| image: Source image | |
| Returns: | |
| List of cropped text line images (excluding any failed crops) | |
| """ | |
| cropped_lines = [] | |
| for i, polygon in enumerate(polygons): | |
| cropped_line = self.crop_line(image, polygon) | |
| if cropped_line is not None: | |
| cropped_lines.append(cropped_line) | |
| else: | |
| print(f"Warning: Failed to crop line {i}") | |
| return cropped_lines | |
| def encode(self, pixel_values: np.ndarray) -> np.ndarray: | |
| """ | |
| Encode image pixel values into hidden states using the vision encoder. | |
| Args: | |
| pixel_values: Preprocessed image tensor from TrOCRProcessor | |
| Shape: (batch_size, channels, height, width) | |
| Returns: | |
| Encoder hidden states for input to the decoder | |
| Shape: (batch_size, sequence_length, hidden_size) | |
| Raises: | |
| RuntimeError: If encoding fails | |
| """ | |
| try: | |
| encoder_outputs = self.encoder.run( | |
| None, | |
| {"pixel_values": pixel_values} | |
| )[0] | |
| return encoder_outputs | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to encode input: {e}') | |
| def generate(self, encoder_outputs: np.ndarray, batch_size: int) -> np.ndarray: | |
| """ | |
| Generate text tokens using autoregressive decoding with early stopping. | |
| Implements per-sequence early stopping: sequences that generate EOS tokens | |
| stop producing new tokens while others continue, improving efficiency. | |
| Args: | |
| encoder_outputs: Hidden states from the encoder | |
| Shape: (batch_size, sequence_length, hidden_size) | |
| batch_size: Number of sequences in the batch | |
| Returns: | |
| Generated token IDs including start and end tokens | |
| Shape: (batch_size, generated_length) | |
| Raises: | |
| RuntimeError: If generation fails | |
| """ | |
| try: | |
| # Initialize decoder input with start tokens | |
| decoder_input_ids = np.full( | |
| (batch_size, 1), | |
| self.config.decoder_start_token_id, | |
| dtype=np.int64 | |
| ) | |
| # Track which sequences have finished | |
| finished = np.zeros(batch_size, dtype=bool) | |
| for step in range(self.max_length): | |
| # Run decoder to get next token logits | |
| decoder_outputs = self.decoder.run( | |
| None, | |
| { | |
| "input_ids": decoder_input_ids, | |
| "encoder_hidden_states": encoder_outputs | |
| } | |
| )[0] | |
| # Get most likely next token for each sequence | |
| next_token_logits = decoder_outputs[:, -1, :] | |
| next_tokens = np.argmax(next_token_logits, axis=-1) | |
| # Check if any sequences just generated EOS token | |
| just_finished = (next_tokens == self.config.eos_token_id) | |
| finished = finished | just_finished | |
| ## Replace tokens with PAD for already finished sequences | |
| next_tokens[finished] = self.config.pad_token_id | |
| # Append new tokens to the sequence | |
| next_tokens = next_tokens.reshape(-1, 1) | |
| decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens], axis=1) | |
| # Stop when all sequences have finished | |
| if np.all(finished): | |
| break | |
| return decoder_input_ids | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to generate output ids: {e}') | |
| def predict_text(self, cropped_lines: List[np.ndarray]) -> List[str]: | |
| """ | |
| Predict text content from cropped line images. | |
| Args: | |
| cropped_lines: List of cropped text line images | |
| Returns: | |
| List of predicted text strings | |
| Raises: | |
| RuntimeError: If prediction fails | |
| """ | |
| try: | |
| # Process image with TrOCR processor | |
| # Use 'pt' (PyTorch) then convert to numpy, as 'np' is not supported by fast processors | |
| pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.numpy() | |
| batch_size = pixel_values.shape[0] | |
| #Encode images to hidden states | |
| encoder_hidden_states = self.encode(pixel_values) | |
| # Generate token sequences | |
| generated_ids = self.generate(encoder_hidden_states, batch_size) | |
| # Decode tokens to text | |
| texts = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| return texts | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to predict text: {e}') | |
| def get_text_lines(self, cropped_lines: List[np.ndarray]) -> List[str]: | |
| """ | |
| Process text lines in batches to manage memory efficiently. | |
| Args: | |
| cropped_lines: List of all cropped line images | |
| Returns: | |
| List of predicted text strings for all lines | |
| """ | |
| generated_text = [] | |
| # Process in batches | |
| for i in range(0, len(cropped_lines), self.batch_size): | |
| batch = cropped_lines[i:i + self.batch_size] | |
| texts = self.predict_text(batch) | |
| generated_text.extend(texts) | |
| return generated_text | |
| def process_lines(self, | |
| polygons: List[List[List[float]]], | |
| image: np.ndarray) -> List[str]: | |
| """ | |
| Complete pipeline: crop text lines and predict their content. | |
| Args: | |
| polygons: List of polygon coordinate lists defining text line regions | |
| image: Source document image | |
| Returns: | |
| List of predicted text strings for each valid line | |
| """ | |
| # Crop line images from the document | |
| cropped_lines = self.crop_lines(polygons, image) | |
| if not cropped_lines: | |
| print("Warning: No valid cropped lines to process") | |
| return [] | |
| # Get text predictions for all lines | |
| generated_text = self.get_text_lines(cropped_lines) | |
| return generated_text |