Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import difflib | |
| class GrammarAgent: | |
| def __init__( | |
| self, | |
| model_name: str = "vennify/t5-base-grammar-correction", | |
| device: Optional[str] = None, | |
| ): | |
| """ | |
| Grammar Agent | |
| - model_name: HF model id | |
| - device: "cuda" or "cpu" (auto-detect if None) | |
| """ | |
| self.model_name = model_name | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| self.model.to(self.device) | |
| def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str: | |
| """ | |
| Internal helper to call the model and get corrected text. | |
| """ | |
| prefixed = "grammar: " + text # T5-style task prefix | |
| inputs = self.tokenizer( | |
| prefixed, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_length, | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| ) | |
| corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return corrected.strip() | |
| def _diff_explanation(self, original: str, corrected: str): | |
| """ | |
| Create a simple, human-readable explanation of changes. | |
| Returns a list of {type, from, to}. | |
| """ | |
| diff = list(difflib.ndiff(original.split(), corrected.split())) | |
| changes = [] | |
| current_del = [] | |
| current_add = [] | |
| for token in diff: | |
| if token.startswith("- "): | |
| current_del.append(token[2:]) | |
| elif token.startswith("+ "): | |
| current_add.append(token[2:]) | |
| elif token.startswith(" "): | |
| if current_del or current_add: | |
| changes.append( | |
| { | |
| "from": " ".join(current_del) if current_del else None, | |
| "to": " ".join(current_add) if current_add else None, | |
| "type": self._infer_change_type(current_del, current_add), | |
| } | |
| ) | |
| current_del, current_add = [], [] | |
| if current_del or current_add: | |
| changes.append( | |
| { | |
| "from": " ".join(current_del) if current_del else None, | |
| "to": " ".join(current_add) if current_add else None, | |
| "type": self._infer_change_type(current_del, current_add), | |
| } | |
| ) | |
| changes = [c for c in changes if c["from"] or c["to"]] | |
| return changes | |
| def _infer_change_type(deleted_tokens, added_tokens): | |
| """ | |
| Very simple heuristic for change type. | |
| You can later improve this with more logic. | |
| """ | |
| if deleted_tokens and not added_tokens: | |
| return "deletion" | |
| if added_tokens and not deleted_tokens: | |
| return "insertion" | |
| return "replacement" | |
| def correct(self, text: str) -> dict: | |
| """ | |
| Main method your system will call. | |
| Returns a dict: | |
| { | |
| "original": ..., | |
| "corrected": ..., | |
| "changes": [ {type, from, to}, ... ], | |
| "confidence": float, | |
| "agent": "grammar" | |
| } | |
| """ | |
| corrected = self._generate(text) | |
| changes = self._diff_explanation(text, corrected) | |
| # simple heuristic confidence based on how much was changed | |
| change_ratio = len(changes) / max(len(text.split()), 1) | |
| confidence = max(0.3, 1.0 - change_ratio) | |
| return { | |
| "original": text, | |
| "corrected": corrected, | |
| "changes": changes, | |
| "confidence": round(confidence, 2), | |
| "agent": "grammar", | |
| } | |