Spaces:
Sleeping
Sleeping
File size: 4,216 Bytes
da7d46c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib
class GrammarAgent:
def __init__(
self,
model_name: str = "vennify/t5-base-grammar-correction",
device: Optional[str] = None,
):
"""
Grammar Agent
- model_name: HF model id
- device: "cuda" or "cpu" (auto-detect if None)
"""
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model.to(self.device)
def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
"""
Internal helper to call the model and get corrected text.
"""
prefixed = "grammar: " + text # T5-style task prefix
inputs = self.tokenizer(
prefixed,
return_tensors="pt",
truncation=True,
max_length=max_length,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected.strip()
def _diff_explanation(self, original: str, corrected: str):
"""
Create a simple, human-readable explanation of changes.
Returns a list of {type, from, to}.
"""
diff = list(difflib.ndiff(original.split(), corrected.split()))
changes = []
current_del = []
current_add = []
for token in diff:
if token.startswith("- "):
current_del.append(token[2:])
elif token.startswith("+ "):
current_add.append(token[2:])
elif token.startswith(" "):
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
current_del, current_add = [], []
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
changes = [c for c in changes if c["from"] or c["to"]]
return changes
@staticmethod
def _infer_change_type(deleted_tokens, added_tokens):
"""
Very simple heuristic for change type.
You can later improve this with more logic.
"""
if deleted_tokens and not added_tokens:
return "deletion"
if added_tokens and not deleted_tokens:
return "insertion"
return "replacement"
def correct(self, text: str) -> dict:
"""
Main method your system will call.
Returns a dict:
{
"original": ...,
"corrected": ...,
"changes": [ {type, from, to}, ... ],
"confidence": float,
"agent": "grammar"
}
"""
corrected = self._generate(text)
changes = self._diff_explanation(text, corrected)
# simple heuristic confidence based on how much was changed
change_ratio = len(changes) / max(len(text.split()), 1)
confidence = max(0.3, 1.0 - change_ratio)
return {
"original": text,
"corrected": corrected,
"changes": changes,
"confidence": round(confidence, 2),
"agent": "grammar",
}
|