from typing import Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import difflib class GrammarAgent: def __init__( self, model_name: str = "vennify/t5-base-grammar-correction", device: Optional[str] = None, ): """ Grammar Agent - model_name: HF model id - device: "cuda" or "cpu" (auto-detect if None) """ self.model_name = model_name self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.model.to(self.device) def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str: """ Internal helper to call the model and get corrected text. """ prefixed = "grammar: " + text # T5-style task prefix inputs = self.tokenizer( prefixed, return_tensors="pt", truncation=True, max_length=max_length, ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True, ) corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected.strip() def _diff_explanation(self, original: str, corrected: str): """ Create a simple, human-readable explanation of changes. Returns a list of {type, from, to}. """ diff = list(difflib.ndiff(original.split(), corrected.split())) changes = [] current_del = [] current_add = [] for token in diff: if token.startswith("- "): current_del.append(token[2:]) elif token.startswith("+ "): current_add.append(token[2:]) elif token.startswith(" "): if current_del or current_add: changes.append( { "from": " ".join(current_del) if current_del else None, "to": " ".join(current_add) if current_add else None, "type": self._infer_change_type(current_del, current_add), } ) current_del, current_add = [], [] if current_del or current_add: changes.append( { "from": " ".join(current_del) if current_del else None, "to": " ".join(current_add) if current_add else None, "type": self._infer_change_type(current_del, current_add), } ) changes = [c for c in changes if c["from"] or c["to"]] return changes @staticmethod def _infer_change_type(deleted_tokens, added_tokens): """ Very simple heuristic for change type. You can later improve this with more logic. """ if deleted_tokens and not added_tokens: return "deletion" if added_tokens and not deleted_tokens: return "insertion" return "replacement" def correct(self, text: str) -> dict: """ Main method your system will call. Returns a dict: { "original": ..., "corrected": ..., "changes": [ {type, from, to}, ... ], "confidence": float, "agent": "grammar" } """ corrected = self._generate(text) changes = self._diff_explanation(text, corrected) # simple heuristic confidence based on how much was changed change_ratio = len(changes) / max(len(text.split()), 1) confidence = max(0.3, 1.0 - change_ratio) return { "original": text, "corrected": corrected, "changes": changes, "confidence": round(confidence, 2), "agent": "grammar", }