File size: 65,661 Bytes
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880e0c3
 
2d5711a
b9095b0
 
58a2a88
 
 
 
 
 
3de965c
950cc9f
58a2a88
 
880e0c3
58a2a88
 
 
 
 
 
 
 
 
 
b9095b0
 
4d01f30
b9095b0
 
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880e0c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a2a88
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b88b399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4acbd3f
b88b399
 
 
 
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e41039
da41f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9f6960
 
 
 
 
 
 
 
2d5711a
e9f6960
 
 
 
 
 
 
2d5711a
e9f6960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4d3b2
 
4ff69d8
8c4d3b2
 
 
880e0c3
4ff69d8
880e0c3
8c4d3b2
 
 
880e0c3
4ff69d8
 
880e0c3
8c4d3b2
 
 
 
 
 
 
880e0c3
e9e00b7
2e41039
880e0c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
e9f6960
8c4d3b2
 
 
2d5711a
8c4d3b2
 
 
 
 
2d5711a
8c4d3b2
 
2d5711a
 
8c4d3b2
 
 
 
 
f0a088d
35de6cd
2d5711a
 
8c4d3b2
 
e9f6960
 
35de6cd
880e0c3
8c4d3b2
e9f6960
8c4d3b2
880e0c3
 
 
 
752a878
 
 
880e0c3
 
 
 
 
 
 
 
 
 
 
 
752a878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880e0c3
 
752a878
 
 
 
 
 
 
 
 
880e0c3
 
752a878
 
 
880e0c3
 
 
 
 
 
 
752a878
880e0c3
 
 
 
 
4ff69d8
 
880e0c3
61c8bf9
 
1fa413d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880e0c3
1fa413d
 
58a2a88
1fa413d
3541663
 
 
880e0c3
 
 
61c8bf9
96274f8
4ff69d8
 
96274f8
 
61c8bf9
58a2a88
1fa413d
 
 
 
880e0c3
61c8bf9
 
 
e9e00b7
61c8bf9
 
880e0c3
96274f8
58a2a88
 
 
 
 
 
 
 
4ff69d8
 
1fa413d
e9e00b7
e9f6960
4ff69d8
 
96274f8
4ff69d8
96274f8
4ff69d8
 
 
 
 
 
 
 
e9e00b7
4ff69d8
 
96274f8
 
4ff69d8
35de6cd
 
4ff69d8
 
 
96274f8
4ff69d8
 
 
 
 
96274f8
4ff69d8
 
 
2d5711a
4ff69d8
 
 
 
58a2a88
4ff69d8
 
 
1fa413d
58a2a88
1fa413d
 
 
58a2a88
 
 
 
 
 
 
 
1fa413d
58a2a88
 
 
 
1fa413d
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
96274f8
4ff69d8
1fa413d
96274f8
1fa413d
96274f8
 
 
 
 
1fa413d
96274f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a2a88
4ff69d8
96274f8
4ff69d8
58a2a88
96274f8
58a2a88
 
 
e9e00b7
58a2a88
 
 
 
 
 
 
e9e00b7
58a2a88
 
4ff69d8
96274f8
4ff69d8
58a2a88
 
1fa413d
58a2a88
 
 
 
 
 
 
 
 
 
 
880e0c3
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c8bf9
 
 
3541663
58a2a88
 
 
 
 
4ff69d8
58a2a88
 
880e0c3
58a2a88
 
 
 
 
 
 
 
 
 
 
 
1fa413d
 
880e0c3
1fa413d
 
 
 
880e0c3
 
4ff69d8
58a2a88
880e0c3
58a2a88
880e0c3
 
 
58a2a88
1fa413d
 
880e0c3
 
 
 
 
 
 
61c8bf9
880e0c3
61c8bf9
880e0c3
61c8bf9
880e0c3
 
61c8bf9
 
58a2a88
61c8bf9
 
 
880e0c3
61c8bf9
58a2a88
 
 
 
 
880e0c3
1fa413d
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c8bf9
58a2a88
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
4ff69d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a2a88
 
 
 
 
dc4a82b
58a2a88
 
 
 
dc4a82b
58a2a88
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
dc4a82b
 
 
58a2a88
 
 
 
dc4a82b
 
58a2a88
 
 
 
 
 
 
 
 
dc4a82b
58a2a88
 
 
dc4a82b
58a2a88
dc4a82b
58a2a88
 
 
dc4a82b
58a2a88
 
 
 
 
 
 
 
 
 
dc4a82b
 
 
 
 
 
58a2a88
dc4a82b
58a2a88
0eff5e2
 
4ff69d8
0eff5e2
4ff69d8
0eff5e2
dc4a82b
 
 
 
 
 
 
 
 
 
58a2a88
dc4a82b
58a2a88
 
dc4a82b
58a2a88
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4a82b
61c8bf9
58a2a88
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
 
58a2a88
 
 
 
 
 
 
 
0f4dfc1
58a2a88
0f4dfc1
 
 
 
 
 
58a2a88
 
0f4dfc1
 
 
 
58a2a88
 
 
 
 
0f4dfc1
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
b9095b0
 
 
 
 
 
58a2a88
 
4ff69d8
58a2a88
 
0f4dfc1
58a2a88
0f4dfc1
58a2a88
 
 
 
 
b9095b0
58a2a88
 
880e0c3
58a2a88
 
 
 
 
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
b9095b0
58a2a88
 
4ff69d8
 
 
58a2a88
 
4ff69d8
58a2a88
4ff69d8
58a2a88
 
4ff69d8
58a2a88
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce2214
 
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdbfd9e
 
 
 
b9095b0
 
 
bdbfd9e
 
b9095b0
 
bdbfd9e
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdbfd9e
 
 
 
b9095b0
 
 
 
bdbfd9e
b9095b0
 
bdbfd9e
 
b9095b0
 
 
bdbfd9e
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdbfd9e
 
b9095b0
 
 
 
 
 
bdbfd9e
b9095b0
 
bdbfd9e
 
 
 
 
 
 
 
 
 
 
b9095b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdbfd9e
b9095b0
 
 
 
 
 
bdbfd9e
b9095b0
bdbfd9e
b9095b0
 
 
 
0f4dfc1
 
 
 
 
 
 
4ff69d8
0f4dfc1
4ff69d8
0f4dfc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9095b0
 
 
58a2a88
880e0c3
58a2a88
 
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
4ff69d8
58a2a88
 
 
 
 
 
 
 
 
4ff69d8
58a2a88
 
 
 
4ff69d8
 
58a2a88
 
 
 
 
4ff69d8
58a2a88
 
 
 
 
 
4ff69d8
58a2a88
 
 
4ff69d8
58a2a88
 
4ff69d8
 
 
58a2a88
880e0c3
58a2a88
4ff69d8
58a2a88
 
 
4ff69d8
58a2a88
 
 
4ff69d8
58a2a88
 
