Spaces:
Sleeping
Sleeping
finalized
Browse files- inference/beam_search.py +17 -7
- inference/translate.py +6 -6
inference/beam_search.py
CHANGED
|
@@ -15,11 +15,11 @@ class BeamSearch:
|
|
| 15 |
"""Beam search decoder for transformer models"""
|
| 16 |
|
| 17 |
def __init__(self, beam_size: int = 4, length_penalty: float = 0.6,
|
| 18 |
-
coverage_penalty: float = 0.0, no_repeat_ngram_size: int =
|
| 19 |
self.beam_size = beam_size
|
| 20 |
self.length_penalty = length_penalty
|
| 21 |
self.coverage_penalty = coverage_penalty
|
| 22 |
-
self.no_repeat_ngram_size = no_repeat_ngram_size
|
| 23 |
|
| 24 |
def search(self, model, src: torch.Tensor, max_length: int = 100,
|
| 25 |
bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]:
|
|
@@ -51,7 +51,11 @@ class BeamSearch:
|
|
| 51 |
all_candidates = []
|
| 52 |
|
| 53 |
for batch_idx in range(batch_size):
|
| 54 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
if all(hyp.finished for hyp in beams[batch_idx]):
|
| 56 |
continue
|
| 57 |
|
|
@@ -98,19 +102,21 @@ class BeamSearch:
|
|
| 98 |
for token_rank, (token_log_prob, token_id) in enumerate(
|
| 99 |
zip(beam_log_probs, beam_indices_local)):
|
| 100 |
|
|
|
|
|
|
|
| 101 |
# Apply no-repeat penalty
|
| 102 |
-
if self._has_repeated_ngram(
|
| 103 |
continue
|
| 104 |
|
| 105 |
new_log_prob = hypothesis.log_prob + token_log_prob.item()
|
| 106 |
|
| 107 |
# Apply length penalty
|
| 108 |
-
score = self._apply_length_penalty(new_log_prob, len(
|
| 109 |
|
| 110 |
candidates.append((
|
| 111 |
score,
|
| 112 |
BeamHypothesis(
|
| 113 |
-
tokens=
|
| 114 |
log_prob=new_log_prob,
|
| 115 |
finished=(token_id.item() == eos_id)
|
| 116 |
)
|
|
@@ -123,6 +129,10 @@ class BeamSearch:
|
|
| 123 |
for score, hypothesis in candidates[:self.beam_size]:
|
| 124 |
new_beams.append(hypothesis)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
beams[batch_idx] = new_beams
|
| 127 |
|
| 128 |
# Extract best sequences
|
|
@@ -202,4 +212,4 @@ class GreedyDecoder:
|
|
| 202 |
tokens = tokens[:eos_idx + 1]
|
| 203 |
results.append(tokens)
|
| 204 |
|
| 205 |
-
return results
|
|
|
|
| 15 |
"""Beam search decoder for transformer models"""
|
| 16 |
|
| 17 |
def __init__(self, beam_size: int = 4, length_penalty: float = 0.6,
|
| 18 |
+
coverage_penalty: float = 0.0, no_repeat_ngram_size: int = 3):
|
| 19 |
self.beam_size = beam_size
|
| 20 |
self.length_penalty = length_penalty
|
| 21 |
self.coverage_penalty = coverage_penalty
|
| 22 |
+
self.no_repeat_ngram_size = no_repeat_ngram_size # Changed default from 0 to 3
|
| 23 |
|
| 24 |
def search(self, model, src: torch.Tensor, max_length: int = 100,
|
| 25 |
bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]:
|
|
|
|
| 51 |
all_candidates = []
|
| 52 |
|
| 53 |
for batch_idx in range(batch_size):
|
| 54 |
+
# NEW: Stop if the BEST beam (first one after sorting) is finished
|
| 55 |
+
if beams[batch_idx] and beams[batch_idx][0].finished:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
# Also skip if all beams are finished
|
| 59 |
if all(hyp.finished for hyp in beams[batch_idx]):
|
| 60 |
continue
|
| 61 |
|
|
|
|
| 102 |
for token_rank, (token_log_prob, token_id) in enumerate(
|
| 103 |
zip(beam_log_probs, beam_indices_local)):
|
| 104 |
|
| 105 |
+
new_tokens = hypothesis.tokens + [token_id.item()]
|
| 106 |
+
|
| 107 |
# Apply no-repeat penalty
|
| 108 |
+
if self._has_repeated_ngram(new_tokens):
|
| 109 |
continue
|
| 110 |
|
| 111 |
new_log_prob = hypothesis.log_prob + token_log_prob.item()
|
| 112 |
|
| 113 |
# Apply length penalty
|
| 114 |
+
score = self._apply_length_penalty(new_log_prob, len(new_tokens))
|
| 115 |
|
| 116 |
candidates.append((
|
| 117 |
score,
|
| 118 |
BeamHypothesis(
|
| 119 |
+
tokens=new_tokens,
|
| 120 |
log_prob=new_log_prob,
|
| 121 |
finished=(token_id.item() == eos_id)
|
| 122 |
)
|
|
|
|
| 129 |
for score, hypothesis in candidates[:self.beam_size]:
|
| 130 |
new_beams.append(hypothesis)
|
| 131 |
|
| 132 |
+
# If we have no candidates, keep the old beams
|
| 133 |
+
if not new_beams:
|
| 134 |
+
new_beams = beams[batch_idx]
|
| 135 |
+
|
| 136 |
beams[batch_idx] = new_beams
|
| 137 |
|
| 138 |
# Extract best sequences
|
|
|
|
| 212 |
tokens = tokens[:eos_idx + 1]
|
| 213 |
results.append(tokens)
|
| 214 |
|
| 215 |
+
return results
|
inference/translate.py
CHANGED
|
@@ -37,16 +37,16 @@ class Translator:
|
|
| 37 |
self.eos_id = self.sp.eos_id()
|
| 38 |
self.pad_id = self.sp.pad_id()
|
| 39 |
|
| 40 |
-
# Decoder
|
| 41 |
self.use_beam_search = use_beam_search
|
| 42 |
if use_beam_search:
|
| 43 |
-
self.decoder = BeamSearch(beam_size=beam_size)
|
| 44 |
|
| 45 |
logger.info(f"Translator initialized on {self.device}")
|
| 46 |
logger.info(f"Vocab size: {self.sp.vocab_size()}")
|
| 47 |
logger.info(f"Using {'beam search' if use_beam_search else 'greedy'} decoding")
|
| 48 |
|
| 49 |
-
def translate(self, text: str, max_length: int =
|
| 50 |
"""
|
| 51 |
Translate a single text
|
| 52 |
|
|
@@ -93,7 +93,7 @@ class Translator:
|
|
| 93 |
|
| 94 |
return translated_text
|
| 95 |
|
| 96 |
-
def translate_batch(self, texts: List[str], max_length: int =
|
| 97 |
"""
|
| 98 |
Translate multiple texts in batch
|
| 99 |
|
|
@@ -149,7 +149,7 @@ class Translator:
|
|
| 149 |
|
| 150 |
return results
|
| 151 |
|
| 152 |
-
def translate_with_attention(self, text: str, max_length: int =
|
| 153 |
"""
|
| 154 |
Translate and return attention weights
|
| 155 |
|
|
@@ -253,4 +253,4 @@ if __name__ == "__main__":
|
|
| 253 |
checkpoint_path = sys.argv[1]
|
| 254 |
tokenizer_path = sys.argv[2]
|
| 255 |
|
| 256 |
-
interactive_translation(checkpoint_path, tokenizer_path)
|
|
|
|
| 37 |
self.eos_id = self.sp.eos_id()
|
| 38 |
self.pad_id = self.sp.pad_id()
|
| 39 |
|
| 40 |
+
# Decoder - with no_repeat_ngram_size=3 to prevent repetition
|
| 41 |
self.use_beam_search = use_beam_search
|
| 42 |
if use_beam_search:
|
| 43 |
+
self.decoder = BeamSearch(beam_size=beam_size, no_repeat_ngram_size=3)
|
| 44 |
|
| 45 |
logger.info(f"Translator initialized on {self.device}")
|
| 46 |
logger.info(f"Vocab size: {self.sp.vocab_size()}")
|
| 47 |
logger.info(f"Using {'beam search' if use_beam_search else 'greedy'} decoding")
|
| 48 |
|
| 49 |
+
def translate(self, text: str, max_length: int = 50) -> str: # Changed default from 100 to 50
|
| 50 |
"""
|
| 51 |
Translate a single text
|
| 52 |
|
|
|
|
| 93 |
|
| 94 |
return translated_text
|
| 95 |
|
| 96 |
+
def translate_batch(self, texts: List[str], max_length: int = 50) -> List[str]: # Changed from 100 to 50
|
| 97 |
"""
|
| 98 |
Translate multiple texts in batch
|
| 99 |
|
|
|
|
| 149 |
|
| 150 |
return results
|
| 151 |
|
| 152 |
+
def translate_with_attention(self, text: str, max_length: int = 50) -> Tuple[str, torch.Tensor]:
|
| 153 |
"""
|
| 154 |
Translate and return attention weights
|
| 155 |
|
|
|
|
| 253 |
checkpoint_path = sys.argv[1]
|
| 254 |
tokenizer_path = sys.argv[2]
|
| 255 |
|
| 256 |
+
interactive_translation(checkpoint_path, tokenizer_path)
|