File size: 4,216 Bytes
da7d46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import difflib


class GrammarAgent:
    def __init__(
        self,
        model_name: str = "vennify/t5-base-grammar-correction",
        device: Optional[str] = None,
    ):
        """
        Grammar Agent
        - model_name: HF model id
        - device: "cuda" or "cpu" (auto-detect if None)
        """
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.model.to(self.device)

    def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
        """
        Internal helper to call the model and get corrected text.
        """
        prefixed = "grammar: " + text  # T5-style task prefix

        inputs = self.tokenizer(
            prefixed,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
            )

        corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return corrected.strip()

    def _diff_explanation(self, original: str, corrected: str):
        """
        Create a simple, human-readable explanation of changes.
        Returns a list of {type, from, to}.
        """
        diff = list(difflib.ndiff(original.split(), corrected.split()))
        changes = []
        current_del = []
        current_add = []

        for token in diff:
            if token.startswith("- "):
                current_del.append(token[2:])
            elif token.startswith("+ "):
                current_add.append(token[2:])
            elif token.startswith("  "):
                if current_del or current_add:
                    changes.append(
                        {
                            "from": " ".join(current_del) if current_del else None,
                            "to": " ".join(current_add) if current_add else None,
                            "type": self._infer_change_type(current_del, current_add),
                        }
                    )
                    current_del, current_add = [], []

        if current_del or current_add:
            changes.append(
                {
                    "from": " ".join(current_del) if current_del else None,
                    "to": " ".join(current_add) if current_add else None,
                    "type": self._infer_change_type(current_del, current_add),
                }
            )

        changes = [c for c in changes if c["from"] or c["to"]]
        return changes

    @staticmethod
    def _infer_change_type(deleted_tokens, added_tokens):
        """
        Very simple heuristic for change type.
        You can later improve this with more logic.
        """
        if deleted_tokens and not added_tokens:
            return "deletion"
        if added_tokens and not deleted_tokens:
            return "insertion"
        return "replacement"

    def correct(self, text: str) -> dict:
        """
        Main method your system will call.

        Returns a dict:
        {
            "original": ...,
            "corrected": ...,
            "changes": [ {type, from, to}, ... ],
            "confidence": float,
            "agent": "grammar"
        }
        """
        corrected = self._generate(text)
        changes = self._diff_explanation(text, corrected)

        # simple heuristic confidence based on how much was changed
        change_ratio = len(changes) / max(len(text.split()), 1)
        confidence = max(0.3, 1.0 - change_ratio)

        return {
            "original": text,
            "corrected": corrected,
            "changes": changes,
            "confidence": round(confidence, 2),
            "agent": "grammar",
        }