import os import math import re from typing import List, Tuple, Optional import gradio as gr import numpy as np from sklearn.cluster import KMeans import torch from transformers import ( AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, AutoConfig, ) # ----------------------------- # Defaults (feel free to change) # ----------------------------- DEFAULT_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # fast & solid # For legal focus, you can try: "nlpaueb/legal-bert-base-uncased" with mean-pooled embeddings (slower) DEFAULT_ABS_MODEL = "pszemraj/led-large-book-summary" # good LED variant for long docs FALLBACK_ABS_MODEL = "allenai/led-base-16384" MAX_INPUT_TOKENS = 12000 # safety cap before chunking for LED WINDOW_TOKENS = 3500 # per chunk for LED/Long models OVERLAP_TOKENS = 250 # overlap between chunks DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Utilities # ----------------------------- def simple_sentence_split(text: str) -> List[str]: """ Robust-enough sentence splitter without external downloads. Splits by . ! ? and newline while keeping abbreviations modestly safe. """ # Normalize whitespace text = re.sub(r"\s+", " ", text).strip() # Anchor on punctuation that likely ends sentences candidates = re.split(r"(?<=[.!?])\s+", text) # Merge tiny fragments back (e.g., "No." followed by "23.") merged = [] buf = "" for c in candidates: frag = c.strip() if not frag: continue if not buf: buf = frag else: # if the fragment is very short (like section numbers), attach it back if len(frag) <= 3 and re.match(r"^[\(\)\[\]\dA-Za-z\-:;+.,]+$", frag): buf += " " + frag else: merged.append(buf) buf = frag if buf: merged.append(buf) # Filter empties and duplicates merged = [s.strip() for s in merged if s.strip()] return merged def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Mean-pool token embeddings with attention mask.""" mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() summed = torch.sum(last_hidden_state * mask, dim=1) counts = torch.clamp(mask.sum(dim=1), min=1e-9) return summed / counts def load_embedder(model_name: str): """ Load an embedding model. If it's a sentence-transformers model, AutoModel works; we do manual mean pooling. """ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) mdl = AutoModel.from_pretrained(model_name, trust_remote_code=True) mdl.to(DEVICE) mdl.eval() return tok, mdl @torch.inference_mode() def embed_sentences(sentences: List[str], tok, mdl, batch_size: int = 32) -> np.ndarray: vecs = [] for i in range(0, len(sentences), batch_size): batch = sentences[i:i+batch_size] enc = tok( batch, padding=True, truncation=True, max_length=256, return_tensors="pt" ) enc = {k: v.to(DEVICE) for k, v in enc.items()} out = mdl(**enc) sent_emb = mean_pooling(out.last_hidden_state, enc["attention_mask"]) sent_emb = torch.nn.functional.normalize(sent_emb, p=2, dim=1) vecs.append(sent_emb.cpu().numpy()) return np.vstack(vecs) def choose_k(n_sent: int, user_k: Optional[int]) -> int: if user_k and user_k > 0: return min(user_k, n_sent) # heuristic: sqrt(n) but clamped k = max(5, int(math.sqrt(max(1, n_sent)))) return min(k, n_sent) def kmeans_select(sentences: List[str], embeddings: np.ndarray, k: int, pick: int = 1) -> List[int]: """ KMeans and pick `pick` sentences closest to each centroid. Returns sorted indices to preserve a logical reading flow. """ if k <= 0 or len(sentences) == 0: return [] # Edge case: fewer sentences than k k = min(k, len(sentences)) kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) labels = kmeans.fit_predict(embeddings) chosen = [] for c in range(k): idxs = np.where(labels == c)[0] # distances to centroid dists = np.linalg.norm(embeddings[idxs] - kmeans.cluster_centers_[c], axis=1) local_order = np.argsort(dists)[:pick] chosen.extend(idxs[local_order].tolist()) chosen = sorted(set(chosen)) return chosen def load_abstractive_model(model_name: str): cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True) tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) mdl.to(DEVICE) mdl.eval() return tok, mdl, cfg def chunk_tokens(tokens: List[int], window: int, overlap: int) -> List[List[int]]: if window <= 0: return [tokens] chunks = [] i = 0 while i < len(tokens): chunks.append(tokens[i:i+window]) i += max(1, window - overlap) return chunks @torch.inference_mode() def run_abstractive( text: str, model_name: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, min_len: int = 40, window_tokens: int = WINDOW_TOKENS, overlap_tokens: int = OVERLAP_TOKENS, ) -> str: tok, mdl, cfg = load_abstractive_model(model_name) # Tokenize large text and process in windows enc = tok(text, return_tensors="pt", truncation=False) input_ids = enc["input_ids"].squeeze(0).tolist() if len(input_ids) <= window_tokens: enc = tok(text, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE) gen = mdl.generate( **enc, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.1, no_repeat_ngram_size=3, min_length=min_len, ) return tok.decode(gen[0], skip_special_tokens=True) # sliding window parts = [] for chunk in chunk_tokens(input_ids, window_tokens, overlap_tokens): enc_chunk = {"input_ids": torch.tensor([chunk]).to(DEVICE), "attention_mask": torch.ones((1, len(chunk)), dtype=torch.long, device=DEVICE)} gen = mdl.generate( **enc_chunk, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.1, no_repeat_ngram_size=3, min_length=min_len, ) parts.append(tok.decode(gen[0], skip_special_tokens=True)) # Simple merge; optionally re-summarize the stitched text once more stitched = "\n".join(parts) if len(stitched.split()) > 600: enc2 = tok(stitched, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE) gen2 = mdl.generate( **enc2, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.1, no_repeat_ngram_size=3, min_length=min_len, ) return tok.decode(gen2[0], skip_special_tokens=True) return stitched def pipeline( text: str, embed_model: str, abs_model: Optional[str], k_clusters: Optional[int], pick_per_cluster: int, max_new_tokens: int, temperature: float, top_p: float, min_len: int, ) -> Tuple[str, str, str]: """ Returns: (extractive_core, abstractive_summary, debug_info) """ if not text or not text.strip(): return "", "", "No input text." # 1) sentence split sentences = simple_sentence_split(text) if len(sentences) == 0: return "", "", "No sentences detected after splitting." # 2) embeddings etok, emdl = load_embedder(embed_model) embs = embed_sentences(sentences, etok, emdl, batch_size=32) # 3) clustering + representative pick k = choose_k(len(sentences), k_clusters) chosen_idx = kmeans_select(sentences, embs, k, pick=pick_per_cluster) extractive = " ".join([sentences[i] for i in chosen_idx]) # 4) abstractive (optional) abstractive = "" model_used = abs_model or "" if abs_model and abs_model.strip().lower() != "none": try: abstractive = run_abstractive( extractive if len(extractive) > 0 else text, model_name=abs_model, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, min_len=min_len, ) except Exception as e: # fall back to different LED if available if abs_model != FALLBACK_ABS_MODEL: try: abstractive = run_abstractive( extractive if len(extractive) > 0 else text, model_name=FALLBACK_ABS_MODEL, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, min_len=min_len, ) model_used = f"{abs_model} -> fell back to {FALLBACK_ABS_MODEL}" except Exception as e2: abstractive = "" model_used = f"{abs_model} (failed) & fallback failed" else: abstractive = "" model_used = f"{abs_model} (failed)" # debug dbg = ( f"Device: {DEVICE}\n" f"Embedder: {embed_model}\n" f"Abstractive model: {model_used or 'None'}\n" f"Sentences: {len(sentences)} | K: {k} | Pick/cluster: {pick_per_cluster}\n" f"Chosen indices (sorted): {chosen_idx[:50]}{'...' if len(chosen_idx) > 50 else ''}\n" ) return extractive, (abstractive or extractive), dbg # ----------------------------- # Gradio UI # ----------------------------- EXAMPLE_LEGAL = """IN THE SUPREME COURT OF INDIA Civil Appeal No. 1234 of 2021 The appellant contends that the High Court erred in overlooking binding precedent on limitation. The respondent argues that the delay is inordinate and unexplained. The core issue is whether sufficient cause exists under Section 5 of the Limitation Act. After hearing the parties and perusing the record, we find that the appellant was prevented by bona fide reasons. Accordingly, the delay is condoned subject to costs of Rs. 10,000. The matter is remanded to the High Court for disposal on merits in accordance with law.""" with gr.Blocks(title="Legal Summarizer (K-Means + BERT + LED)") as demo: gr.Markdown( """ # ⚖️ Legal Text Summarizer — K-Means + BERT + LED Upload/paste a judgment or order. We cluster sentences with BERT embeddings (extractive core), then optionally refine with an LED/Long model for an abstractive final summary. - **Embedding model**: any BERT/sentence-transformers model - **Abstractive model**: LED/LongT5/T5 from Hugging Face (can handle long docs with chunking) - Works well on Indian legal text; swap to `nlpaueb/legal-bert-base-uncased` for domain flavor (slower). """ ) with gr.Row(): text_in = gr.Textbox( label="Paste Legal Text", lines=18, placeholder="Paste a long judgment, order, or legal article…", value=EXAMPLE_LEGAL ) with gr.Accordion("Models & Settings", open=True): with gr.Row(): embed_model = gr.Textbox( label="Embedding Model (BERT/Sentence-Transformers)", value=DEFAULT_EMBED_MODEL, info="E.g., sentence-transformers/all-MiniLM-L6-v2 or nlpaueb/legal-bert-base-uncased" ) abs_model = gr.Textbox( label="Abstractive Model (LED/LongT5/T5) or 'none'", value=DEFAULT_ABS_MODEL, info="Try: pszemraj/led-large-book-summary, allenai/led-base-16384, google/long-t5-tglobal-base" ) with gr.Row(): k_clusters = gr.Number(label="K (clusters). Leave 0 to auto (≈√N)", value=0, precision=0) pick_per = gr.Slider(label="Sentences picked per cluster", minimum=1, maximum=3, value=1, step=1) with gr.Row(): max_new = gr.Slider(label="Max new tokens (abstractive)", minimum=64, maximum=1024, value=256, step=16) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.2, value=0.7, step=0.05) top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.05) min_len = gr.Slider(label="Min summary length (tokens)", minimum=10, maximum=200, value=40, step=5) run_btn = gr.Button("Summarize 🚀", variant="primary") extractive_out = gr.Textbox(label="Extractive Core (cluster representatives)", lines=10) abstractive_out = gr.Textbox(label="Final Summary (abstractive if model provided)", lines=10) debug_out = gr.Textbox(label="Debug Info", lines=8) def _go(text, e_model, a_model, k, pick, mx, temp, topp, minl): k = int(k) if k else 0 return pipeline( text=text, embed_model=e_model.strip(), abs_model=a_model.strip() if a_model else "none", k_clusters=k, pick_per_cluster=int(pick), max_new_tokens=int(mx), temperature=float(temp), top_p=float(topp), min_len=int(minl), ) run_btn.click( _go, inputs=[text_in, embed_model, abs_model, k_clusters, pick_per, max_new, temperature, top_p, min_len], outputs=[extractive_out, abstractive_out, debug_out] ) if __name__ == "__main__": # For Spaces/Colab inline preview demo.launch()