CAIAssignmentGradio / generate_indexes.py
rishabhsetiya's picture
Update generate_indexes.py
ef1795c verified
import os
import re
import json
import pickle
from typing import List, Dict
import numpy as np
import faiss
import pandas as pd
import tabula
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
# ---------------- Config ----------------
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
PDF_PATH = "MakeMyTrip_Financial_Statements.pdf"
OUT_DIR = "data/index_merged"
# Paths for saved chunks & indices
CHUNKS_100_PATH = os.path.join(OUT_DIR, "chunks_100.json")
CHUNKS_400_PATH = os.path.join(OUT_DIR, "chunks_400.json")
CHUNKS_MERGED_PATH = os.path.join(OUT_DIR, "chunks_merged.json")
FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index")
BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl")
META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl")
# ---------------- Utils ----------------
_tok_pat = re.compile(r"[a-z0-9]+", re.I)
def simple_tokenize(text: str):
return _tok_pat.findall((text or "").lower())
def create_chunks(texts: List[str], max_tokens: int) -> List[str]:
"""Simple word-based tokenizer to split texts into chunks."""
chunks, current_chunk, current_tokens = [], [], 0
for text in texts:
tokens = re.findall(r"\w+", text)
if current_tokens + len(tokens) > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk, current_tokens = [], 0
current_chunk.append(text)
current_tokens += len(tokens)
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def extract_tables_from_pdf(pdf_path: str, pages="all") -> List[Dict]:
"""Extract tables from financial PDF into structured row-year-value dicts."""
tables = tabula.read_pdf(
pdf_path,
pages=pages,
multiple_tables=True,
pandas_options={'dtype': str}
)
table_rows = []
row_id = 0
for df in tables:
if df.empty:
continue
df = df.replace(r'\n', ' ', regex=True).fillna("")
headers = list(df.iloc[0])
if any(re.match(r"20\d{2}", str(c)) for c in headers):
df.columns = [c.strip() for c in headers]
df = df.drop(0).reset_index(drop=True)
for _, row in df.iterrows():
metric = str(row.iloc[0]).strip()
if not metric or metric.lower() in ["note", ""]:
continue
values = {}
for col, val in row.items():
if re.match(r"20\d{2}", str(col)):
clean_val = str(val).replace(",", "").strip()
if clean_val and clean_val not in ["-", "—", "nan"]:
values[str(col)] = clean_val
if not values:
continue
table_rows.append({
"id": f"table-{row_id}",
"metric": metric,
"years": list(values.keys()),
"values": values,
"content": f"{metric} values: {json.dumps(values)}",
"source": "table"
})
row_id += 1
print(f"Extracted {len(table_rows)} rows from PDF tables")
return table_rows
def build_dense_faiss(texts: List[str], out_path: str):
print(f"Embedding {len(texts)} docs with {EMBED_MODEL} ...")
model = SentenceTransformer(EMBED_MODEL)
emb = model.encode(texts, convert_to_numpy=True, batch_size=64, show_progress_bar=True)
faiss.normalize_L2(emb)
dim = emb.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(emb)
faiss.write_index(index, out_path)
print(f"FAISS index built & saved -> {out_path}")
def build_bm25(texts: List[str], out_path: str):
tokenized = [simple_tokenize(t) for t in texts]
bm25 = BM25Okapi(tokenized)
with open(out_path, "wb") as f:
pickle.dump({"bm25": bm25, "tokenized_corpus": tokenized}, f)
print(f"BM25 index built & saved -> {out_path}")
# ---------------- Main ----------------
def main():
os.makedirs(OUT_DIR, exist_ok=True)
# 1) Extract table rows
docs = extract_tables_from_pdf(PDF_PATH, pages="all")
all_texts = [d["content"] for d in docs]
# 2) Create chunks of size 100 and 400
chunks_100 = create_chunks(all_texts, 100)
chunks_400 = create_chunks(all_texts, 400)
# 3) Save them separately
with open(CHUNKS_100_PATH, "w", encoding="utf-8") as f:
json.dump(chunks_100, f, indent=2, ensure_ascii=False)
with open(CHUNKS_400_PATH, "w", encoding="utf-8") as f:
json.dump(chunks_400, f, indent=2, ensure_ascii=False)
print(f"Saved {len(chunks_100)} chunks_100 -> {CHUNKS_100_PATH}")
print(f"Saved {len(chunks_400)} chunks_400 -> {CHUNKS_400_PATH}")
# 4) Merge with metadata
merged = []
for i, ch in enumerate(chunks_100):
merged.append({"id": f"100-{i}", "chunk_size": 100, "content": ch})
for i, ch in enumerate(chunks_400):
merged.append({"id": f"400-{i}", "chunk_size": 400, "content": ch})
# 5) Save merged chunks
with open(CHUNKS_MERGED_PATH, "w", encoding="utf-8") as f:
json.dump(merged, f, indent=2, ensure_ascii=False)
print(f"Saved {len(merged)} merged chunks -> {CHUNKS_MERGED_PATH}")
# 6) Build FAISS & BM25 on merged chunks
texts = [m["content"] for m in merged]
build_dense_faiss(texts, FAISS_PATH)
build_bm25(texts, BM25_PATH)
# 7) Save metadata
with open(META_PATH, "wb") as f:
pickle.dump(merged, f)
print(f"Saved metadata -> {META_PATH}")
print("\n✅ Done. Created 100 + 400 chunks separately and merged them for unified FAISS & BM25 indexes!")
if __name__ == "__main__":
main()