Spaces:
Runtime error
Runtime error
| # /tools/reranker.py | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| class Reranker: | |
| def __init__(self, model_id="BAAI/bge-reranker-base"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| self.model.eval() | |
| def rerank(self, query: str, docs: list[str], top_k: int = 5): | |
| """Re-rank a list of doc texts for a query and return sorted indices.""" | |
| pairs = [(query, d) for d in docs] | |
| inputs = self.tokenizer( | |
| pairs, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ) | |
| scores = self.model(**inputs, return_dict=True).logits.squeeze(-1) | |
| sorted_ids = torch.argsort(scores, descending=True) | |
| top = sorted_ids[:top_k] | |
| return top.tolist(), scores[top].tolist() | |