File size: 5,768 Bytes
dbe2c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch

from typing import Dict

from . import Json_ChunkUnder


class RecursiveSummarizer:
    """
    Bộ tóm tắt học thuật tiếng Việt theo hướng:
    Extractive (chunk semantic) + Abstractive (recursive summarization)
    """

    def __init__(
        self,
        tokenizer,
        summarizer,
        sum_device: str,
        chunk_builder: Json_ChunkUnder.ChunkUndertheseaBuilder,
        max_length: int = 256,
        min_length: int = 64,
        max_depth: int = 5
    ):
        """
        tokenizer: AutoTokenizer đã load sẵn.
        summarizer: AutoModelForSeq2SeqLM (ViT5 / BartPho / mT5)
        sum_device: 'cuda' hoặc 'cpu'
        chunk_builder: ChunkUndertheseaBuilder instance.
        """
        self.tokenizer = tokenizer
        self.model = summarizer
        self.device = sum_device
        self.chunk_builder = chunk_builder
        self.max_length = max_length
        self.min_length = min_length
        self.max_depth = max_depth

    # ============================================================
    # 1️⃣ Hàm tóm tắt 1 đoạn
    # ============================================================
    def summarize_single(self, text: str) -> str:
        """
        Tóm tắt 1 đoạn đơn bằng mô hình abstractive (ViT5/BartPho).
        """
        if not text or len(text.strip()) == 0:
            return ""

        if "vit5" in str(self.model.__class__).lower():
            input_text = f"vietnews: {text.strip()} </s>"
        else:
            input_text = text.strip()

        try:
            inputs = self.tokenizer(
                input_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(self.device)

            with torch.no_grad():
                summary_ids = self.model.generate(
                    **inputs,
                    max_length=self.max_length,
                    min_length=self.min_length,
                    num_beams=4,
                    no_repeat_ngram_size=3,
                    early_stopping=True
                )

            summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            return summary.strip()

        except torch.cuda.OutOfMemoryError:
            print("⚠️ GPU OOM – fallback sang CPU.")
            self.model = self.model.to("cpu")
            inputs = inputs.to("cpu")

            with torch.no_grad():
                summary_ids = self.model.generate(
                    **inputs,
                    max_length=self.max_length,
                    min_length=self.min_length,
                    num_beams=4
                )

            return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()

        except Exception as e:
            print(f"❌ Lỗi khi tóm tắt đoạn: {e}")
            return ""

    # ============================================================
    # 2️⃣ Đệ quy tóm tắt văn bản dài
    # ============================================================
    def summarize_recursive(self, text: str, depth: int = 0, minInput: int = 256, maxInput: int = 1024) -> str:
        """
        Đệ quy tóm tắt văn bản dài:
        - <256 từ: giữ nguyên
        - <1024 từ: tóm tắt trực tiếp
        - >=1024 từ: chia chunk + tóm tắt từng phần → gộp → đệ quy
        """
        word_count = len(text.split())
        indent = "  " * depth
        print(f"{indent}🔹 Level {depth}: {word_count} từ")

        # 1️⃣ Văn bản ngắn
        if word_count < minInput:
            return self.summarize_single(text)

        else:
            chunks = self.chunk_builder.build(text)
            summaries = []

            for item in chunks:
                content = item.get("Content", "")
                print(content)
                idx = item.get("Index", "?")
                wc = len(content.split())

                if wc < 20:
                    print(f"{indent}⚠️ Bỏ qua chunk {idx} (quá ngắn)")
                    continue

                print(f"{indent}🔸 Chunk {idx}: {wc} từ")
                sub_summary = self.summarize_single(content)
                if sub_summary:
                    summaries.append(sub_summary)

            merged_summary = "\n".join(summaries)
            merged_len = len(merged_summary.split())
            print(f"{indent}🔁 Gộp {len(summaries)} summary → {merged_len} từ")

            # Đệ quy nếu vẫn dài
            if merged_len > 1024 and depth < self.max_depth:
                return self.summarize_recursive(merged_summary, depth + 1)
            else:
                return merged_summary

    # ============================================================
    # 3️⃣ Hàm chính cho người dùng
    # ============================================================
    def summarize(self, full_text: str, minInput: int = 256, maxInput: int = 1024) -> Dict[str, str]:
        """
        Giao diện chính:
        - Nhận text dài
        - Tự động chia chunk, tóm tắt, gộp
        - Trả về dict gồm summary và thống kê
        """
        original_len = len(full_text.split())
        summary = self.summarize_recursive(full_text, depth = 0, minInput = minInput, maxInput = maxInput)

        summary_len = len(summary.split())
        ratio = round(summary_len / original_len, 3) if original_len else 0

        print(f"\n✨ FINAL SUMMARY ({summary_len}/{original_len} từ, r={ratio}) ✨")
        return {
            "summary_text": summary,
            "original_words": original_len,
            "summary_words": summary_len,
            "compression_ratio": ratio
        }