TextDoctor / agents /clarity_agent.py
FurqanIshaq's picture
Update agents/clarity_agent.py
ea4ded9 verified
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib
class ClarityAgent:
def __init__(
self,
model_name: str = "Vamsi/T5_Paraphrase_Paws",
device: Optional[str] = None,
):
"""
Clarity Agent
- Uses a paraphrasing model to restate sentences more clearly.
- model_name: Hugging Face model ID.
- device: "cuda" or "cpu" (auto-detect if None).
"""
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model.to(self.device)
def _generate(
self,
text: str,
max_length: int = 256,
num_beams: int = 5,
num_return_sequences: int = 1,
) -> str:
"""
Internal helper to generate a clearer paraphrase of the input sentence.
"""
# Many T5 paraphrase models expect a prefix like "paraphrase: "
prefixed = "paraphrase: " + text + " </s>"
inputs = self.tokenizer(
[prefixed],
max_length=max_length,
padding="longest",
truncation=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
early_stopping=True,
)
# Take the first generated sequence
paraphrased = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return paraphrased.strip()
def _diff_explanation(self, original: str, clarified: str):
"""
Compare original vs clarified sentence and return simple word-level changes.
"""
diff = list(difflib.ndiff(original.split(), clarified.split()))
changes = []
current_del = []
current_add = []
for token in diff:
if token.startswith("- "):
current_del.append(token[2:])
elif token.startswith("+ "):
current_add.append(token[2:])
elif token.startswith(" "):
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
current_del, current_add = [], []
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
changes = [c for c in changes if c["from"] or c["to"]]
return changes
@staticmethod
def _infer_change_type(deleted_tokens, added_tokens):
if deleted_tokens and not added_tokens:
return "deletion"
if added_tokens and not deleted_tokens:
return "insertion"
return "replacement"
def clarify(self, text: str) -> dict:
"""
Main method for TextDoctor.
Returns:
{
"original": ...,
"clarified": ...,
"changes": [ {type, from, to}, ... ],
"confidence": float,
"agent": "clarity"
}
"""
clarified = self._generate(text)
changes = self._diff_explanation(text, clarified)
# simple heuristic confidence
change_ratio = len(changes) / max(len(text.split()), 1)
confidence = max(0.3, 1.0 - change_ratio)
return {
"original": text,
"clarified": clarified,
"changes": changes,
"confidence": round(confidence, 2),
"agent": "clarity",
}