FurqanIshaq commited on
Commit
8d21146
·
verified ·
1 Parent(s): da7d46c

Update agents/style_agent.py

Browse files
Files changed (1) hide show
  1. agents/style_agent.py +126 -104
agents/style_agent.py CHANGED
@@ -1,104 +1,126 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- import torch
3
- import difflib
4
-
5
- class StyleAgent:
6
- def __init__(self, model_name: str = "rajistics/informal_formal_style_transfer", device: str | None = None):
7
- """
8
- Style Agent
9
- - model_name: HF model id for informal -> formal style transfer
10
- - device: "cuda" or "cpu" (auto-detect if None)
11
- """
12
- self.model_name = model_name
13
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
-
16
- if device is None:
17
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
- else:
19
- self.device = device
20
-
21
- self.model.to(self.device)
22
-
23
- def _generate(self, text: str, max_length: int = 128, num_beams: int = 5) -> str:
24
- """
25
- Internal helper to call the model and get a more formal / professional version.
26
- """
27
- # Many style-transfer T5 models work directly on raw text
28
- inputs = self.tokenizer(
29
- text,
30
- return_tensors="pt",
31
- truncation=True,
32
- max_length=max_length
33
- ).to(self.device)
34
-
35
- with torch.no_grad():
36
- outputs = self.model.generate(
37
- **inputs,
38
- max_length=max_length,
39
- num_beams=num_beams,
40
- early_stopping=True
41
- )
42
-
43
- styled = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- return styled.strip()
45
-
46
- def _diff_explanation(self, original: str, styled: str):
47
- """
48
- Compare original vs styled sentence and return simple word-level changes.
49
- """
50
- diff = list(difflib.ndiff(original.split(), styled.split()))
51
- changes = []
52
- current_del = []
53
- current_add = []
54
-
55
- for token in diff:
56
- if token.startswith("- "):
57
- current_del.append(token[2:])
58
- elif token.startswith("+ "):
59
- current_add.append(token[2:])
60
- elif token.startswith(" "):
61
- if current_del or current_add:
62
- changes.append({
63
- "from": " ".join(current_del) if current_del else None,
64
- "to": " ".join(current_add) if current_add else None,
65
- "type": self._infer_change_type(current_del, current_add)
66
- })
67
- current_del, current_add = [], []
68
-
69
- if current_del or current_add:
70
- changes.append({
71
- "from": " ".join(current_del) if current_del else None,
72
- "to": " ".join(current_add) if current_add else None,
73
- "type": self._infer_change_type(current_del, current_add)
74
- })
75
-
76
- changes = [c for c in changes if c["from"] or c["to"]]
77
- return changes
78
-
79
- @staticmethod
80
- def _infer_change_type(deleted_tokens, added_tokens):
81
- if deleted_tokens and not added_tokens:
82
- return "deletion"
83
- if added_tokens and not deleted_tokens:
84
- return "insertion"
85
- return "replacement"
86
-
87
- def stylize(self, text: str) -> dict:
88
- """
89
- Main method for TextDoctor.
90
-
91
- Returns:
92
- {
93
- "original": ...,
94
- "styled": ...,
95
- "changes": [ {type, from, to}, ... ]
96
- }
97
- """
98
- styled = self._generate(text)
99
- changes = self._diff_explanation(text, styled)
100
- return {
101
- "original": text,
102
- "styled": styled,
103
- "changes": changes
104
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import difflib
5
+
6
+
7
+ class StyleAgent:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "rajistics/informal_formal_style_transfer",
11
+ device: Optional[str] = None,
12
+ ):
13
+ """
14
+ Style Agent
15
+ - model_name: HF model id for informal -> formal style transfer
16
+ - device: "cuda" or "cpu" (auto-detect if None)
17
+ """
18
+ self.model_name = model_name
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
+ if device is None:
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ else:
25
+ self.device = device
26
+
27
+ self.model.to(self.device)
28
+
29
+ def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
30
+ """
31
+ Internal helper to call the model and get a more formal / professional version.
32
+ """
33
+ # Many style-transfer T5 models work directly on raw text.
34
+ # If the model card suggests a prefix, add it here, e.g.:
35
+ # text = "formal: " + text
36
+
37
+ inputs = self.tokenizer(
38
+ text,
39
+ return_tensors="pt",
40
+ truncation=True,
41
+ max_length=max_length,
42
+ ).to(self.device)
43
+
44
+ with torch.no_grad():
45
+ outputs = self.model.generate(
46
+ **inputs,
47
+ max_length=max_length,
48
+ num_beams=num_beams,
49
+ early_stopping=True,
50
+ )
51
+
52
+ styled = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return styled.strip()
54
+
55
+ def _diff_explanation(self, original: str, styled: str):
56
+ """
57
+ Compare original vs styled sentence and return simple word-level changes.
58
+ """
59
+ diff = list(difflib.ndiff(original.split(), styled.split()))
60
+ changes = []
61
+ current_del = []
62
+ current_add = []
63
+
64
+ for token in diff:
65
+ if token.startswith("- "):
66
+ current_del.append(token[2:])
67
+ elif token.startswith("+ "):
68
+ current_add.append(token[2:])
69
+ elif token.startswith(" "):
70
+ if current_del or current_add:
71
+ changes.append(
72
+ {
73
+ "from": " ".join(current_del) if current_del else None,
74
+ "to": " ".join(current_add) if current_add else None,
75
+ "type": self._infer_change_type(current_del, current_add),
76
+ }
77
+ )
78
+ current_del, current_add = [], []
79
+
80
+ if current_del or current_add:
81
+ changes.append(
82
+ {
83
+ "from": " ".join(current_del) if current_del else None,
84
+ "to": " ".join(current_add) if current_add else None,
85
+ "type": self._infer_change_type(current_del, current_add),
86
+ }
87
+ )
88
+
89
+ changes = [c for c in changes if c["from"] or c["to"]]
90
+ return changes
91
+
92
+ @staticmethod
93
+ def _infer_change_type(deleted_tokens, added_tokens):
94
+ if deleted_tokens and not added_tokens:
95
+ return "deletion"
96
+ if added_tokens and not deleted_tokens:
97
+ return "insertion"
98
+ return "replacement"
99
+
100
+ def stylize(self, text: str) -> dict:
101
+ """
102
+ Main method for TextDoctor.
103
+
104
+ Returns:
105
+ {
106
+ "original": ...,
107
+ "styled": ...,
108
+ "changes": [ {type, from, to}, ... ],
109
+ "confidence": float,
110
+ "agent": "style"
111
+ }
112
+ """
113
+ styled = self._generate(text)
114
+ changes = self._diff_explanation(text, styled)
115
+
116
+ # simple heuristic confidence
117
+ change_ratio = len(changes) / max(len(text.split()), 1)
118
+ confidence = max(0.3, 1.0 - change_ratio)
119
+
120
+ return {
121
+ "original": text,
122
+ "styled": styled,
123
+ "changes": changes,
124
+ "confidence": round(confidence, 2),
125
+ "agent": "style",
126
+ }