FurqanIshaq commited on
Commit
da7d46c
·
verified ·
1 Parent(s): 258b3c8

Update agents/grammar_agent.py

Browse files
Files changed (1) hide show
  1. agents/grammar_agent.py +129 -114
agents/grammar_agent.py CHANGED
@@ -1,114 +1,129 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- import torch
3
- import difflib
4
-
5
- class GrammarAgent:
6
- def __init__(self, model_name: str = "vennify/t5-base-grammar-correction", device: str | None = None):
7
- """
8
- Grammar Agent
9
- - model_name: HF model id
10
- - device: "cuda" or "cpu" (auto-detect if None)
11
- """
12
- self.model_name = model_name
13
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
-
16
- if device is None:
17
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
- else:
19
- self.device = device
20
-
21
- self.model.to(self.device)
22
-
23
- def _generate(self, text: str, max_length: int = 128, num_beams: int = 5) -> str:
24
- """
25
- Internal helper to call the model and get corrected text.
26
- """
27
- # T5-style models expect a task prefix
28
- prefixed = "grammar: " + text
29
-
30
- inputs = self.tokenizer(
31
- prefixed,
32
- return_tensors="pt",
33
- truncation=True,
34
- max_length=max_length
35
- ).to(self.device)
36
-
37
- with torch.no_grad():
38
- outputs = self.model.generate(
39
- **inputs,
40
- max_length=max_length,
41
- num_beams=num_beams,
42
- early_stopping=True
43
- )
44
-
45
- corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
46
- return corrected.strip()
47
-
48
- def _diff_explanation(self, original: str, corrected: str):
49
- """
50
- Create a simple, human-readable explanation of changes.
51
- Returns a list of {type, from, to}.
52
- """
53
- diff = list(difflib.ndiff(original.split(), corrected.split()))
54
- changes = []
55
- current_del = []
56
- current_add = []
57
-
58
- for token in diff:
59
- if token.startswith("- "):
60
- current_del.append(token[2:])
61
- elif token.startswith("+ "):
62
- current_add.append(token[2:])
63
- elif token.startswith(" "):
64
- # flush any pending change block
65
- if current_del or current_add:
66
- changes.append({
67
- "from": " ".join(current_del) if current_del else None,
68
- "to": " ".join(current_add) if current_add else None,
69
- "type": self._infer_change_type(current_del, current_add)
70
- })
71
- current_del, current_add = [], []
72
-
73
- # flush at end
74
- if current_del or current_add:
75
- changes.append({
76
- "from": " ".join(current_del) if current_del else None,
77
- "to": " ".join(current_add) if current_add else None,
78
- "type": self._infer_change_type(current_del, current_add)
79
- })
80
-
81
- # remove empty no-op changes
82
- changes = [c for c in changes if c["from"] or c["to"]]
83
- return changes
84
-
85
- @staticmethod
86
- def _infer_change_type(deleted_tokens, added_tokens):
87
- """
88
- Very simple heuristic for change type.
89
- You can later improve this with more logic.
90
- """
91
- if deleted_tokens and not added_tokens:
92
- return "deletion"
93
- if added_tokens and not deleted_tokens:
94
- return "insertion"
95
- return "replacement"
96
-
97
- def correct(self, text: str) -> dict:
98
- """
99
- Main method your system will call.
100
-
101
- Returns a dict:
102
- {
103
- "original": ...,
104
- "corrected": ...,
105
- "changes": [ {type, from, to}, ... ]
106
- }
107
- """
108
- corrected = self._generate(text)
109
- changes = self._diff_explanation(text, corrected)
110
- return {
111
- "original": text,
112
- "corrected": corrected,
113
- "changes": changes
114
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import difflib
5
+
6
+
7
+ class GrammarAgent:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "vennify/t5-base-grammar-correction",
11
+ device: Optional[str] = None,
12
+ ):
13
+ """
14
+ Grammar Agent
15
+ - model_name: HF model id
16
+ - device: "cuda" or "cpu" (auto-detect if None)
17
+ """
18
+ self.model_name = model_name
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
+ if device is None:
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ else:
25
+ self.device = device
26
+
27
+ self.model.to(self.device)
28
+
29
+ def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
30
+ """
31
+ Internal helper to call the model and get corrected text.
32
+ """
33
+ prefixed = "grammar: " + text # T5-style task prefix
34
+
35
+ inputs = self.tokenizer(
36
+ prefixed,
37
+ return_tensors="pt",
38
+ truncation=True,
39
+ max_length=max_length,
40
+ ).to(self.device)
41
+
42
+ with torch.no_grad():
43
+ outputs = self.model.generate(
44
+ **inputs,
45
+ max_length=max_length,
46
+ num_beams=num_beams,
47
+ early_stopping=True,
48
+ )
49
+
50
+ corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return corrected.strip()
52
+
53
+ def _diff_explanation(self, original: str, corrected: str):
54
+ """
55
+ Create a simple, human-readable explanation of changes.
56
+ Returns a list of {type, from, to}.
57
+ """
58
+ diff = list(difflib.ndiff(original.split(), corrected.split()))
59
+ changes = []
60
+ current_del = []
61
+ current_add = []
62
+
63
+ for token in diff:
64
+ if token.startswith("- "):
65
+ current_del.append(token[2:])
66
+ elif token.startswith("+ "):
67
+ current_add.append(token[2:])
68
+ elif token.startswith(" "):
69
+ if current_del or current_add:
70
+ changes.append(
71
+ {
72
+ "from": " ".join(current_del) if current_del else None,
73
+ "to": " ".join(current_add) if current_add else None,
74
+ "type": self._infer_change_type(current_del, current_add),
75
+ }
76
+ )
77
+ current_del, current_add = [], []
78
+
79
+ if current_del or current_add:
80
+ changes.append(
81
+ {
82
+ "from": " ".join(current_del) if current_del else None,
83
+ "to": " ".join(current_add) if current_add else None,
84
+ "type": self._infer_change_type(current_del, current_add),
85
+ }
86
+ )
87
+
88
+ changes = [c for c in changes if c["from"] or c["to"]]
89
+ return changes
90
+
91
+ @staticmethod
92
+ def _infer_change_type(deleted_tokens, added_tokens):
93
+ """
94
+ Very simple heuristic for change type.
95
+ You can later improve this with more logic.
96
+ """
97
+ if deleted_tokens and not added_tokens:
98
+ return "deletion"
99
+ if added_tokens and not deleted_tokens:
100
+ return "insertion"
101
+ return "replacement"
102
+
103
+ def correct(self, text: str) -> dict:
104
+ """
105
+ Main method your system will call.
106
+
107
+ Returns a dict:
108
+ {
109
+ "original": ...,
110
+ "corrected": ...,
111
+ "changes": [ {type, from, to}, ... ],
112
+ "confidence": float,
113
+ "agent": "grammar"
114
+ }
115
+ """
116
+ corrected = self._generate(text)
117
+ changes = self._diff_explanation(text, corrected)
118
+
119
+ # simple heuristic confidence based on how much was changed
120
+ change_ratio = len(changes) / max(len(text.split()), 1)
121
+ confidence = max(0.3, 1.0 - change_ratio)
122
+
123
+ return {
124
+ "original": text,
125
+ "corrected": corrected,
126
+ "changes": changes,
127
+ "confidence": round(confidence, 2),
128
+ "agent": "grammar",
129
+ }