Multicentury-HTR-Demo / onnx_text_recognition.py
MikkoLipsanen's picture
Update code to use 202509_onnx_small model
eee8423 verified
from transformers import VisionEncoderDecoderConfig
from typing import List, Tuple, Optional
from transformers import TrOCRProcessor
from pathlib import Path
import numpy as np
import onnxruntime
import math
import time
import cv2
import os
class TextRecognition:
"""
ONNX-based text recognition class using TrOCR for handwritten text recognition.
Processes text line images through an encoder-decoder architecture, supporting
batch processing and CUDA acceleration.
Args:
model_path: Path to the model directory containing ONNX models and config
device: Device identifier (default: 'cuda:0')
batch_size: Number of lines to process in parallel (default: 10)
img_height: Target height for input images (default: 192)
img_width: Target width for input images (default: 1024)
max_length: Maximum sequence length for generation (default: 128)
"""
def __init__(self,
model_path: str,
device: str = 'cuda:0',
batch_size: int = 10,
img_height: int = 192,
img_width: int = 1024,
max_length: int = 128):
self.model_path = model_path
self.device = device
self.batch_size = batch_size
self.img_height = img_height
self.img_width = img_width
self.max_length = max_length
# Validate model path
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model path does not exist: {model_path}")
self.init_processor()
self.init_recognition_model()
def init_processor(self) -> None:
"""
Initialize the TrOCR processor with custom image dimensions.
Raises:
Exception: If processor initialization fails
"""
try:
self.processor = TrOCRProcessor.from_pretrained(
str(self.model_path),
use_fast=True,
do_resize=True,
size={
'height': self.img_height,
'width': self.img_width
}
)
print(f"✓ Processor loaded with custom image size: {self.img_height}x{self.img_width}")
except Exception as e:
raise RuntimeError(f'Failed to initialize processor: {e}')
def init_recognition_model(self) -> None:
"""
Initialize the ONNX encoder and decoder models with optimized settings.
Raises:
FileNotFoundError: If model files are not found
RuntimeError: If model loading fails
"""
encoder_path = os.path.join(self.model_path, "encoder_model.onnx")
decoder_path = os.path.join(self.model_path, "decoder_model.onnx")
if not os.path.exists(encoder_path):
raise FileNotFoundError(f"Encoder model not found: {encoder_path}")
if not os.path.exists(decoder_path):
raise FileNotFoundError(f"Decoder model not found: {decoder_path}")
# Session options for better performance
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
providers = [
'CUDAExecutionProvider',
'CPUExecutionProvider'
]
# Load model config
self.config = VisionEncoderDecoderConfig.from_pretrained(str(self.model_path))
try:
print("Loading encoder...")
self.encoder = onnxruntime.InferenceSession(
str(encoder_path),
sess_options=sess_options,
providers=providers
)
print("Loading decoder...")
self.decoder = onnxruntime.InferenceSession(
str(decoder_path),
sess_options=sess_options,
providers=providers
)
# Report which provider is actually being used
encoder_provider = self.encoder.get_providers()[0]
decoder_provider = self.decoder.get_providers()[0]
print(f"✓ Using execution provider: Encoder={encoder_provider}, Decoder={decoder_provider}")
except Exception as e:
raise RuntimeError(f'Failed to load recognition models: {e}')
def crop_line(self, image: np.ndarray, polygon: List[List[float]]) -> Optional[np.ndarray]:
"""
Crop a text line from an image based on polygon coordinates.
Creates a masked crop where the polygon area contains the original image
and the background is filled with white pixels.
Args:
image: Source image as numpy array
polygon: List of [x, y] coordinate pairs defining the text line region
Returns:
Cropped and masked text line image, or None if polygon is invalid
"""
# Convert polygon to integer coordinates
polygon_array = np.array([[int(pt[0]), int(pt[1])] for pt in polygon], dtype=np.int32)
# Get bounding rectangle
rect = cv2.boundingRect(polygon_array)
x, y, w, h = rect
# Validate rectangle
if w <= 0 or h <= 0:
print(f"Warning: Invalid bounding rect dimensions: {w}x{h}")
return None
# Crop image to bounding rectangle
cropped_image = image[y:y + h, x:x + w]
if cropped_image.size == 0:
print(f"Warning: Empty cropped image at rect {rect}")
return None
# Create mask for the polygon region
mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8)
# Adjust polygon coordinates relative to the cropped region
polygon_offset = polygon_array - np.array([[x, y]])
cv2.drawContours(mask, [polygon_offset], -1, (255, 255, 255), -1, cv2.LINE_AA)
# Extract the polygon region from the cropped image
masked_region = cv2.bitwise_and(cropped_image, cropped_image, mask=mask)
# Create white background
white_background = np.ones_like(cropped_image, np.uint8) * 255
cv2.bitwise_not(white_background, white_background, mask=mask)
# Overlay the masked region on white background
result = white_background + masked_region
return result
def crop_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[np.ndarray]:
"""
Crop multiple text lines from an image.
Args:
polygons: List of polygon coordinate lists
image: Source image
Returns:
List of cropped text line images (excluding any failed crops)
"""
cropped_lines = []
for i, polygon in enumerate(polygons):
cropped_line = self.crop_line(image, polygon)
if cropped_line is not None:
cropped_lines.append(cropped_line)
else:
print(f"Warning: Failed to crop line {i}")
return cropped_lines
def encode(self, pixel_values: np.ndarray) -> np.ndarray:
"""
Encode image pixel values into hidden states using the vision encoder.
Args:
pixel_values: Preprocessed image tensor from TrOCRProcessor
Shape: (batch_size, channels, height, width)
Returns:
Encoder hidden states for input to the decoder
Shape: (batch_size, sequence_length, hidden_size)
Raises:
RuntimeError: If encoding fails
"""
try:
encoder_outputs = self.encoder.run(
None,
{"pixel_values": pixel_values}
)[0]
return encoder_outputs
except Exception as e:
raise RuntimeError(f'Failed to encode input: {e}')
def generate(self, encoder_outputs: np.ndarray, batch_size: int) -> np.ndarray:
"""
Generate text tokens using autoregressive decoding with early stopping.
Implements per-sequence early stopping: sequences that generate EOS tokens
stop producing new tokens while others continue, improving efficiency.
Args:
encoder_outputs: Hidden states from the encoder
Shape: (batch_size, sequence_length, hidden_size)
batch_size: Number of sequences in the batch
Returns:
Generated token IDs including start and end tokens
Shape: (batch_size, generated_length)
Raises:
RuntimeError: If generation fails
"""
try:
# Initialize decoder input with start tokens
decoder_input_ids = np.full(
(batch_size, 1),
self.config.decoder_start_token_id,
dtype=np.int64
)
# Track which sequences have finished
finished = np.zeros(batch_size, dtype=bool)
for step in range(self.max_length):
# Run decoder to get next token logits
decoder_outputs = self.decoder.run(
None,
{
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_outputs
}
)[0]
# Get most likely next token for each sequence
next_token_logits = decoder_outputs[:, -1, :]
next_tokens = np.argmax(next_token_logits, axis=-1)
# Check if any sequences just generated EOS token
just_finished = (next_tokens == self.config.eos_token_id)
finished = finished | just_finished
## Replace tokens with PAD for already finished sequences
next_tokens[finished] = self.config.pad_token_id
# Append new tokens to the sequence
next_tokens = next_tokens.reshape(-1, 1)
decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens], axis=1)
# Stop when all sequences have finished
if np.all(finished):
break
return decoder_input_ids
except Exception as e:
raise RuntimeError(f'Failed to generate output ids: {e}')
def predict_text(self, cropped_lines: List[np.ndarray]) -> List[str]:
"""
Predict text content from cropped line images.
Args:
cropped_lines: List of cropped text line images
Returns:
List of predicted text strings
Raises:
RuntimeError: If prediction fails
"""
try:
# Process image with TrOCR processor
# Use 'pt' (PyTorch) then convert to numpy, as 'np' is not supported by fast processors
pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
pixel_values = pixel_values.numpy()
batch_size = pixel_values.shape[0]
#Encode images to hidden states
encoder_hidden_states = self.encode(pixel_values)
# Generate token sequences
generated_ids = self.generate(encoder_hidden_states, batch_size)
# Decode tokens to text
texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return texts
except Exception as e:
raise RuntimeError(f'Failed to predict text: {e}')
def get_text_lines(self, cropped_lines: List[np.ndarray]) -> List[str]:
"""
Process text lines in batches to manage memory efficiently.
Args:
cropped_lines: List of all cropped line images
Returns:
List of predicted text strings for all lines
"""
generated_text = []
# Process in batches
for i in range(0, len(cropped_lines), self.batch_size):
batch = cropped_lines[i:i + self.batch_size]
texts = self.predict_text(batch)
generated_text.extend(texts)
return generated_text
def process_lines(self,
polygons: List[List[List[float]]],
image: np.ndarray) -> List[str]:
"""
Complete pipeline: crop text lines and predict their content.
Args:
polygons: List of polygon coordinate lists defining text line regions
image: Source document image
Returns:
List of predicted text strings for each valid line
"""
# Crop line images from the document
cropped_lines = self.crop_lines(polygons, image)
if not cropped_lines:
print("Warning: No valid cropped lines to process")
return []
# Get text predictions for all lines
generated_text = self.get_text_lines(cropped_lines)
return generated_text