File size: 4,156 Bytes
8d21146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib


class StyleAgent:
    def __init__(
        self,
        model_name: str = "rajistics/informal_formal_style_transfer",
        device: Optional[str] = None,
    ):
        """
        Style Agent
        - model_name: HF model id for informal -> formal style transfer
        - device: "cuda" or "cpu" (auto-detect if None)
        """
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.model.to(self.device)

    def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
        """
        Internal helper to call the model and get a more formal / professional version.
        """
        # Many style-transfer T5 models work directly on raw text.
        # If the model card suggests a prefix, add it here, e.g.:
        # text = "formal: " + text

        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
            )

        styled = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return styled.strip()

    def _diff_explanation(self, original: str, styled: str):
        """
        Compare original vs styled sentence and return simple word-level changes.
        """
        diff = list(difflib.ndiff(original.split(), styled.split()))
        changes = []
        current_del = []
        current_add = []

        for token in diff:
            if token.startswith("- "):
                current_del.append(token[2:])
            elif token.startswith("+ "):
                current_add.append(token[2:])
            elif token.startswith("  "):
                if current_del or current_add:
                    changes.append(
                        {
                            "from": " ".join(current_del) if current_del else None,
                            "to": " ".join(current_add) if current_add else None,
                            "type": self._infer_change_type(current_del, current_add),
                        }
                    )
                    current_del, current_add = [], []

        if current_del or current_add:
            changes.append(
                {
                    "from": " ".join(current_del) if current_del else None,
                    "to": " ".join(current_add) if current_add else None,
                    "type": self._infer_change_type(current_del, current_add),
                }
            )

        changes = [c for c in changes if c["from"] or c["to"]]
        return changes

    @staticmethod
    def _infer_change_type(deleted_tokens, added_tokens):
        if deleted_tokens and not added_tokens:
            return "deletion"
        if added_tokens and not deleted_tokens:
            return "insertion"
        return "replacement"

    def stylize(self, text: str) -> dict:
        """
        Main method for TextDoctor.

        Returns:
        {
            "original": ...,
            "styled": ...,
            "changes": [ {type, from, to}, ... ],
            "confidence": float,
            "agent": "style"
        }
        """
        styled = self._generate(text)
        changes = self._diff_explanation(text, styled)

        # simple heuristic confidence
        change_ratio = len(changes) / max(len(text.split()), 1)
        confidence = max(0.3, 1.0 - change_ratio)

        return {
            "original": text,
            "styled": styled,
            "changes": changes,
            "confidence": round(confidence, 2),
            "agent": "style",
        }