4ff69d8
58a2a88
4ff69d8
 
 
58a2a88
4ff69d8
58a2a88
4ff69d8
 
58a2a88
 
4ff69d8
 
58a2a88
4ff69d8
58a2a88
4ff69d8
0f4dfc1
b9095b0
4ff69d8
0f4dfc1
 
 
 
4ff69d8
 
 
 
58a2a88
4ff69d8
880e0c3
 
58a2a88
4ff69d8
58a2a88
 
880e0c3
 
58a2a88
4ff69d8
 
58a2a88
880e0c3
58a2a88
 
 
b9095b0
58a2a88
4ff69d8
 
 
 
58a2a88
4ff69d8
0f4dfc1
4ff69d8
 
 
 
 
0f4dfc1
58a2a88
 
4ff69d8
880e0c3
4ff69d8
e9f6960
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
import json
import argparse
import os
import re
import torch
import torch.nn as nn
from TorchCRF import CRF
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
from typing import List, Dict, Any, Optional, Union, Tuple
import fitz  # PyMuPDF
import numpy as np
import cv2
from ultralytics import YOLO
import glob
import pytesseract
from PIL import Image
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
import sys
import io
import base64
import tempfile
import time
import shutil
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# ============================================================================
# --- CONFIGURATION AND CONSTANTS ---
# ============================================================================

# NOTE: Update these paths to match your environment before running!
WEIGHTS_PATH = 'YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt'
DEFAULT_LAYOUTLMV3_MODEL_PATH = "97.pth"

# DIRECTORY CONFIGURATION
OCR_JSON_OUTPUT_DIR = './ocr_json_output_final'
FIGURE_EXTRACTION_DIR = './figure_extraction'
TEMP_IMAGE_DIR = './temp_pdf_images'

# Detection parameters
CONF_THRESHOLD = 0.2
TARGET_CLASSES = ['figure', 'equation']
IOU_MERGE_THRESHOLD = 0.4
IOA_SUPPRESSION_THRESHOLD = 0.7
LINE_TOLERANCE = 15


#Similarity
SIMILARITY_THRESHOLD = 0.10
RESOLUTION_MARGIN = 0.05

# Global counters for sequential numbering across the entire PDF
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0

# LayoutLMv3 Labels
ID_TO_LABEL = {
    0: "O",
    1: "B-QUESTION", 2: "I-QUESTION",
    3: "B-OPTION", 4: "I-OPTION",
    5: "B-ANSWER", 6: "I-ANSWER",
    7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
    9: "B-PASSAGE", 10: "I-PASSAGE"
}
NUM_LABELS = len(ID_TO_LABEL)


# ============================================================================
# --- PERFORMANCE OPTIMIZATION: OCR CACHE ---
# ============================================================================

class OCRCache:
    """Caches OCR results per page to avoid redundant Tesseract runs."""

    def __init__(self):
        self.cache = {}

    def get_key(self, pdf_path: str, page_num: int) -> str:
        return f"{pdf_path}:{page_num}"

    def has_ocr(self, pdf_path: str, page_num: int) -> bool:
        return self.get_key(pdf_path, page_num) in self.cache

    def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]:
        return self.cache.get(self.get_key(pdf_path, page_num))

    def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list):
        self.cache[self.get_key(pdf_path, page_num)] = ocr_data

    def clear(self):
        self.cache.clear()


# Global OCR cache instance
_ocr_cache = OCRCache()


# ============================================================================
# --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS ---
# ============================================================================

def calculate_iou(box1, box2):
    x1_a, y1_a, x2_a, y2_a = box1
    x1_b, y1_b, x2_b, y2_b = box2
    x_left = max(x1_a, x1_b)
    y_top = max(y1_a, y1_b)
    x_right = min(x2_a, x2_b)
    y_bottom = min(y2_a, y2_b)
    intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
    box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
    box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
    union_area = float(box_a_area + box_b_area - intersection_area)
    return intersection_area / union_area if union_area > 0 else 0


def calculate_ioa(box1, box2):
    x1_a, y1_a, x2_a, y2_a = box1
    x1_b, y1_b, x2_b, y2_b = box2
    x_left = max(x1_a, x1_b)
    y_top = max(y1_a, y1_b)
    x_right = min(x2_a, x2_b)
    y_bottom = min(y2_a, y2_b)
    intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
    box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
    return intersection_area / box_a_area if box_a_area > 0 else 0


def filter_nested_boxes(detections, ioa_threshold=0.80):
    """
    Removes boxes that are inside larger boxes (Containment Check).
    Prioritizes keeping the LARGEST box (the 'parent' container).
    """
    if not detections: 
        return []

    # 1. Calculate Area for all detections
    for d in detections:
        x1, y1, x2, y2 = d['coords']
        d['area'] = (x2 - x1) * (y2 - y1)

    # 2. Sort by Area Descending (Largest to Smallest)
    # This ensures we process the 'container' first
    detections.sort(key=lambda x: x['area'], reverse=True)

    keep_indices = []
    is_suppressed = [False] * len(detections)

    for i in range(len(detections)):
        if is_suppressed[i]: continue

        keep_indices.append(i)
        box_a = detections[i]['coords']
        
        # Compare with all smaller boxes
        for j in range(i + 1, len(detections)):
            if is_suppressed[j]: continue

            box_b = detections[j]['coords']
            
            # Calculate Intersection
            x_left = max(box_a[0], box_b[0])
            y_top = max(box_a[1], box_b[1])
            x_right = min(box_a[2], box_b[2])
            y_bottom = min(box_a[3], box_b[3])

            if x_right < x_left or y_bottom < y_top:
                intersection = 0
            else:
                intersection = (x_right - x_left) * (y_bottom - y_top)

            # Calculate IoA (Intersection over Area of the SMALLER box)
            # Since we sorted by area, 'box_b' (detections[j]) is the smaller one.
            area_b = detections[j]['area']
            
            if area_b > 0:
                ioa_small = intersection / area_b
                
                # If the small box is > 90% inside the big box, suppress the small one.
                if ioa_small > ioa_threshold:
                    is_suppressed[j] = True
                    #print(f"    [Suppress] Removed nested object inside larger '{detections[i]['class']}'")

    return [detections[i] for i in keep_indices]    


def merge_overlapping_boxes(detections, iou_threshold):
    if not detections: return []
    detections.sort(key=lambda d: d['conf'], reverse=True)
    merged_detections = []
    is_merged = [False] * len(detections)
    for i in range(len(detections)):
        if is_merged[i]: continue
        current_box = detections[i]['coords']
        current_class = detections[i]['class']
        merged_x1, merged_y1, merged_x2, merged_y2 = current_box
        for j in range(i + 1, len(detections)):
            if is_merged[j] or detections[j]['class'] != current_class: continue
            other_box = detections[j]['coords']
            iou = calculate_iou(current_box, other_box)
            if iou > iou_threshold:
                merged_x1 = min(merged_x1, other_box[0])
                merged_y1 = min(merged_y1, other_box[1])
                merged_x2 = max(merged_x2, other_box[2])
                merged_y2 = max(merged_y2, other_box[3])
                is_merged[j] = True
        merged_detections.append({
            'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
            'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
        })
    return merged_detections


