Manik Sheokand commited on
Commit
7b2b305
·
1 Parent(s): cad6baf

Download FAISS index/metadata from HF Hub at runtime if missing

Browse files
Files changed (1) hide show
  1. tools/retriever.py +32 -13
tools/retriever.py CHANGED
@@ -1,29 +1,21 @@
1
- import os, json, faiss, numpy as np
2
  from pathlib import Path
3
  from sentence_transformers import SentenceTransformer
4
- # from tools.retriever import search
5
- # from tools.reranker import Reranker
6
 
7
  INDEX_PATH = os.environ.get("INDEX_PATH", "indexes/cosmetics_faiss_ip.index")
8
  META_PATH = os.environ.get("META_PATH", "indexes/cosmetics_meta.json")
 
 
9
  EMB_MODEL = os.environ.get("EMB_MODEL_ID", "intfloat/multilingual-e5-base")
10
 
11
  _embedder = None
12
  _index = None
13
  _meta = None
14
- # reranker = Reranker()
15
-
16
- # def refined_search(query, k_initial=20, k_final=5):
17
- # # 1. retrieve coarse top-20
18
- # cands = search(query, k=k_initial)
19
- # texts = [f"{c['brand_en']} {c['product_name_en']} {c['description_en']}" for c in cands]
20
-
21
- # # 2. re-rank with cross-encoder
22
- # idxs, scores = reranker.rerank(query, texts, top_k=k_final)
23
- # return [cands[i] | {"rerank_score": scores[j]} for j, i in enumerate(idxs)]
24
 
25
  def _load():
26
  global _embedder, _index, _meta
 
27
  if _embedder is None:
28
  _embedder = SentenceTransformer(EMB_MODEL)
29
  if _index is None:
@@ -32,6 +24,33 @@ def _load():
32
  _meta = json.load(open(META_PATH, "r", encoding="utf-8"))
33
  return _embedder, _index, _meta
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def search(query: str, k: int = 8):
36
  emb, idx, meta = _load()
37
  q = emb.encode([query], normalize_embeddings=True).astype("float32")
 
1
+ import os, json, faiss, numpy as np, shutil
2
  from pathlib import Path
3
  from sentence_transformers import SentenceTransformer
4
+ from huggingface_hub import hf_hub_download
 
5
 
6
  INDEX_PATH = os.environ.get("INDEX_PATH", "indexes/cosmetics_faiss_ip.index")
7
  META_PATH = os.environ.get("META_PATH", "indexes/cosmetics_meta.json")
8
+ HUB_REPO_ID = os.environ.get("HUB_REPO_ID", os.environ.get("REPO_ID", "ColdSlim/DermalCare"))
9
+ HUB_REPO_TYPE = os.environ.get("HUB_REPO_TYPE", "space")
10
  EMB_MODEL = os.environ.get("EMB_MODEL_ID", "intfloat/multilingual-e5-base")
11
 
12
  _embedder = None
13
  _index = None
14
  _meta = None
 
 
 
 
 
 
 
 
 
 
15
 
16
  def _load():
17
  global _embedder, _index, _meta
18
+ _ensure_index_files()
19
  if _embedder is None:
20
  _embedder = SentenceTransformer(EMB_MODEL)
21
  if _index is None:
 
24
  _meta = json.load(open(META_PATH, "r", encoding="utf-8"))
25
  return _embedder, _index, _meta
26
 
27
+ def _ensure_index_files():
28
+ """Ensure FAISS index and metadata exist locally; if missing, download from Hub.
29
+
30
+ Downloads from the Space repository's LFS using huggingface_hub.
31
+ """
32
+ index_path = Path(INDEX_PATH)
33
+ meta_path = Path(META_PATH)
34
+ index_path.parent.mkdir(parents=True, exist_ok=True)
35
+
36
+ # Download index if missing
37
+ if not index_path.exists():
38
+ cached = hf_hub_download(
39
+ repo_id=HUB_REPO_ID,
40
+ repo_type=HUB_REPO_TYPE,
41
+ filename=f"indexes/{index_path.name}",
42
+ )
43
+ shutil.copy2(cached, index_path)
44
+
45
+ # Download metadata if missing
46
+ if not meta_path.exists():
47
+ cached_meta = hf_hub_download(
48
+ repo_id=HUB_REPO_ID,
49
+ repo_type=HUB_REPO_TYPE,
50
+ filename=f"indexes/{meta_path.name}",
51
+ )
52
+ shutil.copy2(cached_meta, meta_path)
53
+
54
  def search(query: str, k: int = 8):
55
  emb, idx, meta = _load()
56
  q = emb.encode([query], normalize_embeddings=True).astype("float32")