from transformers import VisionEncoderDecoderConfig from typing import List, Tuple, Optional from transformers import TrOCRProcessor from pathlib import Path import numpy as np import onnxruntime import math import time import cv2 import os class TextRecognition: """ ONNX-based text recognition class using TrOCR for handwritten text recognition. Processes text line images through an encoder-decoder architecture, supporting batch processing and CUDA acceleration. Args: model_path: Path to the model directory containing ONNX models and config device: Device identifier (default: 'cuda:0') batch_size: Number of lines to process in parallel (default: 10) img_height: Target height for input images (default: 192) img_width: Target width for input images (default: 1024) max_length: Maximum sequence length for generation (default: 128) """ def __init__(self, model_path: str, device: str = 'cuda:0', batch_size: int = 10, img_height: int = 192, img_width: int = 1024, max_length: int = 128): self.model_path = model_path self.device = device self.batch_size = batch_size self.img_height = img_height self.img_width = img_width self.max_length = max_length # Validate model path if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model path does not exist: {model_path}") self.init_processor() self.init_recognition_model() def init_processor(self) -> None: """ Initialize the TrOCR processor with custom image dimensions. Raises: Exception: If processor initialization fails """ try: self.processor = TrOCRProcessor.from_pretrained( str(self.model_path), use_fast=True, do_resize=True, size={ 'height': self.img_height, 'width': self.img_width } ) print(f"✓ Processor loaded with custom image size: {self.img_height}x{self.img_width}") except Exception as e: raise RuntimeError(f'Failed to initialize processor: {e}') def init_recognition_model(self) -> None: """ Initialize the ONNX encoder and decoder models with optimized settings. Raises: FileNotFoundError: If model files are not found RuntimeError: If model loading fails """ encoder_path = os.path.join(self.model_path, "encoder_model.onnx") decoder_path = os.path.join(self.model_path, "decoder_model.onnx") if not os.path.exists(encoder_path): raise FileNotFoundError(f"Encoder model not found: {encoder_path}") if not os.path.exists(decoder_path): raise FileNotFoundError(f"Decoder model not found: {decoder_path}") # Session options for better performance sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 4 providers = [ 'CUDAExecutionProvider', 'CPUExecutionProvider' ] # Load model config self.config = VisionEncoderDecoderConfig.from_pretrained(str(self.model_path)) try: print("Loading encoder...") self.encoder = onnxruntime.InferenceSession( str(encoder_path), sess_options=sess_options, providers=providers ) print("Loading decoder...") self.decoder = onnxruntime.InferenceSession( str(decoder_path), sess_options=sess_options, providers=providers ) # Report which provider is actually being used encoder_provider = self.encoder.get_providers()[0] decoder_provider = self.decoder.get_providers()[0] print(f"✓ Using execution provider: Encoder={encoder_provider}, Decoder={decoder_provider}") except Exception as e: raise RuntimeError(f'Failed to load recognition models: {e}') def crop_line(self, image: np.ndarray, polygon: List[List[float]]) -> Optional[np.ndarray]: """ Crop a text line from an image based on polygon coordinates. Creates a masked crop where the polygon area contains the original image and the background is filled with white pixels. Args: image: Source image as numpy array polygon: List of [x, y] coordinate pairs defining the text line region Returns: Cropped and masked text line image, or None if polygon is invalid """ # Convert polygon to integer coordinates polygon_array = np.array([[int(pt[0]), int(pt[1])] for pt in polygon], dtype=np.int32) # Get bounding rectangle rect = cv2.boundingRect(polygon_array) x, y, w, h = rect # Validate rectangle if w <= 0 or h <= 0: print(f"Warning: Invalid bounding rect dimensions: {w}x{h}") return None # Crop image to bounding rectangle cropped_image = image[y:y + h, x:x + w] if cropped_image.size == 0: print(f"Warning: Empty cropped image at rect {rect}") return None # Create mask for the polygon region mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8) # Adjust polygon coordinates relative to the cropped region polygon_offset = polygon_array - np.array([[x, y]]) cv2.drawContours(mask, [polygon_offset], -1, (255, 255, 255), -1, cv2.LINE_AA) # Extract the polygon region from the cropped image masked_region = cv2.bitwise_and(cropped_image, cropped_image, mask=mask) # Create white background white_background = np.ones_like(cropped_image, np.uint8) * 255 cv2.bitwise_not(white_background, white_background, mask=mask) # Overlay the masked region on white background result = white_background + masked_region return result def crop_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[np.ndarray]: """ Crop multiple text lines from an image. Args: polygons: List of polygon coordinate lists image: Source image Returns: List of cropped text line images (excluding any failed crops) """ cropped_lines = [] for i, polygon in enumerate(polygons): cropped_line = self.crop_line(image, polygon) if cropped_line is not None: cropped_lines.append(cropped_line) else: print(f"Warning: Failed to crop line {i}") return cropped_lines def encode(self, pixel_values: np.ndarray) -> np.ndarray: """ Encode image pixel values into hidden states using the vision encoder. Args: pixel_values: Preprocessed image tensor from TrOCRProcessor Shape: (batch_size, channels, height, width) Returns: Encoder hidden states for input to the decoder Shape: (batch_size, sequence_length, hidden_size) Raises: RuntimeError: If encoding fails """ try: encoder_outputs = self.encoder.run( None, {"pixel_values": pixel_values} )[0] return encoder_outputs except Exception as e: raise RuntimeError(f'Failed to encode input: {e}') def generate(self, encoder_outputs: np.ndarray, batch_size: int) -> np.ndarray: """ Generate text tokens using autoregressive decoding with early stopping. Implements per-sequence early stopping: sequences that generate EOS tokens stop producing new tokens while others continue, improving efficiency. Args: encoder_outputs: Hidden states from the encoder Shape: (batch_size, sequence_length, hidden_size) batch_size: Number of sequences in the batch Returns: Generated token IDs including start and end tokens Shape: (batch_size, generated_length) Raises: RuntimeError: If generation fails """ try: # Initialize decoder input with start tokens decoder_input_ids = np.full( (batch_size, 1), self.config.decoder_start_token_id, dtype=np.int64 ) # Track which sequences have finished finished = np.zeros(batch_size, dtype=bool) for step in range(self.max_length): # Run decoder to get next token logits decoder_outputs = self.decoder.run( None, { "input_ids": decoder_input_ids, "encoder_hidden_states": encoder_outputs } )[0] # Get most likely next token for each sequence next_token_logits = decoder_outputs[:, -1, :] next_tokens = np.argmax(next_token_logits, axis=-1) # Check if any sequences just generated EOS token just_finished = (next_tokens == self.config.eos_token_id) finished = finished | just_finished ## Replace tokens with PAD for already finished sequences next_tokens[finished] = self.config.pad_token_id # Append new tokens to the sequence next_tokens = next_tokens.reshape(-1, 1) decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens], axis=1) # Stop when all sequences have finished if np.all(finished): break return decoder_input_ids except Exception as e: raise RuntimeError(f'Failed to generate output ids: {e}') def predict_text(self, cropped_lines: List[np.ndarray]) -> List[str]: """ Predict text content from cropped line images. Args: cropped_lines: List of cropped text line images Returns: List of predicted text strings Raises: RuntimeError: If prediction fails """ try: # Process image with TrOCR processor # Use 'pt' (PyTorch) then convert to numpy, as 'np' is not supported by fast processors pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values pixel_values = pixel_values.numpy() batch_size = pixel_values.shape[0] #Encode images to hidden states encoder_hidden_states = self.encode(pixel_values) # Generate token sequences generated_ids = self.generate(encoder_hidden_states, batch_size) # Decode tokens to text texts = self.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return texts except Exception as e: raise RuntimeError(f'Failed to predict text: {e}') def get_text_lines(self, cropped_lines: List[np.ndarray]) -> List[str]: """ Process text lines in batches to manage memory efficiently. Args: cropped_lines: List of all cropped line images Returns: List of predicted text strings for all lines """ generated_text = [] # Process in batches for i in range(0, len(cropped_lines), self.batch_size): batch = cropped_lines[i:i + self.batch_size] texts = self.predict_text(batch) generated_text.extend(texts) return generated_text def process_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[str]: """ Complete pipeline: crop text lines and predict their content. Args: polygons: List of polygon coordinate lists defining text line regions image: Source document image Returns: List of predicted text strings for each valid line """ # Crop line images from the document cropped_lines = self.crop_lines(polygons, image) if not cropped_lines: print("Warning: No valid cropped lines to process") return [] # Get text predictions for all lines generated_text = self.get_text_lines(cropped_lines) return generated_text