def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list:
    """
    Filters out raw words that are inside YOLO boxes and replaces them with 
    a single solid 'placeholder' block for the column detector.
    """
    if not yolo_detections:
        return raw_word_data

    # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points)
    pdf_space_boxes = []
    for det in yolo_detections:
        x1, y1, x2, y2 = det['coords']
        pdf_box = (
            x1 / scale_factor, 
            y1 / scale_factor, 
            x2 / scale_factor, 
            y2 / scale_factor
        )
        pdf_space_boxes.append(pdf_box)

    # 2. Filter out raw words that are inside YOLO boxes
    cleaned_word_data = []
    for word_tuple in raw_word_data:
        wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4]
        w_center_x = (wx1 + wx2) / 2
        w_center_y = (wy1 + wy2) / 2
        
        is_inside_yolo = False
        for px1, py1, px2, py2 in pdf_space_boxes:
            if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2:
                is_inside_yolo = True
                break
        
        if not is_inside_yolo:
            cleaned_word_data.append(word_tuple)

    # 3. Add the YOLO boxes themselves as "Solid Words"
    for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes):
        dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2)
        cleaned_word_data.append(dummy_entry)

    return cleaned_word_data



# ============================================================================
# --- MISSING HELPER FUNCTION ---
# ============================================================================

def preprocess_image_for_ocr(img_np):
    """
    Converts image to grayscale and applies Otsu's Binarization 
    to separate text from background clearly.
    """
    # 1. Convert to Grayscale if needed
    if len(img_np.shape) == 3:
        gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
    else:
        gray = img_np
        
    # 2. Apply Otsu's Thresholding (Automatic binary threshold)
    # This makes text solid black and background solid white
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    return thresh


def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float:
    """
    Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator.
    A valid column split should split > 65% of the page verticality.
    """
    if not word_data:
        return 0.0

    # Determine the vertical span of the actual text content
    y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2
    min_y, max_y = min(y_coords), max(y_coords)
    total_text_height = max_y - min_y
    
    if total_text_height <= 0:
        return 0.0

    # Create a boolean array representing the Y-axis (1 pixel per unit)
    gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool)
    
    zone_left = sep_x - (gutter_width / 2)
    zone_right = sep_x + (gutter_width / 2)
    offset_y = int(min_y)
    
    for _, x1, y1, x2, y2 in word_data:
        # Check if this word horizontally interferes with the separator
        if x2 > zone_left and x1 < zone_right:
            y_start_idx = max(0, int(y1) - offset_y)
            y_end_idx = min(len(gap_open_mask), int(y2) - offset_y)
            if y_end_idx > y_start_idx:
                gap_open_mask[y_start_idx:y_end_idx] = False

    open_pixels = np.sum(gap_open_mask)
    coverage_ratio = open_pixels / len(gap_open_mask)
    
    return coverage_ratio




def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]:
    """
    Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage.
    """
    if not word_data: return []

    x_points = []
    # Use only word_data elements 1 (x1) and 3 (x2)
    for item in word_data:
        x_points.extend([item[1], item[3]])

    if not x_points: return []
    max_x = max(x_points)

    # 1. Determine total text height for ratio calculation
    y_coords = [item[2] for item in word_data] + [item[4] for item in word_data]
    min_y, max_y = min(y_coords), max(y_coords)
    total_text_height = max_y - min_y
    if total_text_height <= 0: return []

    # Histogram Setup
    bin_size = params.get('cluster_bin_size', 5)
    smoothing = params.get('cluster_smoothing', 1)
    min_width = params.get('cluster_min_width', 20)
    threshold_percentile = params.get('cluster_threshold_percentile', 85)

    num_bins = int(np.ceil(max_x / bin_size))
    hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
    smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing)
    inverted_signal = np.max(smoothed_hist) - smoothed_hist

    peaks, properties = find_peaks(
        inverted_signal,
        height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile),
        distance=min_width / bin_size
    )

    if not peaks.size: return []
    separator_x_coords = [int(bin_edges[p]) for p in peaks]
    final_separators = []
    
    for x_coord in separator_x_coords:
        # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) ---
        # Calculate the total vertical height of words that physically cross this line.
        bridging_height = 0
        bridging_count = 0
        
        for item in word_data:
            wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4]
            
            # Check if this word physically sits on top of the separator line
            if wx1 < x_coord and wx2 > x_coord:
                word_h = wy2 - wy1
                bridging_height += word_h
                bridging_count += 1
        
        # Calculate Ratio: How much of the page's text height is blocked by these crossing words?
        bridging_ratio = bridging_height / total_text_height
        
        # THRESHOLD: If bridging blocks > 8% of page height, REJECT.
        # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs.
        if bridging_ratio > 0.08: 
            print(f"      ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.")
            continue

        # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) ---
        # The gap must exist cleanly for > 65% of the text height.
        coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width)
        
        if coverage >= 0.80:
            final_separators.append(x_coord)
            print(f"      -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")
        else:
            print(f"      ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})")

    return sorted(final_separators)





def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int,
                                top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
    """Extract word data with OCR caching to avoid redundant Tesseract runs."""
    word_data = page.get_text("words")

    if len(word_data) > 0:
        word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
    else:
        if _ocr_cache.has_ocr(pdf_path, page_num):
            word_data = _ocr_cache.get_ocr(pdf_path, page_num)
        else:
            try:
                # --- OPTIMIZATION START ---
                # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI)
                zoom_level = 4.0
                pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level))
                
                # 2. Convert directly to OpenCV format (Faster than PIL)
                img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
                if pix.n == 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
                elif pix.n == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR)

                # 3. Apply Preprocessing (Thresholding)
                processed_img = preprocess_image_for_ocr(img_np)

                # 4. Optimized Tesseract Config
                # --psm 6: Assume a single uniform block of text (Great for columns/questions)
                # --oem 3: Default engine (LSTM)
                custom_config = r'--oem 3 --psm 6'
                
                data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, config=custom_config)
                
                full_word_data = []
                for i in range(len(data['level'])):
                    text = data['text'][i].strip()
                    if text:
                        # Scale coordinates back to PDF points
                        x1 = data['left'][i] / zoom_level
                        y1 = data['top'][i] / zoom_level
                        x2 = (data['left'][i] + data['width'][i]) / zoom_level
                        y2 = (data['top'][i] + data['height'][i]) / zoom_level
                        full_word_data.append((text, x1, y1, x2, y2))
                
                word_data = full_word_data
                _ocr_cache.set_ocr(pdf_path, page_num, word_data)
                # --- OPTIMIZATION END ---
            except Exception as e:
                print(f"  ❌ OCR Error in detection phase: {e}")
                return []

    # Apply margin filtering
    page_height = page.rect.height
    y_min = page_height * top_margin_percent
    y_max = page_height * (1 - bottom_margin_percent)
    return [d for d in word_data if d[2] >= y_min and d[4] <= y_max]
    


