File size: 13,750 Bytes
eee8423
 
f731714
eee8423
f731714
 
 
eee8423
f731714
 
 
 
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29cf47d
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
 
eee8423
 
 
 
 
f731714
 
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f731714
eee8423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from transformers import VisionEncoderDecoderConfig
from typing import List, Tuple, Optional
from transformers import TrOCRProcessor
from pathlib import Path
import numpy as np
import onnxruntime
import math
import time
import cv2
import os

class TextRecognition:
    """
    ONNX-based text recognition class using TrOCR for handwritten text recognition.
    
    Processes text line images through an encoder-decoder architecture, supporting
    batch processing and CUDA acceleration.
    
    Args:
        model_path: Path to the model directory containing ONNX models and config
        device: Device identifier (default: 'cuda:0')
        batch_size: Number of lines to process in parallel (default: 10)
        img_height: Target height for input images (default: 192)
        img_width: Target width for input images (default: 1024)
        max_length: Maximum sequence length for generation (default: 128)
    """
    def __init__(self, 
                model_path: str,
                device: str = 'cuda:0', 
                batch_size: int = 10,
                img_height: int = 192,
                img_width: int = 1024,
                max_length: int = 128):
        self.model_path = model_path
        self.device = device
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.max_length = max_length

        # Validate model path
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model path does not exist: {model_path}")

        self.init_processor()
        self.init_recognition_model()
        
    def init_processor(self) -> None:
        """
        Initialize the TrOCR processor with custom image dimensions.
        
        Raises:
            Exception: If processor initialization fails
        """
        try:
            self.processor = TrOCRProcessor.from_pretrained(
                                                      str(self.model_path),
                                                      use_fast=True,
                                                      do_resize=True, 
                                                      size={
                                                        'height': self.img_height,
                                                        'width': self.img_width
                                                      }
                                                    )
            print(f"✓ Processor loaded with custom image size: {self.img_height}x{self.img_width}")
        except Exception as e:
            raise RuntimeError(f'Failed to initialize processor: {e}')


    def init_recognition_model(self) -> None:
        """
        Initialize the ONNX encoder and decoder models with optimized settings.
        
        Raises:
            FileNotFoundError: If model files are not found
            RuntimeError: If model loading fails
        """
        encoder_path = os.path.join(self.model_path, "encoder_model.onnx")
        decoder_path = os.path.join(self.model_path, "decoder_model.onnx")

        if not os.path.exists(encoder_path):
            raise FileNotFoundError(f"Encoder model not found: {encoder_path}")
        if not os.path.exists(decoder_path):
            raise FileNotFoundError(f"Decoder model not found: {decoder_path}")

        # Session options for better performance
        sess_options = onnxruntime.SessionOptions()
        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.intra_op_num_threads = 4

        providers = [
            'CUDAExecutionProvider',
            'CPUExecutionProvider'
        ]

        # Load model config
        self.config = VisionEncoderDecoderConfig.from_pretrained(str(self.model_path))

        try:
            print("Loading encoder...")
            self.encoder = onnxruntime.InferenceSession(
                str(encoder_path),
                sess_options=sess_options,
                providers=providers
            )
            
            print("Loading decoder...")
            self.decoder = onnxruntime.InferenceSession(
                str(decoder_path),
                sess_options=sess_options,
                providers=providers
            )

            # Report which provider is actually being used
            encoder_provider = self.encoder.get_providers()[0]
            decoder_provider = self.decoder.get_providers()[0]
            print(f"✓ Using execution provider: Encoder={encoder_provider}, Decoder={decoder_provider}")
            
        except Exception as e:
            raise RuntimeError(f'Failed to load recognition models: {e}')

    def crop_line(self, image: np.ndarray, polygon: List[List[float]]) -> Optional[np.ndarray]:
        """
        Crop a text line from an image based on polygon coordinates.
        
        Creates a masked crop where the polygon area contains the original image
        and the background is filled with white pixels.
        
        Args:
            image: Source image as numpy array
            polygon: List of [x, y] coordinate pairs defining the text line region
            
        Returns:
            Cropped and masked text line image, or None if polygon is invalid
        """
        # Convert polygon to integer coordinates
        polygon_array = np.array([[int(pt[0]), int(pt[1])] for pt in polygon], dtype=np.int32)

        # Get bounding rectangle
        rect = cv2.boundingRect(polygon_array)
        x, y, w, h = rect

        # Validate rectangle
        if w <= 0 or h <= 0:
            print(f"Warning: Invalid bounding rect dimensions: {w}x{h}")
            return None
        
        # Crop image to bounding rectangle
        cropped_image = image[y:y + h, x:x + w]

        if cropped_image.size == 0:
            print(f"Warning: Empty cropped image at rect {rect}")
            return None

        # Create mask for the polygon region
        mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8)

        # Adjust polygon coordinates relative to the cropped region
        polygon_offset = polygon_array - np.array([[x, y]])
        cv2.drawContours(mask, [polygon_offset], -1, (255, 255, 255), -1, cv2.LINE_AA)

        # Extract the polygon region from the cropped image
        masked_region = cv2.bitwise_and(cropped_image, cropped_image, mask=mask)

        # Create white background
        white_background = np.ones_like(cropped_image, np.uint8) * 255
        cv2.bitwise_not(white_background, white_background, mask=mask)
        
        # Overlay the masked region on white background
        result = white_background + masked_region
        
        return result

    def crop_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[np.ndarray]:
        """
        Crop multiple text lines from an image.
        
        Args:
            polygons: List of polygon coordinate lists
            image: Source image
            
        Returns:
            List of cropped text line images (excluding any failed crops)
        """
        cropped_lines = []
        for i, polygon in enumerate(polygons):
            cropped_line = self.crop_line(image, polygon)
            if cropped_line is not None:
                cropped_lines.append(cropped_line)
            else:
                print(f"Warning: Failed to crop line {i}")
        return cropped_lines
    
    def encode(self, pixel_values: np.ndarray) -> np.ndarray:
        """
        Encode image pixel values into hidden states using the vision encoder.
        
        Args:
            pixel_values: Preprocessed image tensor from TrOCRProcessor
                         Shape: (batch_size, channels, height, width)
            
        Returns:
            Encoder hidden states for input to the decoder
            Shape: (batch_size, sequence_length, hidden_size)
            
        Raises:
            RuntimeError: If encoding fails
        """
        try:
            encoder_outputs = self.encoder.run(
                None,
                {"pixel_values": pixel_values}
            )[0]
            return encoder_outputs
        except Exception as e:
            raise RuntimeError(f'Failed to encode input: {e}')

    def generate(self, encoder_outputs: np.ndarray, batch_size: int) -> np.ndarray:
        """
        Generate text tokens using autoregressive decoding with early stopping.
        
        Implements per-sequence early stopping: sequences that generate EOS tokens
        stop producing new tokens while others continue, improving efficiency.
        
        Args:
            encoder_outputs: Hidden states from the encoder
                            Shape: (batch_size, sequence_length, hidden_size)
            batch_size: Number of sequences in the batch
            
        Returns:
            Generated token IDs including start and end tokens
            Shape: (batch_size, generated_length)
            
        Raises:
            RuntimeError: If generation fails
        """
        try:
            # Initialize decoder input with start tokens
            decoder_input_ids = np.full(
                (batch_size, 1), 
                self.config.decoder_start_token_id, 
                dtype=np.int64
            )
            
            # Track which sequences have finished
            finished = np.zeros(batch_size, dtype=bool)
            
            for step in range(self.max_length):
                # Run decoder to get next token logits
                decoder_outputs = self.decoder.run(
                    None,
                    {
                        "input_ids": decoder_input_ids,
                        "encoder_hidden_states": encoder_outputs
                    }
                )[0]

                # Get most likely next token for each sequence
                next_token_logits = decoder_outputs[:, -1, :]
                next_tokens = np.argmax(next_token_logits, axis=-1)
                
                # Check if any sequences just generated EOS token
                just_finished = (next_tokens == self.config.eos_token_id)
                finished = finished | just_finished
                
                ## Replace tokens with PAD for already finished sequences
                next_tokens[finished] = self.config.pad_token_id
                
                # Append new tokens to the sequence
                next_tokens = next_tokens.reshape(-1, 1)
                decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens], axis=1)
                
                # Stop when all sequences have finished
                if np.all(finished):
                    break
            
            return decoder_input_ids

        except Exception as e:
            raise RuntimeError(f'Failed to generate output ids: {e}')

    def predict_text(self, cropped_lines: List[np.ndarray]) -> List[str]:
        """
        Predict text content from cropped line images.
        
        Args:
            cropped_lines: List of cropped text line images
            
        Returns:
            List of predicted text strings
            
        Raises:
            RuntimeError: If prediction fails
        """
        try:
            # Process image with TrOCR processor
            # Use 'pt' (PyTorch) then convert to numpy, as 'np' is not supported by fast processors
            pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
            pixel_values = pixel_values.numpy()
            batch_size = pixel_values.shape[0]
            
            #Encode images to hidden states
            encoder_hidden_states = self.encode(pixel_values)

            # Generate token sequences
            generated_ids = self.generate(encoder_hidden_states, batch_size)

            # Decode tokens to text
            texts = self.processor.batch_decode(
                generated_ids, 
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            
            return texts

        except Exception as e:
            raise RuntimeError(f'Failed to predict text: {e}')

    def get_text_lines(self, cropped_lines: List[np.ndarray]) -> List[str]:
        """
        Process text lines in batches to manage memory efficiently.
        
        Args:
            cropped_lines: List of all cropped line images
            
        Returns:
            List of predicted text strings for all lines
        """
        generated_text = []
        
        # Process in batches
        for i in range(0, len(cropped_lines), self.batch_size):
            batch = cropped_lines[i:i + self.batch_size]
            texts = self.predict_text(batch)
            generated_text.extend(texts)
            
        return generated_text

    def process_lines(self, 
                     polygons: List[List[List[float]]], 
                     image: np.ndarray) -> List[str]:
        """
        Complete pipeline: crop text lines and predict their content.
        
        Args:
            polygons: List of polygon coordinate lists defining text line regions
            image: Source document image
            
        Returns:
            List of predicted text strings for each valid line
        """
        # Crop line images from the document
        cropped_lines = self.crop_lines(polygons, image)
        
        if not cropped_lines:
            print("Warning: No valid cropped lines to process")
            return []
        
        # Get text predictions for all lines
        generated_text = self.get_text_lines(cropped_lines)
        
        return generated_text