Spaces:
Runtime error
Runtime error
Manik Sheokand
commited on
Commit
·
7b2b305
1
Parent(s):
cad6baf
Download FAISS index/metadata from HF Hub at runtime if missing
Browse files- tools/retriever.py +32 -13
tools/retriever.py
CHANGED
|
@@ -1,29 +1,21 @@
|
|
| 1 |
-
import os, json, faiss, numpy as np
|
| 2 |
from pathlib import Path
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
-
|
| 5 |
-
# from tools.reranker import Reranker
|
| 6 |
|
| 7 |
INDEX_PATH = os.environ.get("INDEX_PATH", "indexes/cosmetics_faiss_ip.index")
|
| 8 |
META_PATH = os.environ.get("META_PATH", "indexes/cosmetics_meta.json")
|
|
|
|
|
|
|
| 9 |
EMB_MODEL = os.environ.get("EMB_MODEL_ID", "intfloat/multilingual-e5-base")
|
| 10 |
|
| 11 |
_embedder = None
|
| 12 |
_index = None
|
| 13 |
_meta = None
|
| 14 |
-
# reranker = Reranker()
|
| 15 |
-
|
| 16 |
-
# def refined_search(query, k_initial=20, k_final=5):
|
| 17 |
-
# # 1. retrieve coarse top-20
|
| 18 |
-
# cands = search(query, k=k_initial)
|
| 19 |
-
# texts = [f"{c['brand_en']} {c['product_name_en']} {c['description_en']}" for c in cands]
|
| 20 |
-
|
| 21 |
-
# # 2. re-rank with cross-encoder
|
| 22 |
-
# idxs, scores = reranker.rerank(query, texts, top_k=k_final)
|
| 23 |
-
# return [cands[i] | {"rerank_score": scores[j]} for j, i in enumerate(idxs)]
|
| 24 |
|
| 25 |
def _load():
|
| 26 |
global _embedder, _index, _meta
|
|
|
|
| 27 |
if _embedder is None:
|
| 28 |
_embedder = SentenceTransformer(EMB_MODEL)
|
| 29 |
if _index is None:
|
|
@@ -32,6 +24,33 @@ def _load():
|
|
| 32 |
_meta = json.load(open(META_PATH, "r", encoding="utf-8"))
|
| 33 |
return _embedder, _index, _meta
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def search(query: str, k: int = 8):
|
| 36 |
emb, idx, meta = _load()
|
| 37 |
q = emb.encode([query], normalize_embeddings=True).astype("float32")
|
|
|
|
| 1 |
+
import os, json, faiss, numpy as np, shutil
|
| 2 |
from pathlib import Path
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
| 5 |
|
| 6 |
INDEX_PATH = os.environ.get("INDEX_PATH", "indexes/cosmetics_faiss_ip.index")
|
| 7 |
META_PATH = os.environ.get("META_PATH", "indexes/cosmetics_meta.json")
|
| 8 |
+
HUB_REPO_ID = os.environ.get("HUB_REPO_ID", os.environ.get("REPO_ID", "ColdSlim/DermalCare"))
|
| 9 |
+
HUB_REPO_TYPE = os.environ.get("HUB_REPO_TYPE", "space")
|
| 10 |
EMB_MODEL = os.environ.get("EMB_MODEL_ID", "intfloat/multilingual-e5-base")
|
| 11 |
|
| 12 |
_embedder = None
|
| 13 |
_index = None
|
| 14 |
_meta = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def _load():
|
| 17 |
global _embedder, _index, _meta
|
| 18 |
+
_ensure_index_files()
|
| 19 |
if _embedder is None:
|
| 20 |
_embedder = SentenceTransformer(EMB_MODEL)
|
| 21 |
if _index is None:
|
|
|
|
| 24 |
_meta = json.load(open(META_PATH, "r", encoding="utf-8"))
|
| 25 |
return _embedder, _index, _meta
|
| 26 |
|
| 27 |
+
def _ensure_index_files():
|
| 28 |
+
"""Ensure FAISS index and metadata exist locally; if missing, download from Hub.
|
| 29 |
+
|
| 30 |
+
Downloads from the Space repository's LFS using huggingface_hub.
|
| 31 |
+
"""
|
| 32 |
+
index_path = Path(INDEX_PATH)
|
| 33 |
+
meta_path = Path(META_PATH)
|
| 34 |
+
index_path.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
# Download index if missing
|
| 37 |
+
if not index_path.exists():
|
| 38 |
+
cached = hf_hub_download(
|
| 39 |
+
repo_id=HUB_REPO_ID,
|
| 40 |
+
repo_type=HUB_REPO_TYPE,
|
| 41 |
+
filename=f"indexes/{index_path.name}",
|
| 42 |
+
)
|
| 43 |
+
shutil.copy2(cached, index_path)
|
| 44 |
+
|
| 45 |
+
# Download metadata if missing
|
| 46 |
+
if not meta_path.exists():
|
| 47 |
+
cached_meta = hf_hub_download(
|
| 48 |
+
repo_id=HUB_REPO_ID,
|
| 49 |
+
repo_type=HUB_REPO_TYPE,
|
| 50 |
+
filename=f"indexes/{meta_path.name}",
|
| 51 |
+
)
|
| 52 |
+
shutil.copy2(cached_meta, meta_path)
|
| 53 |
+
|
| 54 |
def search(query: str, k: int = 8):
|
| 55 |
emb, idx, meta = _load()
|
| 56 |
q = emb.encode([query], normalize_embeddings=True).astype("float32")
|