def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
    img_data = pix.samples
    img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
    if pix.n == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
    elif pix.n == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return img


def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list:
    raw_word_data = fitz_page.get_text("words")
    converted_ocr_output = []
    DEFAULT_CONFIDENCE = 99.0

    for x1, y1, x2, y2, word, *rest in raw_word_data:
        if not word.strip(): continue
        x1_pix = int(x1 * scale_factor)
        y1_pix = int(y1 * scale_factor)
        x2_pix = int(x2 * scale_factor)
        y2_pix = int(y2 * scale_factor)
        converted_ocr_output.append({
            'type': 'text',
            'word': word,
            'confidence': DEFAULT_CONFIDENCE,
            'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
            'y0': y1_pix, 'x0': x1_pix
        })
    return converted_ocr_output





def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
                            page_num: int, fitz_page: fitz.Page,
                            pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]:
    """
    OPTIMIZED FLOW: 
    1. Run YOLO to find Equations/Tables.
    2. Mask raw text with YOLO boxes.
    3. Run Column Detection on the MASKED data.
    4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output.
    """
    global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT

    start_time_total = time.time()

    if original_img is None:
        print(f"  ❌ Invalid image for page {page_num}.")
        return None, None

    # ====================================================================
    # --- STEP 1: YOLO DETECTION ---
    # ====================================================================
    start_time_yolo = time.time()
    results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
    
    relevant_detections = []
    if results and results[0].boxes:
        for box in results[0].boxes:
            class_id = int(box.cls[0])
            class_name = model.names[class_id]
            if class_name in TARGET_CLASSES:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
                relevant_detections.append(
                    {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
                )

    merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
    print(f"    [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")

    # ====================================================================
    # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
    # ====================================================================
    # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations
    raw_words_for_layout = get_word_data_for_detection(
        fitz_page, pdf_path, page_num,
        top_margin_percent=0.10, bottom_margin_percent=0.10
    )

    masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0)

    # ====================================================================
    # --- STEP 3: COLUMN DETECTION ---
    # ====================================================================
    page_width_pdf = fitz_page.rect.width
    page_height_pdf = fitz_page.rect.height 
    
    column_detection_params = {
        'cluster_bin_size': 2, 'cluster_smoothing': 2,
        'cluster_min_width': 10, 'cluster_threshold_percentile': 85,
    }

    separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf)
    
    page_separator_x = None
    if separators:
        central_min = page_width_pdf * 0.35
        central_max = page_width_pdf * 0.65
        central_separators = [s for s in separators if central_min <= s <= central_max]
        
        if central_separators:
            center_x = page_width_pdf / 2
            page_separator_x = min(central_separators, key=lambda x: abs(x - center_x))
            print(f"      ✅ Column Split Confirmed at X={page_separator_x:.1f}")
        else:
            print("      ⚠️ Gutter found off-center. Ignoring.")
    else:
        print("      -> Single Column Layout Confirmed.")

    # ====================================================================
    # --- STEP 4: COMPONENT EXTRACTION (Save Images) ---
    # ====================================================================
    start_time_components = time.time()
    component_metadata = []
    fig_count_page = 0
    eq_count_page = 0

    for detection in merged_detections:
        x1, y1, x2, y2 = detection['coords']
        class_name = detection['class']

        if class_name == 'figure':
            GLOBAL_FIGURE_COUNT += 1
            counter = GLOBAL_FIGURE_COUNT
            component_word = f"FIGURE{counter}"
            fig_count_page += 1
        elif class_name == 'equation':
            GLOBAL_EQUATION_COUNT += 1
            counter = GLOBAL_EQUATION_COUNT
            component_word = f"EQUATION{counter}"
            eq_count_page += 1
        else:
            continue

        component_crop = original_img[y1:y2, x1:x2]
        component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png"
        cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop)

        y_midpoint = (y1 + y2) // 2
        component_metadata.append({
            'type': class_name, 'word': component_word,
            'bbox': [int(x1), int(y1), int(x2), int(y2)],
            'y0': int(y_midpoint), 'x0': int(x1)
        })

    # ====================================================================
    # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) ---
    # ====================================================================
    raw_ocr_output = []
    scale_factor = 2.0 # Pipeline standard scale

    try:
        # Try getting native text first
        raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor)
    except Exception as e:
        print(f"  ❌ Native text extraction failed: {e}")

    # If native text is missing, fall back to OCR
    if not raw_ocr_output:
        if _ocr_cache.has_ocr(pdf_path, page_num):
            print(f"  ⚡ Using cached Tesseract OCR for page {page_num}")
            cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num)
            for word_tuple in cached_word_data:
                word_text, x1, y1, x2, y2 = word_tuple
                
                # Scale from PDF points to Pipeline Pixels (2.0)
                x1_pix = int(x1 * scale_factor)
                y1_pix = int(y1 * scale_factor)
                x2_pix = int(x2 * scale_factor)
                y2_pix = int(y2 * scale_factor)
                
                raw_ocr_output.append({
                    'type': 'text', 'word': word_text, 'confidence': 95.0,
                    'bbox': [x1_pix, y1_pix, x2_pix, y2_pix],
                    'y0': y1_pix, 'x0': x1_pix
                })
        else:
            # === START OF OPTIMIZED OCR BLOCK ===
            try:
                # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI)
                # We do this specifically for OCR accuracy, separate from the pipeline image
                ocr_zoom = 4.0
                pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom))
                
                # Convert PyMuPDF Pixmap to OpenCV format
                img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, pix_ocr.n)
                if pix_ocr.n == 3: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR)
                elif pix_ocr.n == 4: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR)

                # 2. Preprocess (Binarization)
                # Ensure 'preprocess_image_for_ocr' is defined at top of file!
                processed_img = preprocess_image_for_ocr(img_ocr_np)
                
                # 3. Run Tesseract with Optimized Configuration
                # --oem 3: Default LSTM engine
                # --psm 6: Assume a single uniform block of text (Critical for lists/questions)
                custom_config = r'--oem 3 --psm 6'
                
                hocr_data = pytesseract.image_to_data(
                    processed_img, 
                    output_type=pytesseract.Output.DICT, 
                    config=custom_config
                )
                
                for i in range(len(hocr_data['level'])):
                    text = hocr_data['text'][i].strip()
                    if text and hocr_data['conf'][i] > -1:
                        
                        # 4. Coordinate Mapping
                        # We scanned at Zoom 4.0, but our pipeline expects Zoom 2.0.
                        # Scale Factor = (Target 2.0) / (Source 4.0) = 0.5
                        scale_adjustment = scale_factor / ocr_zoom 
                        
                        x1 = int(hocr_data['left'][i] * scale_adjustment)
                        y1 = int(hocr_data['top'][i] * scale_adjustment)
                        w = int(hocr_data['width'][i] * scale_adjustment)
                        h = int(hocr_data['height'][i] * scale_adjustment)
                        x2 = x1 + w
                        y2 = y1 + h
                        
                        raw_ocr_output.append({
                            'type': 'text', 
                            'word': text, 
                            'confidence': float(hocr_data['conf'][i]),
                            'bbox': [x1, y1, x2, y2], 
                            'y0': y1, 
                            'x0': x1
                        })
            except Exception as e:
                print(f"  ❌ Tesseract OCR Error: {e}")
            # === END OF OPTIMIZED OCR BLOCK ===

    # ====================================================================
    # --- STEP 6: OCR CLEANING AND MERGING ---
    # ====================================================================
    items_to_sort = []
    
    for ocr_word in raw_ocr_output:
        is_suppressed = False
        for component in component_metadata:
            # Do not include words that are inside figure/equation boxes
            ioa = calculate_ioa(ocr_word['bbox'], component['bbox'])
            if ioa > IOA_SUPPRESSION_THRESHOLD:
                is_suppressed = True
                break
        if not is_suppressed:
            items_to_sort.append(ocr_word)

    # Add figures/equations back into the flow as "words"
    items_to_sort.extend(component_metadata)

    # ====================================================================
    # --- STEP 7: LINE-BASED SORTING ---
    # ====================================================================
    items_to_sort.sort(key=lambda x: (x['y0'], x['x0']))
    lines = []

    for item in items_to_sort:
        placed = False
        for line in lines:
            y_ref = min(it['y0'] for it in line)
            if abs(y_ref - item['y0']) < LINE_TOLERANCE:
                line.append(item)
                placed = True
                break
        if not placed and item['type'] in ['equation', 'figure']:
            for line in lines:
                y_ref = min(it['y0'] for it in line)
                if abs(y_ref - item['y0']) < 20:
                    line.append(item)
                    placed = True
                    break
        if not placed:
            lines.append([item])

    for line in lines:
        line.sort(key=lambda x: x['x0'])

    final_output = []
    for line in lines:
        for item in line:
            data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]}
            if 'tag' in item: data_item['tag'] = item['tag']
            final_output.append(data_item)

    return final_output, page_separator_x



