import torch import torch.nn.functional as F from typing import List, Tuple, Optional import numpy as np from dataclasses import dataclass @dataclass class BeamHypothesis: """Single hypothesis in beam search""" tokens: List[int] log_prob: float finished: bool = False class BeamSearch: """Beam search decoder for transformer models""" def __init__(self, beam_size: int = 4, length_penalty: float = 0.6, coverage_penalty: float = 0.0, no_repeat_ngram_size: int = 3): self.beam_size = beam_size self.length_penalty = length_penalty self.coverage_penalty = coverage_penalty self.no_repeat_ngram_size = no_repeat_ngram_size # Changed default from 0 to 3 def search(self, model, src: torch.Tensor, max_length: int = 100, bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]: """ Perform beam search decoding Args: model: Transformer model src: Source sequence [batch_size, src_len] max_length: Maximum decoding length bos_id: Beginning of sequence token eos_id: End of sequence token pad_id: Padding token Returns: List of decoded sequences """ batch_size = src.size(0) device = src.device # Encode source src_mask = (src != pad_id).unsqueeze(1).unsqueeze(2) memory = model.encode(src, src_mask) # Initialize beams beams = [[BeamHypothesis([bos_id], 0.0)] for _ in range(batch_size)] for step in range(max_length - 1): all_candidates = [] for batch_idx in range(batch_size): # NEW: Stop if the BEST beam (first one after sorting) is finished if beams[batch_idx] and beams[batch_idx][0].finished: continue # Also skip if all beams are finished if all(hyp.finished for hyp in beams[batch_idx]): continue # Prepare input for all beams beam_tokens = [] beam_indices = [] for beam_idx, hypothesis in enumerate(beams[batch_idx]): if not hypothesis.finished: beam_tokens.append(hypothesis.tokens) beam_indices.append(beam_idx) if not beam_tokens: continue # Create batch of sequences tgt = torch.tensor(beam_tokens, device=device) # Decode tgt_mask = torch.ones(len(beam_tokens), 1, tgt.size(1), tgt.size(1), device=device) tgt_mask = torch.tril(tgt_mask) # Expand memory for beam size expanded_memory = memory[batch_idx:batch_idx+1].expand(len(beam_tokens), -1, -1) expanded_src_mask = src_mask[batch_idx:batch_idx+1].expand(len(beam_tokens), -1, -1, -1) # Get predictions decoder_output = model.decode(tgt, expanded_memory, tgt_mask, expanded_src_mask) logits = model.output_projection(decoder_output[:, -1, :]) log_probs = F.log_softmax(logits, dim=-1) # Get top k tokens for each beam vocab_size = log_probs.size(-1) top_log_probs, top_indices = torch.topk(log_probs, min(self.beam_size, vocab_size)) # Create new candidates candidates = [] for beam_local_idx, (beam_idx, beam_log_probs, beam_indices_local) in enumerate( zip(beam_indices, top_log_probs, top_indices)): hypothesis = beams[batch_idx][beam_idx] for token_rank, (token_log_prob, token_id) in enumerate( zip(beam_log_probs, beam_indices_local)): new_tokens = hypothesis.tokens + [token_id.item()] # Apply no-repeat penalty if self._has_repeated_ngram(new_tokens): continue new_log_prob = hypothesis.log_prob + token_log_prob.item() # Apply length penalty score = self._apply_length_penalty(new_log_prob, len(new_tokens)) candidates.append(( score, BeamHypothesis( tokens=new_tokens, log_prob=new_log_prob, finished=(token_id.item() == eos_id) ) )) # Select top beam_size candidates candidates.sort(key=lambda x: x[0], reverse=True) new_beams = [] for score, hypothesis in candidates[:self.beam_size]: new_beams.append(hypothesis) # If we have no candidates, keep the old beams if not new_beams: new_beams = beams[batch_idx] beams[batch_idx] = new_beams # Extract best sequences results = [] for batch_idx in range(batch_size): # Sort by score sorted_hyps = sorted( beams[batch_idx], key=lambda h: self._apply_length_penalty(h.log_prob, len(h.tokens)), reverse=True ) # Get best hypothesis best_hyp = sorted_hyps[0] results.append(best_hyp.tokens) return results def _apply_length_penalty(self, log_prob: float, length: int) -> float: """Apply length penalty to score""" return log_prob / (length ** self.length_penalty) def _has_repeated_ngram(self, tokens: List[int]) -> bool: """Check if sequence has repeated n-grams""" if self.no_repeat_ngram_size <= 0: return False ngrams = set() for i in range(len(tokens) - self.no_repeat_ngram_size + 1): ngram = tuple(tokens[i:i + self.no_repeat_ngram_size]) if ngram in ngrams: return True ngrams.add(ngram) return False class GreedyDecoder: """Simple greedy decoder for fast inference""" @staticmethod def decode(model, src: torch.Tensor, max_length: int = 100, bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]: """ Perform greedy decoding Args: model: Transformer model src: Source sequence [batch_size, src_len] max_length: Maximum decoding length bos_id: Beginning of sequence token eos_id: End of sequence token pad_id: Padding token Returns: List of decoded sequences """ batch_size = src.size(0) device = src.device # Use model's built-in generate method with torch.no_grad(): translations = model.generate( src, max_length=max_length, bos_id=bos_id, eos_id=eos_id ) # Convert to list results = [] for i in range(batch_size): tokens = translations[i].cpu().tolist() # Remove padding and special tokens if needed if eos_id in tokens: eos_idx = tokens.index(eos_id) tokens = tokens[:eos_idx + 1] results.append(tokens) return results