TextDoctor / agents /grammar_agent.py
FurqanIshaq's picture
Update agents/grammar_agent.py
da7d46c verified
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib
class GrammarAgent:
def __init__(
self,
model_name: str = "vennify/t5-base-grammar-correction",
device: Optional[str] = None,
):
"""
Grammar Agent
- model_name: HF model id
- device: "cuda" or "cpu" (auto-detect if None)
"""
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model.to(self.device)
def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
"""
Internal helper to call the model and get corrected text.
"""
prefixed = "grammar: " + text # T5-style task prefix
inputs = self.tokenizer(
prefixed,
return_tensors="pt",
truncation=True,
max_length=max_length,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected.strip()
def _diff_explanation(self, original: str, corrected: str):
"""
Create a simple, human-readable explanation of changes.
Returns a list of {type, from, to}.
"""
diff = list(difflib.ndiff(original.split(), corrected.split()))
changes = []
current_del = []
current_add = []
for token in diff:
if token.startswith("- "):
current_del.append(token[2:])
elif token.startswith("+ "):
current_add.append(token[2:])
elif token.startswith(" "):
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
current_del, current_add = [], []
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
changes = [c for c in changes if c["from"] or c["to"]]
return changes
@staticmethod
def _infer_change_type(deleted_tokens, added_tokens):
"""
Very simple heuristic for change type.
You can later improve this with more logic.
"""
if deleted_tokens and not added_tokens:
return "deletion"
if added_tokens and not deleted_tokens:
return "insertion"
return "replacement"
def correct(self, text: str) -> dict:
"""
Main method your system will call.
Returns a dict:
{
"original": ...,
"corrected": ...,
"changes": [ {type, from, to}, ... ],
"confidence": float,
"agent": "grammar"
}
"""
corrected = self._generate(text)
changes = self._diff_explanation(text, corrected)
# simple heuristic confidence based on how much was changed
change_ratio = len(changes) / max(len(text.split()), 1)
confidence = max(0.3, 1.0 - change_ratio)
return {
"original": text,
"corrected": corrected,
"changes": changes,
"confidence": round(confidence, 2),
"agent": "grammar",
}