def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
    global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT

    GLOBAL_FIGURE_COUNT = 0
    GLOBAL_EQUATION_COUNT = 0
    _ocr_cache.clear()

    print("\n" + "=" * 80)
    print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---")
    print("=" * 80)

    if not os.path.exists(pdf_path):
        print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.")
        return None

    os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True)
    os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True)

    model = YOLO(WEIGHTS_PATH)
    pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]

    try:
        doc = fitz.open(pdf_path)
        print(f"✅ Opened PDF: {pdf_name} ({doc.page_count} pages)")
    except Exception as e:
        print(f"❌ ERROR loading PDF file: {e}")
        return None

    all_pages_data = []
    total_pages_processed = 0
    mat = fitz.Matrix(2.0, 2.0)

    print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]")

    for page_num_0_based in range(doc.page_count):
        page_num = page_num_0_based + 1
        print(f"  -> Processing Page {page_num}/{doc.page_count}...")

        fitz_page = doc.load_page(page_num_0_based)

        try:
            pix = fitz_page.get_pixmap(matrix=mat)
            original_img = pixmap_to_numpy(pix)
        except Exception as e:
            print(f"  ❌ Error converting page {page_num} to image: {e}")
            continue

        final_output, page_separator_x = preprocess_and_ocr_page(
            original_img,
            model,
            pdf_path,
            page_num,
            fitz_page,
            pdf_name
        )

        if final_output is not None:
            page_data = {
                "page_number": page_num,
                "data": final_output,
                "column_separator_x": page_separator_x
            }
            all_pages_data.append(page_data)
            total_pages_processed += 1
        else:
            print(f"  ❌ Skipped page {page_num} due to processing error.")

    doc.close()

    if all_pages_data:
        try:
            with open(preprocessed_json_path, 'w') as f:
                json.dump(all_pages_data, f, indent=4)
            print(f"\n  ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}")
        except Exception as e:
            print(f"❌ ERROR saving combined JSON output: {e}")
            return None
    else:
        print("❌ WARNING: No page data generated. Halting pipeline.")
        return None

    print("\n" + "=" * 80)
    print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---")
    print("=" * 80)

    return preprocessed_json_path


# ============================================================================
# --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS ---
# ============================================================================

class LayoutLMv3ForTokenClassification(nn.Module):
    def __init__(self, num_labels: int = NUM_LABELS):
        super().__init__()
        self.num_labels = num_labels
        config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
        self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.crf = CRF(num_labels)
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None):
        outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True)
        sequence_output = outputs.last_hidden_state
        emissions = self.classifier(sequence_output)
        mask = attention_mask.bool()
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask).mean()
            return loss
        else:
            return self.crf.viterbi_decode(emissions, mask=mask)

def _merge_integrity(all_token_data: List[Dict[str, Any]],
                      column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]:
    """Splits the token data objects into column chunks based on a separator."""
    if column_separator_x is None:
        print("    -> No column separator. Treating as one chunk.")
        return [all_token_data]

    left_column_tokens, right_column_tokens = [], []
    for token_data in all_token_data:
        bbox_raw = token_data['bbox_raw_pdf_space']
        center_x = (bbox_raw[0] + bbox_raw[2]) / 2
        if center_x < column_separator_x:
            left_column_tokens.append(token_data)
        else:
            right_column_tokens.append(token_data)

    chunks = [c for c in [left_column_tokens, right_column_tokens] if c]
    print(f"    -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.")
    return chunks

