yolo_layoutlm / working_yolo_pipeline.py
heerjtdev's picture
Update working_yolo_pipeline.py
7ce2214 verified
import json
import argparse
import os
import re
import torch
import torch.nn as nn
from TorchCRF import CRF
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
from typing import List, Dict, Any, Optional, Union, Tuple
import fitz # PyMuPDF
import numpy as np
import cv2
from ultralytics import YOLO
import glob
import pytesseract
from PIL import Image
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
import sys
import io
import base64
import tempfile
import time
import shutil
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# ============================================================================
# --- CONFIGURATION AND CONSTANTS ---
# ============================================================================
# NOTE: Update these paths to match your environment before running!
WEIGHTS_PATH = 'YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt'
DEFAULT_LAYOUTLMV3_MODEL_PATH = "97.pth"
# DIRECTORY CONFIGURATION
OCR_JSON_OUTPUT_DIR = './ocr_json_output_final'
FIGURE_EXTRACTION_DIR = './figure_extraction'
TEMP_IMAGE_DIR = './temp_pdf_images'
# Detection parameters
CONF_THRESHOLD = 0.2
TARGET_CLASSES = ['figure', 'equation']
IOU_MERGE_THRESHOLD = 0.4
IOA_SUPPRESSION_THRESHOLD = 0.7
LINE_TOLERANCE = 15
#Similarity
SIMILARITY_THRESHOLD = 0.10
RESOLUTION_MARGIN = 0.05
# Global counters for sequential numbering across the entire PDF
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# LayoutLMv3 Labels
ID_TO_LABEL = {
0: "O",
1: "B-QUESTION", 2: "I-QUESTION",
3: "B-OPTION", 4: "I-OPTION",
5: "B-ANSWER", 6: "I-ANSWER",
7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
9: "B-PASSAGE", 10: "I-PASSAGE"
}
NUM_LABELS = len(ID_TO_LABEL)
# ============================================================================
# --- PERFORMANCE OPTIMIZATION: OCR CACHE ---
# ============================================================================
class OCRCache:
"""Caches OCR results per page to avoid redundant Tesseract runs."""
def __init__(self):
self.cache = {}
def get_key(self, pdf_path: str, page_num: int) -> str:
return f"{pdf_path}:{page_num}"
def has_ocr(self, pdf_path: str, page_num: int) -> bool:
return self.get_key(pdf_path, page_num) in self.cache
def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]:
return self.cache.get(self.get_key(pdf_path, page_num))
def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list):
self.cache[self.get_key(pdf_path, page_num)] = ocr_data
def clear(self):
self.cache.clear()
# Global OCR cache instance
_ocr_cache = OCRCache()
# ============================================================================
# --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS ---
# ============================================================================
def calculate_iou(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
union_area = float(box_a_area + box_b_area - intersection_area)
return intersection_area / union_area if union_area > 0 else 0
def calculate_ioa(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
return intersection_area / box_a_area if box_a_area > 0 else 0
def filter_nested_boxes(detections, ioa_threshold=0.80):
"""
Removes boxes that are inside larger boxes (Containment Check).
Prioritizes keeping the LARGEST box (the 'parent' container).
"""
if not detections:
return []
# 1. Calculate Area for all detections
for d in detections:
x1, y1, x2, y2 = d['coords']
d['area'] = (x2 - x1) * (y2 - y1)
# 2. Sort by Area Descending (Largest to Smallest)
# This ensures we process the 'container' first
detections.sort(key=lambda x: x['area'], reverse=True)
keep_indices = []
is_suppressed = [False] * len(detections)
for i in range(len(detections)):
if is_suppressed[i]: continue
keep_indices.append(i)
box_a = detections[i]['coords']
# Compare with all smaller boxes
for j in range(i + 1, len(detections)):
if is_suppressed[j]: continue
box_b = detections[j]['coords']
# Calculate Intersection
x_left = max(box_a[0], box_b[0])
y_top = max(box_a[1], box_b[1])
x_right = min(box_a[2], box_b[2])
y_bottom = min(box_a[3], box_b[3])
if x_right < x_left or y_bottom < y_top:
intersection = 0
else:
intersection = (x_right - x_left) * (y_bottom - y_top)
# Calculate IoA (Intersection over Area of the SMALLER box)
# Since we sorted by area, 'box_b' (detections[j]) is the smaller one.
area_b = detections[j]['area']
if area_b > 0:
ioa_small = intersection / area_b
# If the small box is > 90% inside the big box, suppress the small one.
if ioa_small > ioa_threshold:
is_suppressed[j] = True
#print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'")
return [detections[i] for i in keep_indices]
def merge_overlapping_boxes(detections, iou_threshold):
if not detections: return []
detections.sort(key=lambda d: d['conf'], reverse=True)
merged_detections = []
is_merged = [False] * len(detections)
for i in range(len(detections)):
if is_merged[i]: continue
current_box = detections[i]['coords']
current_class = detections[i]['class']
merged_x1, merged_y1, merged_x2, merged_y2 = current_box
for j in range(i + 1, len(detections)):
if is_merged[j] or detections[j]['class'] != current_class: continue
other_box = detections[j]['coords']
iou = calculate_iou(current_box, other_box)
if iou > iou_threshold:
merged_x1 = min(merged_x1, other_box[0])
merged_y1 = min(merged_y1, other_box[1])
merged_x2 = max(merged_x2, other_box[2])
merged_y2 = max(merged_y2, other_box[3])
is_merged[j] = True
merged_detections.append({
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
})
return merged_detections
def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list:
"""
Filters out raw words that are inside YOLO boxes and replaces them with
a single solid 'placeholder' block for the column detector.
"""
if not yolo_detections:
return raw_word_data
# 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points)
pdf_space_boxes = []
for det in yolo_detections:
x1, y1, x2, y2 = det['coords']
pdf_box = (
x1 / scale_factor,
y1 / scale_factor,
x2 / scale_factor,
y2 / scale_factor
)
pdf_space_boxes.append(pdf_box)
# 2. Filter out raw words that are inside YOLO boxes
cleaned_word_data = []
for word_tuple in raw_word_data:
wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4]
w_center_x = (wx1 + wx2) / 2
w_center_y = (wy1 + wy2) / 2
is_inside_yolo = False
for px1, py1, px2, py2 in pdf_space_boxes:
if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2:
is_inside_yolo = True
break
if not is_inside_yolo:
cleaned_word_data.append(word_tuple)
# 3. Add the YOLO boxes themselves as "Solid Words"
for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes):
dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2)
cleaned_word_data.append(dummy_entry)
return cleaned_word_data
# ============================================================================
# --- MISSING HELPER FUNCTION ---
# ============================================================================
def preprocess_image_for_ocr(img_np):
"""
Converts image to grayscale and applies Otsu's Binarization
to separate text from background clearly.
"""
# 1. Convert to Grayscale if needed
if len(img_np.shape) == 3:
gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
else:
gray = img_np
# 2. Apply Otsu's Thresholding (Automatic binary threshold)
# This makes text solid black and background solid white
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return thresh
def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float:
"""
Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator.
A valid column split should split > 65% of the page verticality.
"""
if not word_data:
return 0.0
# Determine the vertical span of the actual text content
y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2
min_y, max_y = min(y_coords), max(y_coords)
total_text_height = max_y - min_y
if total_text_height <= 0:
return 0.0
# Create a boolean array representing the Y-axis (1 pixel per unit)
gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool)
zone_left = sep_x - (gutter_width / 2)
zone_right = sep_x + (gutter_width / 2)
offset_y = int(min_y)
for _, x1, y1, x2, y2 in word_data:
# Check if this word horizontally interferes with the separator
if x2 > zone_left and x1 < zone_right:
y_start_idx = max(0, int(y1) - offset_y)
y_end_idx = min(len(gap_open_mask), int(y2) - offset_y)
if y_end_idx > y_start_idx:
gap_open_mask[y_start_idx:y_end_idx] = False
open_pixels = np.sum(gap_open_mask)
coverage_ratio = open_pixels / len(gap_open_mask)
return coverage_ratio
def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]:
"""
Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage.
"""
if not word_data: return []
x_points = []
# Use only word_data elements 1 (x1) and 3 (x2)
for item in word_data:
x_points.extend([item[1], item[3]])
if not x_points: return []
max_x = max(x_points)
# 1. Determine total text height for ratio calculation
y_coords = [item[2] for item in word_data] + [item[4] for item in word_data]
min_y, max_y = min(y_coords), max(y_coords)
total_text_height = max_y - min_y
if total_text_height <= 0: return []
# Histogram Setup
bin_size = params.get('cluster_bin_size', 5)
smoothing = params.get('cluster_smoothing', 1)
min_width = params.get('cluster_min_width', 20)
threshold_percentile = params.get('cluster_threshold_percentile', 85)
num_bins = int(np.ceil(max_x / bin_size))
hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing)
inverted_signal = np.max(smoothed_hist) - smoothed_hist
peaks, properties = find_peaks(
inverted_signal,
height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile),
distance=min_width / bin_size
)
if not peaks.size: return []
separator_x_coords = [int(bin_edges[p]) for p in peaks]
final_separators = []
for x_coord in separator_x_coords:
# --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) ---
# Calculate the total vertical height of words that physically cross this line.
bridging_height = 0
bridging_count = 0
for item in word_data:
wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4]
# Check if this word physically sits on top of the separator line
if wx1 < x_coord and wx2 > x_coord:
word_h = wy2 - wy1
bridging_height += word_h
bridging_count += 1
# Calculate Ratio: How much of the page's text height is blocked by these crossing words?
bridging_ratio = bridging_height / total_text_height
# THRESHOLD: If bridging blocks > 8% of page height, REJECT.
# This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs.
if bridging_ratio > 0.08:
print(f" ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.")
continue
# --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) ---
# The gap must exist cleanly for > 65% of the text height.
coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width)
if coverage >= 0.80:
final_separators.append(x_coord)
print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")
else:
print(f" ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")
return sorted(final_separators)
def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int,
top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
"""Extract word data with OCR caching to avoid redundant Tesseract runs."""
word_data = page.get_text("words")
if len(word_data) > 0:
word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
else:
if _ocr_cache.has_ocr(pdf_path, page_num):
word_data = _ocr_cache.get_ocr(pdf_path, page_num)
else:
try:
# --- OPTIMIZATION START ---
# 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI)
zoom_level = 4.0
pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level))
# 2. Convert directly to OpenCV format (Faster than PIL)
img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
if pix.n == 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif pix.n == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR)
# 3. Apply Preprocessing (Thresholding)
processed_img = preprocess_image_for_ocr(img_np)
# 4. Optimized Tesseract Config
# --psm 6: Assume a single uniform block of text (Great for columns/questions)
# --oem 3: Default engine (LSTM)
custom_config = r'--oem 3 --psm 6'
data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, config=custom_config)
full_word_data = []
for i in range(len(data['level'])):
text = data['text'][i].strip()
if text:
# Scale coordinates back to PDF points
x1 = data['left'][i] / zoom_level
y1 = data['top'][i] / zoom_level
x2 = (data['left'][i] + data['width'][i]) / zoom_level
y2 = (data['top'][i] + data['height'][i]) / zoom_level
full_word_data.append((text, x1, y1, x2, y2))
word_data = full_word_data
_ocr_cache.set_ocr(pdf_path, page_num, word_data)
# --- OPTIMIZATION END ---
except Exception as e:
print(f" ❌ OCR Error in detection phase: {e}")
return []
# Apply margin filtering
page_height = page.rect.height
y_min = page_height * top_margin_percent
y_max = page_height * (1 - bottom_margin_percent)
return [d for d in word_data if d[2] >= y_min and d[4] <= y_max]
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
img_data = pix.samples
img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
if pix.n == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
elif pix.n == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
raw_word_data = fitz_page.get_text("words")
converted_ocr_output = []
DEFAULT_CONFIDENCE = 99.0
for x1, y1, x2, y2, word, *rest in raw_word_data:
if not word.strip(): continue
x1_pix = int(x1 * scale_factor)
y1_pix = int(y1 * scale_factor)
x2_pix = int(x2 * scale_factor)
y2_pix = int(y2 * scale_factor)
converted_ocr_output.append({
'type': 'text',
'word': word,
'confidence': DEFAULT_CONFIDENCE,
'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
'y0': y1_pix, 'x0': x1_pix
})
return converted_ocr_output
def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
page_num: int, fitz_page: fitz.Page,
pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
"""
OPTIMIZED FLOW:
1. Run YOLO to find Equations/Tables.
2. Mask raw text with YOLO boxes.
3. Run Column Detection on the MASKED data.
4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output.
"""
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
start_time_total = time.time()
if original_img is None:
print(f" ❌ Invalid image for page {page_num}.")
return None, None
# ====================================================================
# --- STEP 1: YOLO DETECTION ---
# ====================================================================
start_time_yolo = time.time()
results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
relevant_detections = []
if results and results[0].boxes:
for box in results[0].boxes:
class_id = int(box.cls[0])
class_name = model.names[class_id]
if class_name in TARGET_CLASSES:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
relevant_detections.append(
{'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
)
merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
# ====================================================================
# --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
# ====================================================================
# Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations
raw_words_for_layout = get_word_data_for_detection(
fitz_page, pdf_path, page_num,
top_margin_percent=0.10, bottom_margin_percent=0.10
)
masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0)
# ====================================================================
# --- STEP 3: COLUMN DETECTION ---
# ====================================================================
page_width_pdf = fitz_page.rect.width
page_height_pdf = fitz_page.rect.height
column_detection_params = {
'cluster_bin_size': 2, 'cluster_smoothing': 2,
'cluster_min_width': 10, 'cluster_threshold_percentile': 85,
}
separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf)
page_separator_x = None
if separators:
central_min = page_width_pdf * 0.35
central_max = page_width_pdf * 0.65
central_separators = [s for s in separators if central_min <= s <= central_max]
if central_separators:
center_x = page_width_pdf / 2
page_separator_x = min(central_separators, key=lambda x: abs(x - center_x))
print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}")
else:
print(" ⚠️ Gutter found off-center. Ignoring.")
else:
print(" -> Single Column Layout Confirmed.")
# ====================================================================
# --- STEP 4: COMPONENT EXTRACTION (Save Images) ---
# ====================================================================
start_time_components = time.time()
component_metadata = []
fig_count_page = 0
eq_count_page = 0
for detection in merged_detections:
x1, y1, x2, y2 = detection['coords']
class_name = detection['class']
if class_name == 'figure':
GLOBAL_FIGURE_COUNT += 1
counter = GLOBAL_FIGURE_COUNT
component_word = f"FIGURE{counter}"
fig_count_page += 1
elif class_name == 'equation':
GLOBAL_EQUATION_COUNT += 1
counter = GLOBAL_EQUATION_COUNT
component_word = f"EQUATION{counter}"
eq_count_page += 1
else:
continue
component_crop = original_img[y1:y2, x1:x2]
component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png"
cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop)
y_midpoint = (y1 + y2) // 2
component_metadata.append({
'type': class_name, 'word': component_word,
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'y0': int(y_midpoint), 'x0': int(x1)
})
# ====================================================================
# --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) ---
# ====================================================================
raw_ocr_output = []
scale_factor = 2.0 # Pipeline standard scale
try:
# Try getting native text first
raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor)
except Exception as e:
print(f" ❌ Native text extraction failed: {e}")
# If native text is missing, fall back to OCR
if not raw_ocr_output:
if _ocr_cache.has_ocr(pdf_path, page_num):
print(f" ⚡ Using cached Tesseract OCR for page {page_num}")
cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num)
for word_tuple in cached_word_data:
word_text, x1, y1, x2, y2 = word_tuple
# Scale from PDF points to Pipeline Pixels (2.0)
x1_pix = int(x1 * scale_factor)
y1_pix = int(y1 * scale_factor)
x2_pix = int(x2 * scale_factor)
y2_pix = int(y2 * scale_factor)
raw_ocr_output.append({
'type': 'text', 'word': word_text, 'confidence': 95.0,
'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
'y0': y1_pix, 'x0': x1_pix
})
else:
# === START OF OPTIMIZED OCR BLOCK ===
try:
# 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI)
# We do this specifically for OCR accuracy, separate from the pipeline image
ocr_zoom = 4.0
pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom))
# Convert PyMuPDF Pixmap to OpenCV format
img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, pix_ocr.n)
if pix_ocr.n == 3: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR)
elif pix_ocr.n == 4: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR)
# 2. Preprocess (Binarization)
# Ensure 'preprocess_image_for_ocr' is defined at top of file!
processed_img = preprocess_image_for_ocr(img_ocr_np)
# 3. Run Tesseract with Optimized Configuration
# --oem 3: Default LSTM engine
# --psm 6: Assume a single uniform block of text (Critical for lists/questions)
custom_config = r'--oem 3 --psm 6'
hocr_data = pytesseract.image_to_data(
processed_img,
output_type=pytesseract.Output.DICT,
config=custom_config
)
for i in range(len(hocr_data['level'])):
text = hocr_data['text'][i].strip()
if text and hocr_data['conf'][i] > -1:
# 4. Coordinate Mapping
# We scanned at Zoom 4.0, but our pipeline expects Zoom 2.0.
# Scale Factor = (Target 2.0) / (Source 4.0) = 0.5
scale_adjustment = scale_factor / ocr_zoom
x1 = int(hocr_data['left'][i] * scale_adjustment)
y1 = int(hocr_data['top'][i] * scale_adjustment)
w = int(hocr_data['width'][i] * scale_adjustment)
h = int(hocr_data['height'][i] * scale_adjustment)
x2 = x1 + w
y2 = y1 + h
raw_ocr_output.append({
'type': 'text',
'word': text,
'confidence': float(hocr_data['conf'][i]),
'bbox': [x1, y1, x2, y2],
'y0': y1,
'x0': x1
})
except Exception as e:
print(f" ❌ Tesseract OCR Error: {e}")
# === END OF OPTIMIZED OCR BLOCK ===
# ====================================================================
# --- STEP 6: OCR CLEANING AND MERGING ---
# ====================================================================
items_to_sort = []
for ocr_word in raw_ocr_output:
is_suppressed = False
for component in component_metadata:
# Do not include words that are inside figure/equation boxes
ioa = calculate_ioa(ocr_word['bbox'], component['bbox'])
if ioa > IOA_SUPPRESSION_THRESHOLD:
is_suppressed = True
break
if not is_suppressed:
items_to_sort.append(ocr_word)
# Add figures/equations back into the flow as "words"
items_to_sort.extend(component_metadata)
# ====================================================================
# --- STEP 7: LINE-BASED SORTING ---
# ====================================================================
items_to_sort.sort(key=lambda x: (x['y0'], x['x0']))
lines = []
for item in items_to_sort:
placed = False
for line in lines:
y_ref = min(it['y0'] for it in line)
if abs(y_ref - item['y0']) < LINE_TOLERANCE:
line.append(item)
placed = True
break
if not placed and item['type'] in ['equation', 'figure']:
for line in lines:
y_ref = min(it['y0'] for it in line)
if abs(y_ref - item['y0']) < 20:
line.append(item)
placed = True
break
if not placed:
lines.append([item])
for line in lines:
line.sort(key=lambda x: x['x0'])
final_output = []
for line in lines:
for item in line:
data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]}
if 'tag' in item: data_item['tag'] = item['tag']
final_output.append(data_item)
return final_output, page_separator_x
def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
_ocr_cache.clear()
print("\n" + "=" * 80)
print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---")
print("=" * 80)
if not os.path.exists(pdf_path):
print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.")
return None
os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True)
os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True)
model = YOLO(WEIGHTS_PATH)
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
try:
doc = fitz.open(pdf_path)
print(f"✅ Opened PDF: {pdf_name} ({doc.page_count} pages)")
except Exception as e:
print(f"❌ ERROR loading PDF file: {e}")
return None
all_pages_data = []
total_pages_processed = 0
mat = fitz.Matrix(2.0, 2.0)
print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]")
for page_num_0_based in range(doc.page_count):
page_num = page_num_0_based + 1
print(f" -> Processing Page {page_num}/{doc.page_count}...")
fitz_page = doc.load_page(page_num_0_based)
try:
pix = fitz_page.get_pixmap(matrix=mat)
original_img = pixmap_to_numpy(pix)
except Exception as e:
print(f" ❌ Error converting page {page_num} to image: {e}")
continue
final_output, page_separator_x = preprocess_and_ocr_page(
original_img,
model,
pdf_path,
page_num,
fitz_page,
pdf_name
)
if final_output is not None:
page_data = {
"page_number": page_num,
"data": final_output,
"column_separator_x": page_separator_x
}
all_pages_data.append(page_data)
total_pages_processed += 1
else:
print(f" ❌ Skipped page {page_num} due to processing error.")
doc.close()
if all_pages_data:
try:
with open(preprocessed_json_path, 'w') as f:
json.dump(all_pages_data, f, indent=4)
print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}")
except Exception as e:
print(f"❌ ERROR saving combined JSON output: {e}")
return None
else:
print("❌ WARNING: No page data generated. Halting pipeline.")
return None
print("\n" + "=" * 80)
print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---")
print("=" * 80)
return preprocessed_json_path
# ============================================================================
# --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS ---
# ============================================================================
class LayoutLMv3ForTokenClassification(nn.Module):
def __init__(self, num_labels: int = NUM_LABELS):
super().__init__()
self.num_labels = num_labels
config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.crf = CRF(num_labels)
self.init_weights()
def init_weights(self):
nn.init.xavier_uniform_(self.classifier.weight)
if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias)
def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None):
outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True)
sequence_output = outputs.last_hidden_state
emissions = self.classifier(sequence_output)
mask = attention_mask.bool()
if labels is not None:
loss = -self.crf(emissions, labels, mask=mask).mean()
return loss
else:
return self.crf.viterbi_decode(emissions, mask=mask)
def _merge_integrity(all_token_data: List[Dict[str, Any]],
column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]:
"""Splits the token data objects into column chunks based on a separator."""
if column_separator_x is None:
print(" -> No column separator. Treating as one chunk.")
return [all_token_data]
left_column_tokens, right_column_tokens = [], []
for token_data in all_token_data:
bbox_raw = token_data['bbox_raw_pdf_space']
center_x = (bbox_raw[0] + bbox_raw[2]) / 2
if center_x < column_separator_x:
left_column_tokens.append(token_data)
else:
right_column_tokens.append(token_data)
chunks = [c for c in [left_column_tokens, right_column_tokens] if c]
print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.")
return chunks
def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
preprocessed_json_path: str,
column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
print("\n" + "=" * 80)
print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---")
print("=" * 80)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" -> Using device: {device}")
try:
model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
checkpoint = torch.load(model_path, map_location=device)
model_state = checkpoint.get('model_state_dict', checkpoint)
fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
model.load_state_dict(fixed_state_dict)
model.to(device)
model.eval()
print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.")
except Exception as e:
print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}")
return []
try:
with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
preprocessed_data = json.load(f)
print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.")
except Exception:
print("❌ Error loading preprocessed JSON.")
return []
try:
doc = fitz.open(pdf_path)
except Exception:
print("❌ Error loading PDF.")
return []
final_page_predictions = []
CHUNK_SIZE = 500
for page_data in preprocessed_data:
page_num_1_based = page_data['page_number']
page_num_0_based = page_num_1_based - 1
page_raw_predictions = []
print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***")
fitz_page = doc.load_page(page_num_0_based)
page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).")
all_token_data = []
scale_factor = 2.0
for item in page_data['data']:
raw_yolo_bbox = item['bbox']
bbox_pdf = [
int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor),
int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor)
]
normalized_bbox = [
max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
]
all_token_data.append({
"word": item['word'],
"bbox_raw_pdf_space": bbox_pdf,
"bbox_normalized": normalized_bbox,
"item_original_data": item
})
if not all_token_data: continue
column_separator_x = page_data.get('column_separator_x', None)
if column_separator_x is not None:
print(f" -> Using SAVED column separator: X={column_separator_x}")
else:
print(" -> No column separator found. Assuming single chunk.")
token_chunks = _merge_integrity(all_token_data, column_separator_x)
total_chunks = len(token_chunks)
for chunk_idx, chunk_tokens in enumerate(token_chunks):
if not chunk_tokens: continue
chunk_words = [t['word'] for t in chunk_tokens]
chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens]
total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE
for i in range(0, len(chunk_words), CHUNK_SIZE):
sub_chunk_idx = i // CHUNK_SIZE + 1
sub_words = chunk_words[i:i + CHUNK_SIZE]
sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE]
print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...")
encoded_input = tokenizer(
sub_words, boxes=sub_bboxes, truncation=True, padding="max_length",
max_length=512, return_tensors="pt"
)
input_ids = encoded_input['input_ids'].to(device)
bbox = encoded_input['bbox'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
with torch.no_grad():
predictions_int_list = model(input_ids, bbox, attention_mask)
if not predictions_int_list: continue
predictions_int = predictions_int_list[0]
word_ids = encoded_input.word_ids()
word_idx_to_pred_id = {}
for token_idx, word_idx in enumerate(word_ids):
if word_idx is not None and word_idx < len(sub_words):
if word_idx not in word_idx_to_pred_id:
word_idx_to_pred_id[word_idx] = predictions_int[token_idx]
for current_word_idx in range(len(sub_words)):
pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
predicted_label = ID_TO_LABEL[pred_id]
original_token = sub_tokens_data[current_word_idx]
page_raw_predictions.append({
"word": original_token['word'],
"bbox": original_token['bbox_raw_pdf_space'],
"predicted_label": predicted_label,
"page_number": page_num_1_based
})
if page_raw_predictions:
final_page_predictions.append({
"page_number": page_num_1_based,
"data": page_raw_predictions
})
print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***")
doc.close()
print("\n" + "=" * 80)
print("--- LAYOUTLMV3 INFERENCE COMPLETE ---")
print("=" * 80)
return final_page_predictions
def create_label_studio_span(page_results, start_idx, end_idx, label):
entity_words = [page_results[i]['word'] for i in range(start_idx, end_idx + 1)]
entity_bboxes = [page_results[i]['bbox'] for i in range(start_idx, end_idx + 1)]
x0 = min(bbox[0] for bbox in entity_bboxes)
y0 = min(bbox[1] for bbox in entity_bboxes)
x1 = max(bbox[2] for bbox in entity_bboxes)
y1 = max(bbox[3] for bbox in entity_bboxes)
all_words_on_page = [r['word'] for r in page_results]
start_char = len(" ".join(all_words_on_page[:start_idx]))
if start_idx != 0: start_char += 1
end_char = start_char + len(" ".join(entity_words))
span_text = " ".join(entity_words)
return {
"from_name": "label", "to_name": "text", "type": "labels",
"value": {
"start": start_char, "end": end_char, "text": span_text,
"labels": [label],
"bbox": {"x": x0, "y": y0, "width": x1 - x0, "height": y1 - y0}
}, "score": 0.99
}
def convert_raw_predictions_to_label_studio(page_data_list, output_path: str):
final_tasks = []
print("\n[PHASE: LABEL STUDIO CONVERSION]")
for page_data in page_data_list:
page_num = page_data['page_number']
page_results = page_data['data']
if not page_results: continue
original_words = [r['word'] for r in page_results]
text_string = " ".join(original_words)
results = []
current_entity_label = None
current_entity_start_word_index = None
for i, pred_item in enumerate(page_results):
label = pred_item['predicted_label']
tag_only = label.split('-', 1)[-1] if '-' in label else label
if label.startswith('B-'):
if current_entity_label:
results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label))
current_entity_label = tag_only
current_entity_start_word_index = i
elif label.startswith('I-') and current_entity_label == tag_only:
continue
else:
if current_entity_label:
results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label))
current_entity_label = None
current_entity_start_word_index = None
if current_entity_label:
results.append(create_label_studio_span(page_results, current_entity_start_word_index, len(page_results) - 1, current_entity_label))
final_tasks.append({
"data": {
"text": text_string, "original_words": original_words,
"original_bboxes": [r['bbox'] for r in page_results]
},
"annotations": [{"result": results}],
"meta": {"page_number": page_num}
})
with open(output_path, "w", encoding='utf-8') as f:
json.dump(final_tasks, f, indent=2, ensure_ascii=False)
print(f"\n✅ Label Studio tasks saved to {output_path}.")
# ============================================================================
# --- PHASE 3: BIO TO STRUCTURED JSON DECODER ---
# ============================================================================
def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]:
print("\n" + "=" * 80)
print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---")
print("=" * 80)
try:
with open(input_path, 'r', encoding='utf-8') as f:
predictions_by_page = json.load(f)
except Exception as e:
print(f"❌ Error loading raw prediction file: {e}")
return None
predictions = []
for page_item in predictions_by_page:
if isinstance(page_item, dict) and 'data' in page_item:
predictions.extend(page_item['data'])
structured_data = []
current_item = None
current_option_key = None
current_passage_buffer = []
current_text_buffer = []
first_question_started = False
last_entity_type = None
just_finished_i_option = False
is_in_new_passage = False
def finalize_passage_to_item(item, passage_buffer):
if passage_buffer:
passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
if item.get('passage'): item['passage'] += ' ' + passage_text
else: item['passage'] = passage_text
passage_buffer.clear()
for item in predictions:
word = item['word']
label = item['predicted_label']
entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
current_text_buffer.append(word)
previous_entity_type = last_entity_type
is_passage_label = (entity_type == 'PASSAGE')
if not first_question_started:
if label != 'B-QUESTION' and not is_passage_label:
just_finished_i_option = False
is_in_new_passage = False
continue
if is_passage_label:
current_passage_buffer.append(word)
last_entity_type = 'PASSAGE'
just_finished_i_option = False
is_in_new_passage = False
continue
if label == 'B-QUESTION':
if not first_question_started:
header_text = ' '.join(current_text_buffer[:-1]).strip()
if header_text or current_passage_buffer:
metadata_item = {'type': 'METADATA', 'passage': ''}
finalize_passage_to_item(metadata_item, current_passage_buffer)
if header_text: metadata_item['text'] = header_text
structured_data.append(metadata_item)
first_question_started = True
current_text_buffer = [word]
if current_item is not None:
finalize_passage_to_item(current_item, current_passage_buffer)
current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
structured_data.append(current_item)
current_text_buffer = [word]
current_item = {
'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
}
current_option_key = None
last_entity_type = 'QUESTION'
just_finished_i_option = False
is_in_new_passage = False
continue
if current_item is not None:
if is_in_new_passage:
# 🔑 Robust Initialization and Appending for 'new_passage'
if 'new_passage' not in current_item:
current_item['new_passage'] = word
else:
current_item['new_passage'] += f' {word}'
if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
is_in_new_passage = False
if label.startswith(('B-', 'I-')): last_entity_type = entity_type
continue
is_in_new_passage = False
if label.startswith('B-'):
if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
finalize_passage_to_item(current_item, current_passage_buffer)
current_passage_buffer = []
last_entity_type = entity_type
if entity_type == 'PASSAGE':
if previous_entity_type == 'OPTION' and just_finished_i_option:
current_item['new_passage'] = word # Initialize the new passage start
is_in_new_passage = True
else:
current_passage_buffer.append(word)
elif entity_type == 'OPTION':
current_option_key = word
current_item['options'][current_option_key] = word
just_finished_i_option = False
elif entity_type == 'ANSWER':
current_item['answer'] = word
current_option_key = None
just_finished_i_option = False
elif entity_type == 'QUESTION':
current_item['question'] += f' {word}'
just_finished_i_option = False
elif label.startswith('I-'):
if entity_type == 'QUESTION':
current_item['question'] += f' {word}'
elif entity_type == 'PASSAGE':
if previous_entity_type == 'OPTION' and just_finished_i_option:
current_item['new_passage'] = word # Initialize the new passage start
is_in_new_passage = True
else:
if not current_passage_buffer: last_entity_type = 'PASSAGE'
current_passage_buffer.append(word)
elif entity_type == 'OPTION' and current_option_key is not None:
current_item['options'][current_option_key] += f' {word}'
just_finished_i_option = True
elif entity_type == 'ANSWER':
current_item['answer'] += f' {word}'
just_finished_i_option = (entity_type == 'OPTION')
elif label == 'O':
if last_entity_type == 'QUESTION':
current_item['question'] += f' {word}'
just_finished_i_option = False
if current_item is not None:
finalize_passage_to_item(current_item, current_passage_buffer)
current_item['text'] = ' '.join(current_text_buffer).strip()
structured_data.append(current_item)
for item in structured_data:
item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
if 'new_passage' in item:
item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
try:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(structured_data, f, indent=2, ensure_ascii=False)
except Exception: pass
return structured_data
def create_query_text(entry: Dict[str, Any]) -> str:
"""Combines question and options into a single string for similarity matching."""
query_parts = []
if entry.get("question"):
query_parts.append(entry["question"])
for key in ["options", "options_text"]:
options = entry.get(key)
if options and isinstance(options, dict):
for value in options.values():
if value and isinstance(value, str):
query_parts.append(value)
return " ".join(query_parts)
def calculate_similarity(doc1: str, doc2: str) -> float:
"""Calculates Cosine Similarity between two text strings."""
if not doc1 or not doc2:
return 0.0
def clean_text(text):
return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE)
clean_doc1 = clean_text(doc1)
clean_doc2 = clean_text(doc2)
corpus = [clean_doc1, clean_doc2]
try:
vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b')
tfidf_matrix = vectorizer.fit_transform(corpus)
if tfidf_matrix.shape[1] == 0:
return 0.0
vectors = tfidf_matrix.toarray()
# Handle cases where vectors might be empty or too short
if len(vectors) < 2:
return 0.0
score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0]
return score
except Exception:
return 0.0
def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Links questions to passages based on 'passage' flow vs 'new_passage' priority.
Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage,
the passage context is dropped to prevent false positives downstream.
"""
print("\n" + "=" * 80)
print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---")
print("=" * 80)
if not data: return []
# --- PHASE 1: IDENTIFY PASSAGE DEFINERS ---
passage_definer_indices = []
for i, entry in enumerate(data):
if entry.get("passage") and entry["passage"].strip():
passage_definer_indices.append(i)
if entry.get("new_passage") and entry["new_passage"].strip():
if i not in passage_definer_indices:
passage_definer_indices.append(i)
# --- PHASE 2: CONTEXT TRANSFER & LINKING ---
current_passage_text = None
current_new_passage_text = None
# NEW: Counter to track consecutive linking failures
consecutive_failures = 0
MAX_CONSECUTIVE_FAILURES = 2
for i, entry in enumerate(data):
item_type = entry.get("type", "Question")
# A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter)
if entry.get("passage") and entry["passage"].strip():
current_passage_text = entry["passage"]
consecutive_failures = 0 # Reset because we have fresh explicit context
# print(f" [Flow] Updated Standard Context from Item {i}")
if entry.get("new_passage") and entry["new_passage"].strip():
current_new_passage_text = entry["new_passage"]
# We don't necessarily reset standard failures here as this is a local override
# B. QUESTION LINKING
if entry.get("question") and item_type != "METADATA":
combined_query = create_query_text(entry)
# Skip if query is too short (noise)
if len(combined_query.strip()) < 5:
continue
# Calculate scores
score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0
score_new = calculate_similarity(current_new_passage_text, combined_query) if current_new_passage_text else 0.0
q_preview = entry['question'][:30] + '...'
# RESOLUTION LOGIC
linked = False
# 1. Prefer New Passage if significantly better
if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and (score_new >= SIMILARITY_THRESHOLD):
entry["passage"] = current_new_passage_text
print(f" [Linker] 🚀 Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})")
linked = True
# Note: We do not reset 'consecutive_failures' for the standard passage here,
# because we matched the *new* passage, not the standard one.
# 2. Otherwise use Standard Passage if it meets threshold
elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD):
entry["passage"] = current_passage_text
print(f" [Linker] ✅ Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})")
linked = True
consecutive_failures = 0 # Success! Reset the kill switch.
if not linked:
# 3. DECAY LOGIC
if current_passage_text:
consecutive_failures += 1
print(f" [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
print(f" [Linker] 🗑️ Context dropped due to {consecutive_failures} consecutive misses.")
current_passage_text = None
consecutive_failures = 0
else:
print(f" [Linker] ⚠️ Q{i} NOT LINKED (No active context).")
# --- PHASE 3: CLEANUP AND INTERPOLATION ---
print(" [Linker] Running Cleanup & Interpolation...")
# 3A. Self-Correction (Remove weak links)
for i in passage_definer_indices:
entry = data[i]
if entry.get("question") and entry.get("type") != "METADATA":
passage_to_check = entry.get("passage") or entry.get("new_passage")
if passage_to_check:
self_sim = calculate_similarity(passage_to_check, create_query_text(entry))
if self_sim < SIMILARITY_THRESHOLD:
entry["passage"] = ""
if "new_passage" in entry: entry["new_passage"] = ""
print(f" [Cleanup] Removed weak link for Q{i}")
# 3B. Interpolation (Fill gaps)
# We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic
for i in range(1, len(data) - 1):
current_entry = data[i]
is_gap = current_entry.get("question") and not current_entry.get("passage")
if is_gap:
prev_p = data[i - 1].get("passage")
next_p = data[i + 1].get("passage")
if prev_p and next_p and (prev_p == next_p) and prev_p.strip():
current_entry["passage"] = prev_p
print(f" [Linker] 🥪 Q{i} Interpolated from neighbors.")
return data
def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
print("\n" + "=" * 80)
print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---")
print("=" * 80)
tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)')
corrected_count = 0
for item in structured_data:
if item.get('type') in ['METADATA']: continue
options = item.get('options')
if not options or len(options) < 2: continue
option_keys = list(options.keys())
for i in range(len(option_keys) - 1):
current_key = option_keys[i]
next_key = option_keys[i + 1]
current_value = options[current_key].strip()
next_value = options[next_key].strip()
is_current_empty = current_value == current_key
content_in_next = next_value.replace(next_key, '', 1).strip()
tags_in_next = tag_pattern.findall(content_in_next)
has_two_tags = len(tags_in_next) == 2
if is_current_empty and has_two_tags:
tag_to_move = tags_in_next[0]
options[current_key] = f"{current_key} {tag_to_move}".strip()
options[next_key] = f"{next_key} {tags_in_next[1]}".strip()
corrected_count += 1
print(f"✅ Option alignment correction finished. Total corrections: {corrected_count}.")
return structured_data
# ============================================================================
# --- PHASE 4: IMAGE EMBEDDING (Base64) ---
# ============================================================================
def get_base64_for_file(filepath: str) -> str:
try:
with open(filepath, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
except Exception as e:
print(f" ❌ Error encoding file {filepath}: {e}")
return ""
def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[Dict[str, Any]]:
print("\n" + "=" * 80)
print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
print("=" * 80)
if not structured_data: return []
image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
image_lookup = {}
tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
for filepath in image_files:
filename = os.path.basename(filepath)
match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE)
if match:
key = f"{match.group(1).upper()}{match.group(2)}"
image_lookup[key] = filepath
print(f" -> Found {len(image_lookup)} image components.")
final_structured_data = []
for item in structured_data:
text_fields = [item.get('question', ''), item.get('passage', '')]
if 'options' in item:
for opt_val in item['options'].values(): text_fields.append(opt_val)
if 'new_passage' in item: text_fields.append(item['new_passage'])
unique_tags_to_embed = set()
for text in text_fields:
if not text: continue
for match in tag_regex.finditer(text):
tag = match.group(0).upper()
if tag in image_lookup: unique_tags_to_embed.add(tag)
for tag in sorted(list(unique_tags_to_embed)):
filepath = image_lookup[tag]
base64_code = get_base64_for_file(filepath)
base_key = tag.replace(' ', '').lower()
item[base_key] = base64_code
final_structured_data.append(item)
print(f"✅ Image embedding complete.")
return final_structured_data
# ============================================================================
# --- MAIN FUNCTION ---
# ============================================================================
def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[List[Dict[str, Any]]]:
if not os.path.exists(input_pdf_path): return None
print("\n" + "#" * 80)
print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
print("#" * 80)
pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
os.makedirs(temp_pipeline_dir, exist_ok=True)
preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
final_result = None
try:
# Phase 1: Preprocessing with YOLO First + Masking
preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
if not preprocessed_json_path_out: return None
# Phase 2: Inference
page_raw_predictions_list = run_inference_and_get_raw_words(
input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
)
if not page_raw_predictions_list: return None
with open(raw_output_path, 'w', encoding='utf-8') as f:
json.dump(page_raw_predictions_list, f, indent=4)
# Phase 3: Decoding
structured_data_list = convert_bio_to_structured_json_relaxed(
raw_output_path, structured_intermediate_output_path
)
if not structured_data_list: return None
structured_data_list = correct_misaligned_options(structured_data_list)
structured_data_list = process_context_linking(structured_data_list)
try:
convert_raw_predictions_to_label_studio(page_raw_predictions_list, label_studio_output_path)
except Exception as e:
print(f"❌ Error during Label Studio conversion: {e}")
# Phase 4: Embedding
final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
except Exception as e:
print(f"❌ FATAL ERROR: {e}")
import traceback
traceback.print_exc()
return None
finally:
try:
for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
os.remove(f)
os.rmdir(temp_pipeline_dir)
except Exception: pass
print("\n" + "#" * 80)
print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###")
print("#" * 80)
return final_result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Complete Pipeline")
parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
args = parser.parse_args()
pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
ls_output_path = os.path.abspath(args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path)
if final_json_data:
with open(final_output_path, 'w', encoding='utf-8') as f:
json.dump(final_json_data, f, indent=2, ensure_ascii=False)
print(f"\n✅ Final Data Saved: {final_output_path}")
else:
print("\n❌ Pipeline Failed.")
sys.exit(1)