DermalCare / tools /reranker.py
Manik Sheokand
aded recommendor
3e46c4a
raw
history blame contribute delete
966 Bytes
# /tools/reranker.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class Reranker:
def __init__(self, model_id="BAAI/bge-reranker-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
self.model.eval()
@torch.no_grad()
def rerank(self, query: str, docs: list[str], top_k: int = 5):
"""Re-rank a list of doc texts for a query and return sorted indices."""
pairs = [(query, d) for d in docs]
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
)
scores = self.model(**inputs, return_dict=True).logits.squeeze(-1)
sorted_ids = torch.argsort(scores, descending=True)
top = sorted_ids[:top_k]
return top.tolist(), scores[top].tolist()