def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
                                    preprocessed_json_path: str,
                                    column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
    print("\n" + "=" * 80)
    print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---")
    print("=" * 80)

    tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"  -> Using device: {device}")

    try:
        model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
        checkpoint = torch.load(model_path, map_location=device)
        model_state = checkpoint.get('model_state_dict', checkpoint)
        fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
        model.load_state_dict(fixed_state_dict)
        model.to(device)
        model.eval()
        print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.")
    except Exception as e:
        print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}")
        return []

    try:
        with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
            preprocessed_data = json.load(f)
        print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.")
    except Exception:
        print("❌ Error loading preprocessed JSON.")
        return []

    try:
        doc = fitz.open(pdf_path)
    except Exception:
        print("❌ Error loading PDF.")
        return []

    final_page_predictions = []
    CHUNK_SIZE = 500

    for page_data in preprocessed_data:
        page_num_1_based = page_data['page_number']
        page_num_0_based = page_num_1_based - 1
        page_raw_predictions = []
        print(f"\n  *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***")

        fitz_page = doc.load_page(page_num_0_based)
        page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
        print(f"    -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).")

        all_token_data = []
        scale_factor = 2.0

        for item in page_data['data']:
            raw_yolo_bbox = item['bbox']
            bbox_pdf = [
                int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor),
                int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor)
            ]
            normalized_bbox = [
                max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
                max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
                max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
                max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
            ]
            all_token_data.append({
                "word": item['word'],
                "bbox_raw_pdf_space": bbox_pdf,
                "bbox_normalized": normalized_bbox,
                "item_original_data": item
            })

        if not all_token_data: continue

        column_separator_x = page_data.get('column_separator_x', None)
        if column_separator_x is not None:
            print(f"    -> Using SAVED column separator: X={column_separator_x}")
        else:
            print("    -> No column separator found. Assuming single chunk.")

        token_chunks = _merge_integrity(all_token_data, column_separator_x)
        total_chunks = len(token_chunks)

        for chunk_idx, chunk_tokens in enumerate(token_chunks):
            if not chunk_tokens: continue

            chunk_words = [t['word'] for t in chunk_tokens]
            chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens]

            total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE
            for i in range(0, len(chunk_words), CHUNK_SIZE):
                sub_chunk_idx = i // CHUNK_SIZE + 1
                sub_words = chunk_words[i:i + CHUNK_SIZE]
                sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
                sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE]

                print(f"      -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...")

                encoded_input = tokenizer(
                    sub_words, boxes=sub_bboxes, truncation=True, padding="max_length",
                    max_length=512, return_tensors="pt"
                )
                input_ids = encoded_input['input_ids'].to(device)
                bbox = encoded_input['bbox'].to(device)
                attention_mask = encoded_input['attention_mask'].to(device)

                with torch.no_grad():
                    predictions_int_list = model(input_ids, bbox, attention_mask)

                if not predictions_int_list: continue
                predictions_int = predictions_int_list[0]
                word_ids = encoded_input.word_ids()
                word_idx_to_pred_id = {}

                for token_idx, word_idx in enumerate(word_ids):
                    if word_idx is not None and word_idx < len(sub_words):
                        if word_idx not in word_idx_to_pred_id:
                            word_idx_to_pred_id[word_idx] = predictions_int[token_idx]

                for current_word_idx in range(len(sub_words)):
                    pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
                    pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
                    predicted_label = ID_TO_LABEL[pred_id]
                    original_token = sub_tokens_data[current_word_idx]
                    page_raw_predictions.append({
                        "word": original_token['word'],
                        "bbox": original_token['bbox_raw_pdf_space'],
                        "predicted_label": predicted_label,
                        "page_number": page_num_1_based
                    })

        if page_raw_predictions:
            final_page_predictions.append({
                "page_number": page_num_1_based,
                "data": page_raw_predictions
            })
            print(f"  *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***")

    doc.close()
    print("\n" + "=" * 80)
    print("--- LAYOUTLMV3 INFERENCE COMPLETE ---")
    print("=" * 80)
    return final_page_predictions


def create_label_studio_span(page_results, start_idx, end_idx, label):
    entity_words = [page_results[i]['word'] for i in range(start_idx, end_idx + 1)]
    entity_bboxes = [page_results[i]['bbox'] for i in range(start_idx, end_idx + 1)]
    x0 = min(bbox[0] for bbox in entity_bboxes)
    y0 = min(bbox[1] for bbox in entity_bboxes)
    x1 = max(bbox[2] for bbox in entity_bboxes)
    y1 = max(bbox[3] for bbox in entity_bboxes)
    all_words_on_page = [r['word'] for r in page_results]
    start_char = len(" ".join(all_words_on_page[:start_idx]))
    if start_idx != 0: start_char += 1
    end_char = start_char + len(" ".join(entity_words))
    span_text = " ".join(entity_words)
    return {
        "from_name": "label", "to_name": "text", "type": "labels",
        "value": {
            "start": start_char, "end": end_char, "text": span_text,
            "labels": [label],
            "bbox": {"x": x0, "y": y0, "width": x1 - x0, "height": y1 - y0}
        }, "score": 0.99
    }

def convert_raw_predictions_to_label_studio(page_data_list, output_path: str):
    final_tasks = []
    print("\n[PHASE: LABEL STUDIO CONVERSION]")
    for page_data in page_data_list:
        page_num = page_data['page_number']
        page_results = page_data['data']
        if not page_results: continue
        original_words = [r['word'] for r in page_results]
        text_string = " ".join(original_words)
        results = []
        current_entity_label = None
        current_entity_start_word_index = None

        for i, pred_item in enumerate(page_results):
            label = pred_item['predicted_label']
            tag_only = label.split('-', 1)[-1] if '-' in label else label
            if label.startswith('B-'):
                if current_entity_label:
                    results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label))
                current_entity_label = tag_only
                current_entity_start_word_index = i
            elif label.startswith('I-') and current_entity_label == tag_only:
                continue
            else:
                if current_entity_label:
                    results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label))
                current_entity_label = None
                current_entity_start_word_index = None
        if current_entity_label:
            results.append(create_label_studio_span(page_results, current_entity_start_word_index, len(page_results) - 1, current_entity_label))

        final_tasks.append({
            "data": {
                "text": text_string, "original_words": original_words,
                "original_bboxes": [r['bbox'] for r in page_results]
            },
            "annotations": [{"result": results}],
            "meta": {"page_number": page_num}
        })
    with open(output_path, "w", encoding='utf-8') as f:
        json.dump(final_tasks, f, indent=2, ensure_ascii=False)
    print(f"\n✅ Label Studio tasks saved to {output_path}.")


# ============================================================================
# --- PHASE 3: BIO TO STRUCTURED JSON DECODER ---
# ============================================================================



