Ratan1 commited on
Commit
9618094
·
1 Parent(s): aad62f1
Files changed (2) hide show
  1. inference/beam_search.py +17 -7
  2. inference/translate.py +6 -6
inference/beam_search.py CHANGED
@@ -15,11 +15,11 @@ class BeamSearch:
15
  """Beam search decoder for transformer models"""
16
 
17
  def __init__(self, beam_size: int = 4, length_penalty: float = 0.6,
18
- coverage_penalty: float = 0.0, no_repeat_ngram_size: int = 0):
19
  self.beam_size = beam_size
20
  self.length_penalty = length_penalty
21
  self.coverage_penalty = coverage_penalty
22
- self.no_repeat_ngram_size = no_repeat_ngram_size
23
 
24
  def search(self, model, src: torch.Tensor, max_length: int = 100,
25
  bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]:
@@ -51,7 +51,11 @@ class BeamSearch:
51
  all_candidates = []
52
 
53
  for batch_idx in range(batch_size):
54
- # Skip if all beams are finished
 
 
 
 
55
  if all(hyp.finished for hyp in beams[batch_idx]):
56
  continue
57
 
@@ -98,19 +102,21 @@ class BeamSearch:
98
  for token_rank, (token_log_prob, token_id) in enumerate(
99
  zip(beam_log_probs, beam_indices_local)):
100
 
 
 
101
  # Apply no-repeat penalty
102
- if self._has_repeated_ngram(hypothesis.tokens + [token_id.item()]):
103
  continue
104
 
105
  new_log_prob = hypothesis.log_prob + token_log_prob.item()
106
 
107
  # Apply length penalty
108
- score = self._apply_length_penalty(new_log_prob, len(hypothesis.tokens) + 1)
109
 
110
  candidates.append((
111
  score,
112
  BeamHypothesis(
113
- tokens=hypothesis.tokens + [token_id.item()],
114
  log_prob=new_log_prob,
115
  finished=(token_id.item() == eos_id)
116
  )
@@ -123,6 +129,10 @@ class BeamSearch:
123
  for score, hypothesis in candidates[:self.beam_size]:
124
  new_beams.append(hypothesis)
125
 
 
 
 
 
126
  beams[batch_idx] = new_beams
127
 
128
  # Extract best sequences
@@ -202,4 +212,4 @@ class GreedyDecoder:
202
  tokens = tokens[:eos_idx + 1]
203
  results.append(tokens)
204
 
205
- return results
 
15
  """Beam search decoder for transformer models"""
16
 
17
  def __init__(self, beam_size: int = 4, length_penalty: float = 0.6,
18
+ coverage_penalty: float = 0.0, no_repeat_ngram_size: int = 3):
19
  self.beam_size = beam_size
20
  self.length_penalty = length_penalty
21
  self.coverage_penalty = coverage_penalty
22
+ self.no_repeat_ngram_size = no_repeat_ngram_size # Changed default from 0 to 3
23
 
24
  def search(self, model, src: torch.Tensor, max_length: int = 100,
25
  bos_id: int = 2, eos_id: int = 3, pad_id: int = 0) -> List[List[int]]:
 
51
  all_candidates = []
52
 
53
  for batch_idx in range(batch_size):
54
+ # NEW: Stop if the BEST beam (first one after sorting) is finished
55
+ if beams[batch_idx] and beams[batch_idx][0].finished:
56
+ continue
57
+
58
+ # Also skip if all beams are finished
59
  if all(hyp.finished for hyp in beams[batch_idx]):
60
  continue
61
 
 
102
  for token_rank, (token_log_prob, token_id) in enumerate(
103
  zip(beam_log_probs, beam_indices_local)):
104
 
105
+ new_tokens = hypothesis.tokens + [token_id.item()]
106
+
107
  # Apply no-repeat penalty
108
+ if self._has_repeated_ngram(new_tokens):
109
  continue
110
 
111
  new_log_prob = hypothesis.log_prob + token_log_prob.item()
112
 
113
  # Apply length penalty
114
+ score = self._apply_length_penalty(new_log_prob, len(new_tokens))
115
 
116
  candidates.append((
117
  score,
118
  BeamHypothesis(
119
+ tokens=new_tokens,
120
  log_prob=new_log_prob,
121
  finished=(token_id.item() == eos_id)
122
  )
 
129
  for score, hypothesis in candidates[:self.beam_size]:
130
  new_beams.append(hypothesis)
131
 
132
+ # If we have no candidates, keep the old beams
133
+ if not new_beams:
134
+ new_beams = beams[batch_idx]
135
+
136
  beams[batch_idx] = new_beams
137
 
138
  # Extract best sequences
 
212
  tokens = tokens[:eos_idx + 1]
213
  results.append(tokens)
214
 
215
+ return results
inference/translate.py CHANGED
@@ -37,16 +37,16 @@ class Translator:
37
  self.eos_id = self.sp.eos_id()
38
  self.pad_id = self.sp.pad_id()
39
 
40
- # Decoder
41
  self.use_beam_search = use_beam_search
42
  if use_beam_search:
43
- self.decoder = BeamSearch(beam_size=beam_size)
44
 
45
  logger.info(f"Translator initialized on {self.device}")
46
  logger.info(f"Vocab size: {self.sp.vocab_size()}")
47
  logger.info(f"Using {'beam search' if use_beam_search else 'greedy'} decoding")
48
 
49
- def translate(self, text: str, max_length: int = 100) -> str:
50
  """
51
  Translate a single text
52
 
@@ -93,7 +93,7 @@ class Translator:
93
 
94
  return translated_text
95
 
96
- def translate_batch(self, texts: List[str], max_length: int = 100) -> List[str]:
97
  """
98
  Translate multiple texts in batch
99
 
@@ -149,7 +149,7 @@ class Translator:
149
 
150
  return results
151
 
152
- def translate_with_attention(self, text: str, max_length: int = 100) -> Tuple[str, torch.Tensor]:
153
  """
154
  Translate and return attention weights
155
 
@@ -253,4 +253,4 @@ if __name__ == "__main__":
253
  checkpoint_path = sys.argv[1]
254
  tokenizer_path = sys.argv[2]
255
 
256
- interactive_translation(checkpoint_path, tokenizer_path)
 
37
  self.eos_id = self.sp.eos_id()
38
  self.pad_id = self.sp.pad_id()
39
 
40
+ # Decoder - with no_repeat_ngram_size=3 to prevent repetition
41
  self.use_beam_search = use_beam_search
42
  if use_beam_search:
43
+ self.decoder = BeamSearch(beam_size=beam_size, no_repeat_ngram_size=3)
44
 
45
  logger.info(f"Translator initialized on {self.device}")
46
  logger.info(f"Vocab size: {self.sp.vocab_size()}")
47
  logger.info(f"Using {'beam search' if use_beam_search else 'greedy'} decoding")
48
 
49
+ def translate(self, text: str, max_length: int = 50) -> str: # Changed default from 100 to 50
50
  """
51
  Translate a single text
52
 
 
93
 
94
  return translated_text
95
 
96
+ def translate_batch(self, texts: List[str], max_length: int = 50) -> List[str]: # Changed from 100 to 50
97
  """
98
  Translate multiple texts in batch
99
 
 
149
 
150
  return results
151
 
152
+ def translate_with_attention(self, text: str, max_length: int = 50) -> Tuple[str, torch.Tensor]:
153
  """
154
  Translate and return attention weights
155
 
 
253
  checkpoint_path = sys.argv[1]
254
  tokenizer_path = sys.argv[2]
255
 
256
+ interactive_translation(checkpoint_path, tokenizer_path)