# /tools/reranker.py from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch class Reranker: def __init__(self, model_id="BAAI/bge-reranker-base"): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForSequenceClassification.from_pretrained(model_id) self.model.eval() @torch.no_grad() def rerank(self, query: str, docs: list[str], top_k: int = 5): """Re-rank a list of doc texts for a query and return sorted indices.""" pairs = [(query, d) for d in docs] inputs = self.tokenizer( pairs, padding=True, truncation=True, return_tensors="pt", max_length=512 ) scores = self.model(**inputs, return_dict=True).logits.squeeze(-1) sorted_ids = torch.argsort(scores, descending=True) top = sorted_ids[:top_k] return top.tolist(), scores[top].tolist()