def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]:
    print("\n" + "=" * 80)
    print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---")
    print("=" * 80)
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            predictions_by_page = json.load(f)
    except Exception as e:
        print(f"❌ Error loading raw prediction file: {e}")
        return None

    predictions = []
    for page_item in predictions_by_page:
        if isinstance(page_item, dict) and 'data' in page_item:
            predictions.extend(page_item['data'])

    structured_data = []
    current_item = None
    current_option_key = None
    current_passage_buffer = []
    current_text_buffer = []
    first_question_started = False
    last_entity_type = None
    just_finished_i_option = False
    is_in_new_passage = False

    def finalize_passage_to_item(item, passage_buffer):
        if passage_buffer:
            passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
            if item.get('passage'): item['passage'] += ' ' + passage_text
            else: item['passage'] = passage_text
        passage_buffer.clear()

    for item in predictions:
        word = item['word']
        label = item['predicted_label']
        entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
        current_text_buffer.append(word)
        previous_entity_type = last_entity_type
        is_passage_label = (entity_type == 'PASSAGE')

        if not first_question_started:
            if label != 'B-QUESTION' and not is_passage_label:
                just_finished_i_option = False
                is_in_new_passage = False
                continue
            if is_passage_label:
                current_passage_buffer.append(word)
                last_entity_type = 'PASSAGE'
                just_finished_i_option = False
                is_in_new_passage = False
                continue

        if label == 'B-QUESTION':
            if not first_question_started:
                header_text = ' '.join(current_text_buffer[:-1]).strip()
                if header_text or current_passage_buffer:
                    metadata_item = {'type': 'METADATA', 'passage': ''}
                    finalize_passage_to_item(metadata_item, current_passage_buffer)
                    if header_text: metadata_item['text'] = header_text
                    structured_data.append(metadata_item)
                first_question_started = True
                current_text_buffer = [word]

            if current_item is not None:
                finalize_passage_to_item(current_item, current_passage_buffer)
                current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
                structured_data.append(current_item)
                current_text_buffer = [word]

            current_item = {
                'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
            }
            current_option_key = None
            last_entity_type = 'QUESTION'
            just_finished_i_option = False
            is_in_new_passage = False
            continue

        if current_item is not None:
            if is_in_new_passage:
                # 🔑 Robust Initialization and Appending for 'new_passage'
                if 'new_passage' not in current_item:
                    current_item['new_passage'] = word
                else:
                    current_item['new_passage'] += f' {word}'

                if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
                    is_in_new_passage = False
                if label.startswith(('B-', 'I-')): last_entity_type = entity_type
                continue
            is_in_new_passage = False

            if label.startswith('B-'):
                if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
                    finalize_passage_to_item(current_item, current_passage_buffer)
                    current_passage_buffer = []
                last_entity_type = entity_type
                if entity_type == 'PASSAGE':
                    if previous_entity_type == 'OPTION' and just_finished_i_option:
                        current_item['new_passage'] = word # Initialize the new passage start
                        is_in_new_passage = True
                    else:
                        current_passage_buffer.append(word)
                elif entity_type == 'OPTION':
                    current_option_key = word
                    current_item['options'][current_option_key] = word
                    just_finished_i_option = False
                elif entity_type == 'ANSWER':
                    current_item['answer'] = word
                    current_option_key = None
                    just_finished_i_option = False
                elif entity_type == 'QUESTION':
                    current_item['question'] += f' {word}'
                    just_finished_i_option = False

            elif label.startswith('I-'):
                if entity_type == 'QUESTION':
                    current_item['question'] += f' {word}'
                elif entity_type == 'PASSAGE':
                    if previous_entity_type == 'OPTION' and just_finished_i_option:
                        current_item['new_passage'] = word # Initialize the new passage start
                        is_in_new_passage = True
                    else:
                        if not current_passage_buffer: last_entity_type = 'PASSAGE'
                        current_passage_buffer.append(word)
                elif entity_type == 'OPTION' and current_option_key is not None:
                    current_item['options'][current_option_key] += f' {word}'
                    just_finished_i_option = True
                elif entity_type == 'ANSWER':
                    current_item['answer'] += f' {word}'
                just_finished_i_option = (entity_type == 'OPTION')

            elif label == 'O':
                if last_entity_type == 'QUESTION':
                    current_item['question'] += f' {word}'
                just_finished_i_option = False

    if current_item is not None:
        finalize_passage_to_item(current_item, current_passage_buffer)
        current_item['text'] = ' '.join(current_text_buffer).strip()
        structured_data.append(current_item)
    
    for item in structured_data:
        item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
        if 'new_passage' in item:
            item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()

    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(structured_data, f, indent=2, ensure_ascii=False)
    except Exception: pass

    return structured_data

def create_query_text(entry: Dict[str, Any]) -> str:
    """Combines question and options into a single string for similarity matching."""
    query_parts = []
    if entry.get("question"):
        query_parts.append(entry["question"])

    for key in ["options", "options_text"]:
        options = entry.get(key)
        if options and isinstance(options, dict):
            for value in options.values():
                if value and isinstance(value, str):
                    query_parts.append(value)
    return " ".join(query_parts)

    

def calculate_similarity(doc1: str, doc2: str) -> float:
    """Calculates Cosine Similarity between two text strings."""
    if not doc1 or not doc2:
        return 0.0

    def clean_text(text):
        return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE)

    clean_doc1 = clean_text(doc1)
    clean_doc2 = clean_text(doc2)
    corpus = [clean_doc1, clean_doc2]

    try:
        vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b')
        tfidf_matrix = vectorizer.fit_transform(corpus)
        if tfidf_matrix.shape[1] == 0:
            return 0.0
        vectors = tfidf_matrix.toarray()
        # Handle cases where vectors might be empty or too short
        if len(vectors) < 2:
             return 0.0
        score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0]
        return score
    except Exception:
        return 0.0





