Spaces:
Sleeping
Sleeping
File size: 5,768 Bytes
dbe2c62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
from typing import Dict
from . import Json_ChunkUnder
class RecursiveSummarizer:
"""
Bộ tóm tắt học thuật tiếng Việt theo hướng:
Extractive (chunk semantic) + Abstractive (recursive summarization)
"""
def __init__(
self,
tokenizer,
summarizer,
sum_device: str,
chunk_builder: Json_ChunkUnder.ChunkUndertheseaBuilder,
max_length: int = 256,
min_length: int = 64,
max_depth: int = 5
):
"""
tokenizer: AutoTokenizer đã load sẵn.
summarizer: AutoModelForSeq2SeqLM (ViT5 / BartPho / mT5)
sum_device: 'cuda' hoặc 'cpu'
chunk_builder: ChunkUndertheseaBuilder instance.
"""
self.tokenizer = tokenizer
self.model = summarizer
self.device = sum_device
self.chunk_builder = chunk_builder
self.max_length = max_length
self.min_length = min_length
self.max_depth = max_depth
# ============================================================
# 1️⃣ Hàm tóm tắt 1 đoạn
# ============================================================
def summarize_single(self, text: str) -> str:
"""
Tóm tắt 1 đoạn đơn bằng mô hình abstractive (ViT5/BartPho).
"""
if not text or len(text.strip()) == 0:
return ""
if "vit5" in str(self.model.__class__).lower():
input_text = f"vietnews: {text.strip()} </s>"
else:
input_text = text.strip()
try:
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=1024
).to(self.device)
with torch.no_grad():
summary_ids = self.model.generate(
**inputs,
max_length=self.max_length,
min_length=self.min_length,
num_beams=4,
no_repeat_ngram_size=3,
early_stopping=True
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary.strip()
except torch.cuda.OutOfMemoryError:
print("⚠️ GPU OOM – fallback sang CPU.")
self.model = self.model.to("cpu")
inputs = inputs.to("cpu")
with torch.no_grad():
summary_ids = self.model.generate(
**inputs,
max_length=self.max_length,
min_length=self.min_length,
num_beams=4
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"❌ Lỗi khi tóm tắt đoạn: {e}")
return ""
# ============================================================
# 2️⃣ Đệ quy tóm tắt văn bản dài
# ============================================================
def summarize_recursive(self, text: str, depth: int = 0, minInput: int = 256, maxInput: int = 1024) -> str:
"""
Đệ quy tóm tắt văn bản dài:
- <256 từ: giữ nguyên
- <1024 từ: tóm tắt trực tiếp
- >=1024 từ: chia chunk + tóm tắt từng phần → gộp → đệ quy
"""
word_count = len(text.split())
indent = " " * depth
print(f"{indent}🔹 Level {depth}: {word_count} từ")
# 1️⃣ Văn bản ngắn
if word_count < minInput:
return self.summarize_single(text)
else:
chunks = self.chunk_builder.build(text)
summaries = []
for item in chunks:
content = item.get("Content", "")
print(content)
idx = item.get("Index", "?")
wc = len(content.split())
if wc < 20:
print(f"{indent}⚠️ Bỏ qua chunk {idx} (quá ngắn)")
continue
print(f"{indent}🔸 Chunk {idx}: {wc} từ")
sub_summary = self.summarize_single(content)
if sub_summary:
summaries.append(sub_summary)
merged_summary = "\n".join(summaries)
merged_len = len(merged_summary.split())
print(f"{indent}🔁 Gộp {len(summaries)} summary → {merged_len} từ")
# Đệ quy nếu vẫn dài
if merged_len > 1024 and depth < self.max_depth:
return self.summarize_recursive(merged_summary, depth + 1)
else:
return merged_summary
# ============================================================
# 3️⃣ Hàm chính cho người dùng
# ============================================================
def summarize(self, full_text: str, minInput: int = 256, maxInput: int = 1024) -> Dict[str, str]:
"""
Giao diện chính:
- Nhận text dài
- Tự động chia chunk, tóm tắt, gộp
- Trả về dict gồm summary và thống kê
"""
original_len = len(full_text.split())
summary = self.summarize_recursive(full_text, depth = 0, minInput = minInput, maxInput = maxInput)
summary_len = len(summary.split())
ratio = round(summary_len / original_len, 3) if original_len else 0
print(f"\n✨ FINAL SUMMARY ({summary_len}/{original_len} từ, r={ratio}) ✨")
return {
"summary_text": summary,
"original_words": original_len,
"summary_words": summary_len,
"compression_ratio": ratio
}
|