|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
from shapely.validation import make_valid |
|
|
from shapely.geometry import Polygon |
|
|
from rfdetr import RFDETRSegPreview |
|
|
from collections import defaultdict |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
from image_processing import ( |
|
|
load_with_torchvision, |
|
|
preprocess_resize_torch_transform, |
|
|
upscale_bbox, |
|
|
upscale_mask_opencv, |
|
|
crop_line |
|
|
) |
|
|
|
|
|
from utils import get_default_region, get_line_regions, order_regions_lines |
|
|
|
|
|
class SegmentImage: |
|
|
""" |
|
|
Document image segmentation for detecting text regions and lines. |
|
|
|
|
|
Uses an RFDETR segmentation model to detect and extract text regions and lines |
|
|
from document images. Includes polygon merging, validation, and ordering. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the RFDETR segmentation model weights |
|
|
max_size: Maximum dimension (height or width) for image preprocessing (default: 768) |
|
|
confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1) |
|
|
line_percentage_threshold: Minimum polygon area as fraction of image area for lines |
|
|
(default: 7e-05, i.e., 0.007% of image) |
|
|
region_percentage_threshold: Minimum polygon area as fraction of image area for regions |
|
|
(default: 7e-05, i.e., 0.007% of image) |
|
|
line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1) |
|
|
region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1) |
|
|
line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1) |
|
|
region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1) |
|
|
class_id_region: Class ID constant for identifying regions in segmentation model output |
|
|
class_id_line: Class ID constant for identifying lines in segmentation model output |
|
|
min_polygon_points: Minimum number of points to form a valid polygon |
|
|
""" |
|
|
def __init__(self, |
|
|
model_path: str, |
|
|
max_size: int = 768, |
|
|
confidence_threshold: float = 0.15, |
|
|
line_percentage_threshold: float = 7e-05, |
|
|
region_percentage_threshold: float = 7e-05, |
|
|
line_iou: float = 0.3, |
|
|
region_iou: float = 0.3, |
|
|
line_overlap_threshold: float = 0.5, |
|
|
region_overlap_threshold: float = 0.5, |
|
|
class_id_region: int = 1, |
|
|
class_id_line: int = 2, |
|
|
min_polygon_points: int = 3): |
|
|
|
|
|
self.model_path = model_path |
|
|
self.max_size = max_size |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.line_percentage_threshold = line_percentage_threshold |
|
|
self.region_percentage_threshold = region_percentage_threshold |
|
|
self.line_iou = line_iou |
|
|
self.region_iou = region_iou |
|
|
self.line_overlap_threshold = line_overlap_threshold |
|
|
self.region_overlap_threshold = region_overlap_threshold |
|
|
self.class_id_region = class_id_region |
|
|
self.class_id_line = class_id_line |
|
|
self.min_polygon_points = min_polygon_points |
|
|
|
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
|
raise FileNotFoundError(f"Model path does not exist: {self.model_path}") |
|
|
|
|
|
self.init_model() |
|
|
|
|
|
def init_model(self) -> None: |
|
|
""" |
|
|
Load and optimize an RFDETR segmentation model for inference. |
|
|
|
|
|
Raises: |
|
|
Exception: If model initialization fails |
|
|
""" |
|
|
try: |
|
|
self.model = RFDETRSegPreview(pretrain_weights=self.model_path) |
|
|
self.model.optimize_for_inference() |
|
|
print(f"✓ Segmentation model initialized successfully") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f'Failed to initialize segmentation model: {e}') |
|
|
|
|
|
def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]: |
|
|
""" |
|
|
Test and correct the validity of a polygon using Shapely. |
|
|
|
|
|
Converts numpy array to Shapely Polygon, validates it, and attempts |
|
|
to fix invalid geometries using make_valid(). |
|
|
|
|
|
Args: |
|
|
polygon: Array of polygon coordinates with shape (N, 2) |
|
|
|
|
|
Returns: |
|
|
Valid Shapely Polygon object, or None if polygon has fewer than 3 points |
|
|
""" |
|
|
if len(polygon) > 2: |
|
|
try: |
|
|
shapely_polygon = Polygon(polygon) |
|
|
if not shapely_polygon.is_valid: |
|
|
shapely_polygon = make_valid(shapely_polygon) |
|
|
return shapely_polygon |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to validate polygon: {e}") |
|
|
return None |
|
|
else: |
|
|
return None |
|
|
|
|
|
def merge_polygons(self, |
|
|
polygons: List[np.ndarray], |
|
|
polygon_iou: float, |
|
|
overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]: |
|
|
""" |
|
|
Merge overlapping polygons using connected components (union-find algorithm). |
|
|
|
|
|
Uses IoU (Intersection over Union) and area overlap ratio to determine which |
|
|
polygons should be merged. Implements union-find to group connected components |
|
|
of overlapping polygons, then merges each component into a single polygon. |
|
|
|
|
|
Args: |
|
|
polygons: List of polygon coordinate arrays, each with shape (N, 2) |
|
|
polygon_iou: IoU threshold for merging (0-1) |
|
|
overlap_threshold: Minimum area overlap ratio for merging (0-1) |
|
|
|
|
|
Returns: |
|
|
Tuple of: |
|
|
- merged_polygons: List of merged polygon coordinate arrays |
|
|
- polygon_mapping: List mapping each input polygon index to its output |
|
|
polygon index (-1 if invalid/skipped) |
|
|
""" |
|
|
n = len(polygons) |
|
|
if n == 0: |
|
|
return [], [] |
|
|
|
|
|
|
|
|
validated = [self.validate_polygon(p) for p in polygons] |
|
|
|
|
|
|
|
|
parent = list(range(n)) |
|
|
|
|
|
def find(x: int) -> int: |
|
|
"""Find root of element x with path compression.""" |
|
|
if parent[x] != x: |
|
|
parent[x] = find(parent[x]) |
|
|
return parent[x] |
|
|
|
|
|
def union(x: int, y: int) -> None: |
|
|
"""Union two sets containing x and y.""" |
|
|
px, py = find(x), find(y) |
|
|
if px != py: |
|
|
parent[px] = py |
|
|
|
|
|
|
|
|
for i in range(n): |
|
|
poly1 = validated[i] |
|
|
if not poly1: |
|
|
continue |
|
|
|
|
|
for j in range(i + 1, n): |
|
|
poly2 = validated[j] |
|
|
if not poly2 or not poly1.intersects(poly2): |
|
|
continue |
|
|
|
|
|
|
|
|
intersection = poly1.intersection(poly2) |
|
|
union_geom = poly1.union(poly2) |
|
|
iou = intersection.area / union_geom.area if union_geom.area > 0 else 0 |
|
|
|
|
|
|
|
|
should_merge = iou > polygon_iou |
|
|
|
|
|
|
|
|
if not should_merge and overlap_threshold > 0: |
|
|
smaller_area = min(poly1.area, poly2.area) |
|
|
overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0 |
|
|
should_merge = overlap_ratio > overlap_threshold |
|
|
|
|
|
|
|
|
if should_merge: |
|
|
union(i, j) |
|
|
|
|
|
|
|
|
components = defaultdict(list) |
|
|
for i in range(n): |
|
|
if validated[i]: |
|
|
root = find(i) |
|
|
components[root].append(i) |
|
|
|
|
|
|
|
|
merged_polygons = [] |
|
|
polygon_mapping = [-1] * n |
|
|
|
|
|
for root, indices in components.items(): |
|
|
output_idx = len(merged_polygons) |
|
|
|
|
|
if len(indices) == 1: |
|
|
|
|
|
idx = indices[0] |
|
|
merged_polygons.append(polygons[idx]) |
|
|
polygon_mapping[idx] = output_idx |
|
|
|
|
|
else: |
|
|
|
|
|
merged = validated[indices[0]] |
|
|
for idx in indices[1:]: |
|
|
merged = merged.union(validated[idx]) |
|
|
|
|
|
|
|
|
if merged.geom_type == 'Polygon': |
|
|
|
|
|
merged_polygons.append( |
|
|
np.array(merged.exterior.coords).astype(np.int32) |
|
|
) |
|
|
for idx in indices: |
|
|
polygon_mapping[idx] = output_idx |
|
|
|
|
|
elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']: |
|
|
|
|
|
for geom in merged.geoms: |
|
|
if geom.geom_type == 'Polygon': |
|
|
merged_polygons.append( |
|
|
np.array(geom.exterior.coords).astype(np.int32) |
|
|
) |
|
|
|
|
|
for idx in indices: |
|
|
polygon_mapping[idx] = output_idx |
|
|
|
|
|
return merged_polygons, polygon_mapping |
|
|
|
|
|
def calculate_polygon_area(self, vertices: np.ndarray) -> float: |
|
|
""" |
|
|
Calculate polygon area using the Shoelace formula (surveyor's formula). |
|
|
|
|
|
Computes area using coordinate cross products. Works for simple polygons |
|
|
(non-self-intersecting) regardless of vertex ordering. |
|
|
|
|
|
Args: |
|
|
vertices: Array of polygon coordinates with shape (N, 2) |
|
|
|
|
|
Returns: |
|
|
Area of the polygon in square pixels |
|
|
""" |
|
|
x = vertices[:, 0] |
|
|
y = vertices[:, 1] |
|
|
|
|
|
area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0]) |
|
|
return area |
|
|
|
|
|
def mask_to_polygon_cv2(self, |
|
|
mask: np.ndarray, |
|
|
original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]: |
|
|
""" |
|
|
Convert binary segmentation mask to polygon coordinates using OpenCV contours. |
|
|
|
|
|
Extracts contours from mask, converts them to polygons, and scales coordinates |
|
|
back to original image dimensions. Also calculates area percentages for filtering. |
|
|
|
|
|
Args: |
|
|
mask: Binary mask as numpy array (bool or uint8, 0-255) |
|
|
original_shape: Tuple of (height, width) of original image |
|
|
|
|
|
Returns: |
|
|
Tuple of: |
|
|
- scaled_polygons: List of polygon coordinate arrays scaled to original size |
|
|
- area_percentages: Array of polygon areas as fraction of mask size |
|
|
""" |
|
|
|
|
|
if mask.dtype == bool: |
|
|
mask_uint8 = mask.astype(np.uint8) * 255 |
|
|
else: |
|
|
mask_uint8 = mask.astype(np.uint8) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours( |
|
|
mask_uint8, |
|
|
cv2.RETR_EXTERNAL, |
|
|
cv2.CHAIN_APPROX_SIMPLE |
|
|
) |
|
|
|
|
|
|
|
|
polygons = [ |
|
|
contour.squeeze() |
|
|
for contour in contours |
|
|
if len(contour) >= self.min_polygon_points |
|
|
] |
|
|
|
|
|
|
|
|
orig_height, orig_width = original_shape |
|
|
mask_height, mask_width = mask.shape[:2] |
|
|
scale_x = orig_width / mask_width |
|
|
scale_y = orig_height / mask_height |
|
|
|
|
|
|
|
|
scaled_polygons = [] |
|
|
area_percentages = [] |
|
|
mask_area = mask_height * mask_width |
|
|
|
|
|
for poly in polygons: |
|
|
|
|
|
area = self.calculate_polygon_area( |
|
|
poly if len(poly.shape) > 1 else poly.reshape(1, -1) |
|
|
) |
|
|
area_percentage = area / mask_area if mask_area > 0 else 0 |
|
|
area_percentages.append(area_percentage) |
|
|
|
|
|
|
|
|
if len(poly.shape) == 1: |
|
|
scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) |
|
|
else: |
|
|
scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) |
|
|
|
|
|
scaled_polygons.append(scaled_poly) |
|
|
|
|
|
return scaled_polygons, np.array(area_percentages) |
|
|
|
|
|
|
|
|
def process_polygons(self, |
|
|
poly_masks: np.ndarray, |
|
|
image_shape: Tuple[int, int], |
|
|
percentage_threshold: float, |
|
|
overlap_threshold: float, |
|
|
iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]: |
|
|
""" |
|
|
Extract polygons from segmentation masks, filter by area, and merge overlapping ones. |
|
|
|
|
|
Converts masks to polygons, filters out small detections based on area percentage, |
|
|
and merges overlapping polygons based on IoU and overlap criteria. |
|
|
|
|
|
Args: |
|
|
poly_masks: Array of binary segmentation masks from model |
|
|
image_shape: Tuple of (height, width) of original image |
|
|
percentage_threshold: Minimum polygon area as fraction of image |
|
|
overlap_threshold: Minimum overlap ratio for merging polygons |
|
|
iou_threshold: Minimum IoU for merging polygons |
|
|
|
|
|
Returns: |
|
|
Tuple of: |
|
|
- merged_polygons: List of polygon coordinate arrays |
|
|
- merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples |
|
|
""" |
|
|
all_polygons = [] |
|
|
all_area_percentages = [] |
|
|
|
|
|
|
|
|
for mask in poly_masks: |
|
|
polygons, area_percentages = self.mask_to_polygon_cv2( |
|
|
mask=mask, |
|
|
original_shape=image_shape |
|
|
) |
|
|
all_polygons.extend(polygons) |
|
|
all_area_percentages.extend(area_percentages) |
|
|
|
|
|
all_area_percentages = np.array(all_area_percentages) |
|
|
|
|
|
|
|
|
if len(all_area_percentages) == 0: |
|
|
return [], [] |
|
|
|
|
|
valid_indices = np.where(all_area_percentages > percentage_threshold)[0] |
|
|
filtered_polygons = [all_polygons[idx] for idx in valid_indices] |
|
|
|
|
|
if not filtered_polygons: |
|
|
return [], [] |
|
|
|
|
|
|
|
|
merged_polygons, _ = self.merge_polygons( |
|
|
filtered_polygons, |
|
|
iou_threshold, |
|
|
overlap_threshold |
|
|
) |
|
|
|
|
|
|
|
|
merged_max_mins = [] |
|
|
for poly in merged_polygons: |
|
|
if len(poly) > 0: |
|
|
xmax, ymax = np.max(poly, axis=0) |
|
|
xmin, ymin = np.min(poly, axis=0) |
|
|
merged_max_mins.append((xmin, ymin, xmax, ymax)) |
|
|
|
|
|
return merged_polygons, merged_max_mins |
|
|
|
|
|
def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]: |
|
|
""" |
|
|
Detect and extract ordered text lines and regions from a document image. |
|
|
|
|
|
Runs the segmentation model on the image, extracts line and region polygons, |
|
|
merges overlapping detections, associates lines with regions, and orders them |
|
|
for reading sequence. |
|
|
|
|
|
Args: |
|
|
image: PIL Image object in any mode (will be converted to RGB) |
|
|
|
|
|
Returns: |
|
|
List of ordered line dictionaries with region associations, or None if |
|
|
no lines were detected. Each line dict contains coordinates, region ID, |
|
|
and other metadata. |
|
|
""" |
|
|
image_shape = (image.shape[0], image.shape[1]) |
|
|
|
|
|
|
|
|
preprocessed_image = preprocess_resize_torch_transform( |
|
|
image, |
|
|
max_size=self.max_size |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
detections = self.model.predict( |
|
|
preprocessed_image, |
|
|
threshold=self.confidence_threshold |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error during segmentation prediction: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
line_mask = detections.mask[detections.class_id == self.class_id_line] |
|
|
region_mask = detections.mask[detections.class_id == self.class_id_region] |
|
|
|
|
|
|
|
|
merged_line_polygons, merged_line_max_mins = self.process_polygons( |
|
|
line_mask, |
|
|
image_shape, |
|
|
self.line_percentage_threshold, |
|
|
self.line_overlap_threshold, |
|
|
self.line_iou |
|
|
) |
|
|
|
|
|
|
|
|
merged_region_polygons, merged_region_max_mins = self.process_polygons( |
|
|
region_mask, |
|
|
image_shape, |
|
|
self.region_percentage_threshold, |
|
|
self.region_overlap_threshold, |
|
|
self.region_iou |
|
|
) |
|
|
|
|
|
|
|
|
if not merged_line_polygons: |
|
|
print('No text lines detected from image.') |
|
|
return None |
|
|
|
|
|
|
|
|
line_preds = { |
|
|
'coords': merged_line_polygons, |
|
|
'max_min': merged_line_max_mins |
|
|
} |
|
|
|
|
|
|
|
|
if merged_region_polygons: |
|
|
region_preds = [] |
|
|
for num, (region_polygon, region_max_min) in enumerate( |
|
|
zip(merged_region_polygons, merged_region_max_mins) |
|
|
): |
|
|
region_preds.append({ |
|
|
'coords': region_polygon, |
|
|
'id': str(num), |
|
|
'max_min': region_max_min, |
|
|
'name': 'paragraph', |
|
|
'img_shape': image_shape |
|
|
}) |
|
|
else: |
|
|
|
|
|
region_preds = get_default_region(image_shape=image_shape) |
|
|
|
|
|
|
|
|
lines_connected_to_regions = get_line_regions( |
|
|
lines=line_preds, |
|
|
regions=region_preds |
|
|
) |
|
|
|
|
|
|
|
|
ordered_lines = order_regions_lines( |
|
|
lines=lines_connected_to_regions, |
|
|
regions=region_preds |
|
|
) |
|
|
|
|
|
return ordered_lines |