def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Links questions to passages based on 'passage' flow vs 'new_passage' priority.
    Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage,
    the passage context is dropped to prevent false positives downstream.
    """
    print("\n" + "=" * 80)
    print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---")
    print("=" * 80)
    
    if not data: return []

    # --- PHASE 1: IDENTIFY PASSAGE DEFINERS ---
    passage_definer_indices = []
    for i, entry in enumerate(data):
        if entry.get("passage") and entry["passage"].strip():
            passage_definer_indices.append(i)
        if entry.get("new_passage") and entry["new_passage"].strip():
            if i not in passage_definer_indices:
                passage_definer_indices.append(i)

    # --- PHASE 2: CONTEXT TRANSFER & LINKING ---
    current_passage_text = None
    current_new_passage_text = None
    
    # NEW: Counter to track consecutive linking failures
    consecutive_failures = 0 
    MAX_CONSECUTIVE_FAILURES = 2

    for i, entry in enumerate(data):
        item_type = entry.get("type", "Question")

        # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter)
        if entry.get("passage") and entry["passage"].strip():
            current_passage_text = entry["passage"]
            consecutive_failures = 0 # Reset because we have fresh explicit context
            # print(f"  [Flow] Updated Standard Context from Item {i}")
        
        if entry.get("new_passage") and entry["new_passage"].strip():
            current_new_passage_text = entry["new_passage"]
            # We don't necessarily reset standard failures here as this is a local override

        # B. QUESTION LINKING
        if entry.get("question") and item_type != "METADATA":
            combined_query = create_query_text(entry)
            
            # Skip if query is too short (noise)
            if len(combined_query.strip()) < 5: 
                continue

            # Calculate scores
            score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0
            score_new = calculate_similarity(current_new_passage_text, combined_query) if current_new_passage_text else 0.0

            q_preview = entry['question'][:30] + '...'
            
            # RESOLUTION LOGIC
            linked = False
            
            # 1. Prefer New Passage if significantly better
            if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and (score_new >= SIMILARITY_THRESHOLD):
                entry["passage"] = current_new_passage_text
                print(f"  [Linker] 🚀 Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})")
                linked = True
                # Note: We do not reset 'consecutive_failures' for the standard passage here,
                # because we matched the *new* passage, not the standard one.
            
            # 2. Otherwise use Standard Passage if it meets threshold
            elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD):
                entry["passage"] = current_passage_text
                print(f"  [Linker] ✅ Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})")
                linked = True
                consecutive_failures = 0 # Success! Reset the kill switch.
            
            if not linked:
                # 3. DECAY LOGIC
                if current_passage_text:
                    consecutive_failures += 1
                    print(f"  [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
                    
                    if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
                        print(f"  [Linker] 🗑️  Context dropped due to {consecutive_failures} consecutive misses.")
                        current_passage_text = None
                        consecutive_failures = 0
                else:
                    print(f"  [Linker] ⚠️ Q{i} NOT LINKED (No active context).")

    # --- PHASE 3: CLEANUP AND INTERPOLATION ---
    print("  [Linker] Running Cleanup & Interpolation...")
    
    # 3A. Self-Correction (Remove weak links)
    for i in passage_definer_indices:
        entry = data[i]
        if entry.get("question") and entry.get("type") != "METADATA":
            passage_to_check = entry.get("passage") or entry.get("new_passage")
            if passage_to_check:
                self_sim = calculate_similarity(passage_to_check, create_query_text(entry))
                if self_sim < SIMILARITY_THRESHOLD:
                    entry["passage"] = ""
                    if "new_passage" in entry: entry["new_passage"] = ""
                    print(f"  [Cleanup] Removed weak link for Q{i}")

    # 3B. Interpolation (Fill gaps)
    # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic
    for i in range(1, len(data) - 1):
        current_entry = data[i]
        is_gap = current_entry.get("question") and not current_entry.get("passage")
        if is_gap:
            prev_p = data[i - 1].get("passage")
            next_p = data[i + 1].get("passage")
            if prev_p and next_p and (prev_p == next_p) and prev_p.strip():
                current_entry["passage"] = prev_p
                print(f"  [Linker] 🥪 Q{i} Interpolated from neighbors.")

    return data


def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    print("\n" + "=" * 80)
    print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---")
    print("=" * 80)
    tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)')
    corrected_count = 0
    for item in structured_data:
        if item.get('type') in ['METADATA']: continue
        options = item.get('options')
        if not options or len(options) < 2: continue
        option_keys = list(options.keys())
        for i in range(len(option_keys) - 1):
            current_key = option_keys[i]
            next_key = option_keys[i + 1]
            current_value = options[current_key].strip()
            next_value = options[next_key].strip()
            is_current_empty = current_value == current_key
            content_in_next = next_value.replace(next_key, '', 1).strip()
            tags_in_next = tag_pattern.findall(content_in_next)
            has_two_tags = len(tags_in_next) == 2
            if is_current_empty and has_two_tags:
                tag_to_move = tags_in_next[0]
                options[current_key] = f"{current_key} {tag_to_move}".strip()
                options[next_key] = f"{next_key} {tags_in_next[1]}".strip()
                corrected_count += 1
    print(f"✅ Option alignment correction finished. Total corrections: {corrected_count}.")
    return structured_data




# ============================================================================
# --- PHASE 4: IMAGE EMBEDDING (Base64) ---
# ============================================================================

def get_base64_for_file(filepath: str) -> str:
    try:
        with open(filepath, 'rb') as f:
            return base64.b64encode(f.read()).decode('utf-8')
    except Exception as e:
        print(f"  ❌ Error encoding file {filepath}: {e}")
        return ""

def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[Dict[str, Any]]:
    print("\n" + "=" * 80)
    print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
    print("=" * 80)
    if not structured_data: return []
    image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
    image_lookup = {}
    tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
    for filepath in image_files:
        filename = os.path.basename(filepath)
        match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE)
        if match:
            key = f"{match.group(1).upper()}{match.group(2)}"
            image_lookup[key] = filepath
    print(f"  -> Found {len(image_lookup)} image components.")
    final_structured_data = []
    for item in structured_data:
        text_fields = [item.get('question', ''), item.get('passage', '')]
        if 'options' in item:
            for opt_val in item['options'].values(): text_fields.append(opt_val)
        if 'new_passage' in item: text_fields.append(item['new_passage'])
        unique_tags_to_embed = set()
        for text in text_fields:
            if not text: continue
            for match in tag_regex.finditer(text):
                tag = match.group(0).upper()
                if tag in image_lookup: unique_tags_to_embed.add(tag)
        for tag in sorted(list(unique_tags_to_embed)):
            filepath = image_lookup[tag]
            base64_code = get_base64_for_file(filepath)
            base_key = tag.replace(' ', '').lower()
            item[base_key] = base64_code
        final_structured_data.append(item)
    print(f"✅ Image embedding complete.")
    return final_structured_data

# ============================================================================
# --- MAIN FUNCTION ---
# ============================================================================

def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[List[Dict[str, Any]]]:
    if not os.path.exists(input_pdf_path): return None
    
    print("\n" + "#" * 80)
    print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
    print("#" * 80)
    
    pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
    temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
    os.makedirs(temp_pipeline_dir, exist_ok=True)
    
    preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
    raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
    structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
    
    final_result = None
    try:
        # Phase 1: Preprocessing with YOLO First + Masking
        preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
        if not preprocessed_json_path_out: return None
        
        # Phase 2: Inference
        page_raw_predictions_list = run_inference_and_get_raw_words(
            input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
        )
        if not page_raw_predictions_list: return None
        
        with open(raw_output_path, 'w', encoding='utf-8') as f:
            json.dump(page_raw_predictions_list, f, indent=4)
            
        # Phase 3: Decoding
        structured_data_list = convert_bio_to_structured_json_relaxed(
            raw_output_path, structured_intermediate_output_path
        )
        if not structured_data_list: return None
        structured_data_list = correct_misaligned_options(structured_data_list)
        structured_data_list = process_context_linking(structured_data_list)
        
        try:
            convert_raw_predictions_to_label_studio(page_raw_predictions_list, label_studio_output_path)
        except Exception as e:
            print(f"❌ Error during Label Studio conversion: {e}")
            
        # Phase 4: Embedding
        final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
        
    except Exception as e:
        print(f"❌ FATAL ERROR: {e}")
        import traceback
        traceback.print_exc()
        return None
        
    finally:
        try:
            for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
                os.remove(f)
            os.rmdir(temp_pipeline_dir)
        except Exception: pass
        
    print("\n" + "#" * 80)
    print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###")
    print("#" * 80)
    return final_result


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Complete Pipeline")
    parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF")
    parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path")
    parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path")
    args = parser.parse_args()
    
    pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0]
    final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json")
    ls_output_path = os.path.abspath(args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json")
    
    final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path)
    
    if final_json_data:
        with open(final_output_path, 'w', encoding='utf-8') as f:
            json.dump(final_json_data, f, indent=2, ensure_ascii=False)
        print(f"\n✅ Final Data Saved: {final_output_path}")
    else:
        print("\n❌ Pipeline Failed.")
        sys.exit(1)