TextDoctor / agents /style_agent.py
FurqanIshaq's picture
Update agents/style_agent.py
8d21146 verified
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib
class StyleAgent:
def __init__(
self,
model_name: str = "rajistics/informal_formal_style_transfer",
device: Optional[str] = None,
):
"""
Style Agent
- model_name: HF model id for informal -> formal style transfer
- device: "cuda" or "cpu" (auto-detect if None)
"""
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model.to(self.device)
def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
"""
Internal helper to call the model and get a more formal / professional version.
"""
# Many style-transfer T5 models work directly on raw text.
# If the model card suggests a prefix, add it here, e.g.:
# text = "formal: " + text
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
styled = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return styled.strip()
def _diff_explanation(self, original: str, styled: str):
"""
Compare original vs styled sentence and return simple word-level changes.
"""
diff = list(difflib.ndiff(original.split(), styled.split()))
changes = []
current_del = []
current_add = []
for token in diff:
if token.startswith("- "):
current_del.append(token[2:])
elif token.startswith("+ "):
current_add.append(token[2:])
elif token.startswith(" "):
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
current_del, current_add = [], []
if current_del or current_add:
changes.append(
{
"from": " ".join(current_del) if current_del else None,
"to": " ".join(current_add) if current_add else None,
"type": self._infer_change_type(current_del, current_add),
}
)
changes = [c for c in changes if c["from"] or c["to"]]
return changes
@staticmethod
def _infer_change_type(deleted_tokens, added_tokens):
if deleted_tokens and not added_tokens:
return "deletion"
if added_tokens and not deleted_tokens:
return "insertion"
return "replacement"
def stylize(self, text: str) -> dict:
"""
Main method for TextDoctor.
Returns:
{
"original": ...,
"styled": ...,
"changes": [ {type, from, to}, ... ],
"confidence": float,
"agent": "style"
}
"""
styled = self._generate(text)
changes = self._diff_explanation(text, styled)
# simple heuristic confidence
change_ratio = len(changes) / max(len(text.split()), 1)
confidence = max(0.3, 1.0 - change_ratio)
return {
"original": text,
"styled": styled,
"changes": changes,
"confidence": round(confidence, 2),
"agent": "style",
}