Spaces:
Sleeping
Sleeping
File size: 5,011 Bytes
dbe2c62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import faiss
import numpy as np
from typing import Dict, List, Any, Optional
from sentence_transformers import SentenceTransformer, CrossEncoder
class SemanticSearchEngine:
def __init__(
self,
indexer: SentenceTransformer,
reranker: Optional[CrossEncoder] = None,
device: str = "cuda",
normalize: bool = True,
top_k: int = 20,
rerank_k: int = 10,
rerank_batch_size: int = 16,
):
self.device = device
self.normalize = normalize
self.top_k = int(top_k)
self.rerank_k = int(rerank_k)
self.rerank_batch_size = int(rerank_batch_size)
# ✅ Nhận trực tiếp model đã load
if not isinstance(indexer, SentenceTransformer):
raise TypeError("indexer phải là SentenceTransformer đã load sẵn.")
self._indexer = indexer
# Reranker là tùy chọn
if reranker and not isinstance(reranker, CrossEncoder):
raise TypeError("reranker phải là CrossEncoder hoặc None.")
self.reranker = reranker
# ---------------------------
# Tiện ích nội bộ
# ---------------------------
@staticmethod
def _l2_normalize(x: np.ndarray, axis: int = 1, eps: float = 1e-12) -> np.ndarray:
denom = np.linalg.norm(x, axis=axis, keepdims=True)
denom = np.maximum(denom, eps)
return x / denom
@staticmethod
def _build_idx_maps(Mapping: Dict[str, Any], MapData: Dict[str, Any]):
"""Tạo ánh xạ index→text và index→key"""
items = MapData.get("items", [])
idx2text = {int(item["index"]): item.get("text", None) for item in items}
raw_i2k = Mapping.get("index_to_key", {})
idx2key = {int(i): k for i, k in raw_i2k.items()}
return idx2text, idx2key
# ---------------------------
# 1️⃣ SEARCH: FAISS vector search
# ---------------------------
def search(
self,
query: str,
faissIndex: "faiss.Index", # type: ignore
Mapping: Dict[str, Any],
MapData: Dict[str, Any],
MapChunk: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
query_embedding: Optional[np.ndarray] = None,
) -> List[Dict[str, Any]]:
"""
Trả về:
[{"index":..., "key":..., "text":..., "faiss_score":...}, ...]
"""
k = int(top_k or self.top_k)
# 1. Encode truy vấn (hoặc dùng sẵn embedding)
if query_embedding is None:
q = self._indexer.encode(
[query], convert_to_tensor=True, device=str(self.device)
)
q = q.detach().cpu().numpy().astype("float32")
else:
q = np.asarray(query_embedding, dtype="float32")
if q.ndim == 1:
q = q[None, :]
# 2. Normalize nếu dùng cosine
if self.normalize:
q = self._l2_normalize(q)
# 3. Search FAISS
scores, ids = faissIndex.search(q, k)
idx2text, idx2key = self._build_idx_maps(Mapping, MapData)
# 4. Mapping kết quả
chunk_map = MapChunk.get("index_to_chunk", {}) if MapChunk else {}
results = []
for score, idx in zip(scores[0].tolist(), ids[0].tolist()):
chunk_ids = chunk_map.get(str(idx), [])
results.append({
"index": int(idx),
"key": idx2key.get(int(idx)),
"text": idx2text.get(int(idx)),
"faiss_score": float(score),
"chunk_ids": chunk_ids,
})
return results
# ---------------------------
# 2️⃣ RERANK: CrossEncoder rerank
# ---------------------------
def rerank(
self,
query: str,
results: List[Dict[str, Any]],
top_k: Optional[int] = None,
show_progress: bool = False,
) -> List[Dict[str, Any]]:
"""
Xếp hạng lại kết quả bằng CrossEncoder (nếu có).
Trả về danh sách top_k kết quả đã rerank.
"""
if not results:
return []
if self.reranker is None:
raise ValueError("⚠️ Không có reranker được cung cấp khi khởi tạo.")
k = int(top_k or self.rerank_k)
pairs = []
valid_indices = []
for i, r in enumerate(results):
text = r.get("text")
if isinstance(text, str) and text.strip():
pairs.append([query, text])
valid_indices.append(i)
if not pairs:
return []
scores = self.reranker.predict(
pairs, batch_size=self.rerank_batch_size, show_progress_bar=show_progress
)
for i, s in zip(valid_indices, scores):
results[i]["rerank_score"] = float(s)
reranked = [r for r in results if "rerank_score" in r]
reranked.sort(key=lambda x: x["rerank_score"], reverse=True)
return reranked[:k]
|