doc-ai-api / Libraries /Faiss_Embedding.py
LongK171's picture
Add all
dbe2c62
import logging
import re, os
import torch
import faiss
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
from . import Common_MyUtils as MyUtils
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DirectFaissIndexer:
"""
1) FaissPath (.faiss): chỉ chứa vectors,
2) MapDataPath (.json): content + index,
3) MappingPath (.json): ánh xạ key <-> index.
"""
def __init__(
self,
indexer: Any,
device: str = "cpu",
batch_size: int = 32,
show_progress: bool = False,
flatten_mode: str = "split",
join_sep: str = "\n",
allowed_schema_types: Tuple[str, ...] = ("string", "array", "dict"),
max_chars_per_text: Optional[int] = None,
normalize: bool = True,
verbose: bool = False,
list_policy: str = "split", # "merge" | "split"
):
self.indexer = indexer
self.device = device
self.batch_size = batch_size
self.show_progress = show_progress
self.flatten_mode = flatten_mode
self.join_sep = join_sep
self.allowed_schema_types = allowed_schema_types
self.max_chars_per_text = max_chars_per_text
self.normalize = normalize
self.verbose = verbose
self.list_policy = list_policy
self._non_keep_pattern = re.compile(r"[^\w\s\(\)\.\,\;\:\-–]", flags=re.UNICODE)
# ---------- Schema & chọn trường ----------
@staticmethod
def _base_key_for_schema(key: str) -> str:
return re.sub(r"\[\d+\]", "", key)
def _eligible_by_schema(self, key: str, schema: Optional[Dict[str, str]]) -> bool:
if schema is None:
return True
base_key = self._base_key_for_schema(key)
typ = schema.get(base_key)
return (typ in self.allowed_schema_types) if typ is not None else False
# ---------- Tiền xử lý & flatten ----------
def _preprocess_data(self, data: Any) -> Any:
if MyUtils and hasattr(MyUtils, "preprocess_data"):
return MyUtils.preprocess_data(
data,
non_keep_pattern=self._non_keep_pattern,
max_chars_per_text=self.max_chars_per_text
)
def _flatten_json(self, data: Any) -> Dict[str, Any]:
"""
Flatten JSON theo list_policy:
- merge: gộp list/dict chứa chuỗi thành 1 đoạn text duy nhất
- split: tách từng phần tử
"""
# Nếu merge, xử lý JSON trước khi flatten
if self.list_policy == "merge":
def _merge_lists(obj):
if isinstance(obj, dict):
return {k: _merge_lists(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Nếu list chỉ chứa chuỗi / số, gộp lại
if all(isinstance(i, (str, int, float)) for i in obj):
return self.join_sep.join(map(str, obj))
# Nếu list chứa dict hoặc list lồng, đệ quy
return [_merge_lists(v) for v in obj]
else:
return obj
data = _merge_lists(data)
# Sau đó gọi MyUtils.flatten_json như cũ
return MyUtils.flatten_json(
data,
prefix="",
flatten_mode=self.flatten_mode,
join_sep=self.join_sep
)
# ---------- Encode (batch) với fallback OOM CPU ----------
def _encode_texts(self, texts: List[str]) -> torch.Tensor:
try:
embs = self.indexer.encode(
sentences=texts,
batch_size=self.batch_size,
convert_to_tensor=True,
device=self.device,
show_progress_bar=self.show_progress,
)
return embs
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("⚠️ CUDA OOM → fallback CPU.")
try:
self.indexer.to("cpu")
except Exception:
pass
embs = self.indexer.encode(
sentences=texts,
batch_size=self.batch_size,
convert_to_tensor=True,
device="cpu",
show_progress_bar=self.show_progress,
)
return embs
raise
# ---------- Build FAISS ----------
@staticmethod
def _l2_normalize(mat: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(mat, axis=1, keepdims=True)
norms[norms == 0.0] = 1.0
return mat / norms
def _create_faiss_index(self, matrix: np.ndarray) -> faiss.Index:
dim = int(matrix.shape[1])
index = faiss.IndexFlatIP(dim)
index.add(matrix.astype("float32"))
return index
# ================================================================
# Hàm lọc trùng nhưng vẫn gom nhóm chunk tương ứng
# ================================================================
def deduplicates_with_mask(
self,
pairs: List[Tuple[str, str]],
chunk_map: List[int]
) -> Tuple[List[Tuple[str, str]], List[List[int]]]:
assert len(pairs) == len(chunk_map), "pairs và chunk_map phải đồng dài"
seen_per_key: Dict[str, Dict[str, int]] = {}
# base_key -> text_norm -> index trong filtered_pairs
filtered_pairs: List[Tuple[str, str]] = []
chunk_groups: List[List[int]] = [] # song song với filtered_pairs
for (key, text), c in zip(pairs, chunk_map):
text_norm = text.strip()
if not text_norm:
continue
base_key = re.sub(r"\[\d+\]", "", key)
if base_key not in seen_per_key:
seen_per_key[base_key] = {}
# Nếu text đã xuất hiện → thêm chunk vào nhóm cũ
if text_norm in seen_per_key[base_key]:
idx = seen_per_key[base_key][text_norm]
if c not in chunk_groups[idx]:
chunk_groups[idx].append(c)
continue
# Nếu chưa có → tạo mới
seen_per_key[base_key][text_norm] = len(filtered_pairs)
filtered_pairs.append((key, text_norm))
chunk_groups.append([c])
return filtered_pairs, chunk_groups
# ================================================================
# Ghi ChunkMapping
# ================================================================
def write_chunk_mapping(self, MapChunkPath: str, SegmentPath: str, chunk_groups: List[List[int]]) -> None:
# Ghi chunk mapping dạng gọn: mỗi index một dòng
with open(MapChunkPath, "w", encoding="utf-8") as f:
f.write('{\n')
f.write(' "index_to_chunk": {\n')
items = list(enumerate(chunk_groups))
for i, (idx, group) in enumerate(items):
group_str = "[" + ", ".join(map(str, group)) + "]"
comma = "," if i < len(items) - 1 else ""
f.write(f' "{idx}": {group_str}{comma}\n')
f.write(' },\n')
f.write(' "meta": {\n')
f.write(f' "count": {len(chunk_groups)},\n')
f.write(f' "source": "{os.path.basename(SegmentPath)}"\n')
f.write(' }\n')
f.write('}\n')
# ================================================================
# Hàm build_from_json
# ================================================================
def build_from_json(
self,
SegmentPath: str,
SchemaDict: Optional[str],
FaissPath: str,
MapDataPath: str,
MappingPath: str,
MapChunkPath: Optional[str] = None,
) -> None:
assert os.path.exists(SegmentPath), f"Không thấy file JSON: {SegmentPath}"
os.makedirs(os.path.dirname(FaissPath), exist_ok=True)
os.makedirs(os.path.dirname(MapDataPath), exist_ok=True)
os.makedirs(os.path.dirname(MappingPath), exist_ok=True)
if MapChunkPath:
os.makedirs(os.path.dirname(MapChunkPath), exist_ok=True)
schema = SchemaDict
# 1️⃣ Read JSON
data_obj = MyUtils.read_json(SegmentPath)
data_list = data_obj if isinstance(data_obj, list) else [data_obj]
# 2️⃣ Flatten + lưu chunk_id
pair_list: List[Tuple[str, str]] = []
chunk_map: List[int] = []
for chunk_id, item in enumerate(data_list, start=1):
processed = self._preprocess_data(item)
flat = self._flatten_json(processed)
for k, v in flat.items():
if not self._eligible_by_schema(k, schema):
continue
if isinstance(v, str) and v.strip():
pair_list.append((k, v.strip()))
chunk_map.append(chunk_id)
if not pair_list:
raise ValueError("Không tìm thấy nội dung văn bản hợp lệ để encode.")
# 3️⃣ Loại trùng nhưng gom nhóm chunk
pair_list, chunk_groups = self.deduplicates_with_mask(pair_list, chunk_map)
# 4️⃣ Encode
keys = [k for k, _ in pair_list]
texts = [t for _, t in pair_list]
embs_t = self._encode_texts(texts)
embs = embs_t.detach().cpu().numpy()
if self.normalize:
embs = self._l2_normalize(embs)
# 5️⃣ FAISS
index = self._create_faiss_index(embs)
faiss.write_index(index, FaissPath)
logging.info(f"✅ Đã xây FAISS: {FaissPath}")
# 6️⃣ Mapping + MapData
index_to_key = {str(i): k for i, k in enumerate(keys)}
Mapping = {
"meta": {
"count": len(keys),
"dim": int(embs.shape[1]),
"metric": "ip",
"normalized": bool(self.normalize),
},
"index_to_key": index_to_key,
}
MapData = {
"items": [{"index": i, "key": k, "text": t} for i, (k, t) in enumerate(pair_list)],
"meta": {
"count": len(keys),
"flatten_mode": self.flatten_mode,
"schema_used": schema is not None,
"list_policy": self.list_policy
}
}
self.write_chunk_mapping(MapChunkPath, SegmentPath, chunk_groups)
return Mapping, MapData