File size: 966 Bytes
3e46c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# /tools/reranker.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

class Reranker:
    def __init__(self, model_id="BAAI/bge-reranker-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
        self.model.eval()

    @torch.no_grad()
    def rerank(self, query: str, docs: list[str], top_k: int = 5):
        """Re-rank a list of doc texts for a query and return sorted indices."""
        pairs = [(query, d) for d in docs]
        inputs = self.tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        )
        scores = self.model(**inputs, return_dict=True).logits.squeeze(-1)
        sorted_ids = torch.argsort(scores, descending=True)
        top = sorted_ids[:top_k]
        return top.tolist(), scores[top].tolist()