File size: 8,716 Bytes
1620846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9618094
1620846
 
9618094
1620846
 
 
 
 
9618094
1620846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9618094
1620846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9618094
1620846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9618094
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import torch
import sentencepiece as spm
from typing import List, Optional, Dict, Tuple
import logging
from model.transformer import Transformer
from inference.beam_search import BeamSearch, GreedyDecoder

logger = logging.getLogger(__name__)

class Translator:
    """High-level translation interface"""
    
    def __init__(self, model: Transformer, tokenizer_path: str, 
                 device: Optional[torch.device] = None,
                 beam_size: int = 4, use_beam_search: bool = True):
        """
        Initialize translator
        
        Args:
            model: Trained transformer model
            tokenizer_path: Path to sentencepiece model
            device: Device to run on
            beam_size: Beam size for beam search
            use_beam_search: Whether to use beam search or greedy decoding
        """
        self.model = model
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()
        
        # Load tokenizer
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(tokenizer_path)
        
        # Special tokens
        self.bos_id = self.sp.bos_id()
        self.eos_id = self.sp.eos_id()
        self.pad_id = self.sp.pad_id()
        
        # Decoder - with no_repeat_ngram_size=3 to prevent repetition
        self.use_beam_search = use_beam_search
        if use_beam_search:
            self.decoder = BeamSearch(beam_size=beam_size, no_repeat_ngram_size=3)
        
        logger.info(f"Translator initialized on {self.device}")
        logger.info(f"Vocab size: {self.sp.vocab_size()}")
        logger.info(f"Using {'beam search' if use_beam_search else 'greedy'} decoding")
    
    def translate(self, text: str, max_length: int = 50) -> str:  # Changed default from 100 to 50
        """
        Translate a single text
        
        Args:
            text: Source text to translate
            max_length: Maximum translation length
        
        Returns:
            Translated text
        """
        # Tokenize
        tokens = self.sp.encode(text)
        
        # Add special tokens
        tokens = [self.bos_id] + tokens + [self.eos_id]
        
        # Convert to tensor
        src = torch.tensor([tokens], dtype=torch.long).to(self.device)
        
        # Decode
        with torch.no_grad():
            if self.use_beam_search:
                translations = self.decoder.search(
                    self.model, src, max_length,
                    self.bos_id, self.eos_id, self.pad_id
                )
            else:
                translations = GreedyDecoder.decode(
                    self.model, src, max_length,
                    self.bos_id, self.eos_id, self.pad_id
                )
        
        # Decode tokens
        translated_tokens = translations[0]
        
        # Remove special tokens
        if self.bos_id in translated_tokens:
            translated_tokens = translated_tokens[translated_tokens.index(self.bos_id) + 1:]
        if self.eos_id in translated_tokens:
            translated_tokens = translated_tokens[:translated_tokens.index(self.eos_id)]
        
        # Decode to text
        translated_text = self.sp.decode(translated_tokens)
        
        return translated_text
    
    def translate_batch(self, texts: List[str], max_length: int = 50) -> List[str]:  # Changed from 100 to 50
        """
        Translate multiple texts in batch
        
        Args:
            texts: List of source texts
            max_length: Maximum translation length
        
        Returns:
            List of translated texts
        """
        # Tokenize all texts
        tokenized = []
        for text in texts:
            tokens = self.sp.encode(text)
            tokens = [self.bos_id] + tokens + [self.eos_id]
            tokenized.append(tokens)
        
        # Pad sequences
        max_len = max(len(tokens) for tokens in tokenized)
        padded = []
        for tokens in tokenized:
            padded_tokens = tokens + [self.pad_id] * (max_len - len(tokens))
            padded.append(padded_tokens)
        
        # Convert to tensor
        src = torch.tensor(padded, dtype=torch.long).to(self.device)
        
        # Decode
        with torch.no_grad():
            if self.use_beam_search:
                translations = self.decoder.search(
                    self.model, src, max_length,
                    self.bos_id, self.eos_id, self.pad_id
                )
            else:
                translations = GreedyDecoder.decode(
                    self.model, src, max_length,
                    self.bos_id, self.eos_id, self.pad_id
                )
        
        # Decode all translations
        results = []
        for translated_tokens in translations:
            # Remove special tokens
            if self.bos_id in translated_tokens:
                translated_tokens = translated_tokens[translated_tokens.index(self.bos_id) + 1:]
            if self.eos_id in translated_tokens:
                translated_tokens = translated_tokens[:translated_tokens.index(self.eos_id)]
            
            # Decode to text
            translated_text = self.sp.decode(translated_tokens)
            results.append(translated_text)
        
        return results
    
    def translate_with_attention(self, text: str, max_length: int = 50) -> Tuple[str, torch.Tensor]:
        """
        Translate and return attention weights
        
        Args:
            text: Source text to translate
            max_length: Maximum translation length
        
        Returns:
            Tuple of (translated_text, attention_weights)
        """
        # This is a placeholder - would need to modify model to return attention
        translation = self.translate(text, max_length)
        
        # For now, return dummy attention
        src_len = len(self.sp.encode(text)) + 2  # +2 for BOS/EOS
        tgt_len = len(self.sp.encode(translation)) + 2
        attention = torch.rand(1, self.model.n_heads, tgt_len, src_len)
        
        return translation, attention
    
    @classmethod
    def from_checkpoint(cls, checkpoint_path: str, tokenizer_path: str,
                       device: Optional[torch.device] = None, **kwargs):
        """
        Load translator from checkpoint
        
        Args:
            checkpoint_path: Path to model checkpoint
            tokenizer_path: Path to tokenizer
            device: Device to load on
            **kwargs: Additional arguments for translator
        
        Returns:
            Translator instance
        """
        # Load checkpoint
        device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Create model
        config = checkpoint['config']
        model = Transformer(
            vocab_size=config['model']['vocab_size'],
            d_model=config['model']['d_model'],
            n_heads=config['model']['n_heads'],
            n_layers=config['model']['n_layers'],
            d_ff=config['model']['d_ff'],
            max_seq_length=config['model']['max_seq_length'],
            dropout=0.0  # No dropout during inference
        )
        
        # Load weights
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Create translator
        return cls(model, tokenizer_path, device, **kwargs)


def interactive_translation(checkpoint_path: str, tokenizer_path: str):
    """
    Interactive translation in terminal
    
    Args:
        checkpoint_path: Path to model checkpoint
        tokenizer_path: Path to tokenizer
    """
    # Load translator
    translator = Translator.from_checkpoint(checkpoint_path, tokenizer_path)
    
    print("TransLingo Interactive Translation")
    print("Type 'quit' to exit")
    print("-" * 50)
    
    while True:
        # Get input
        text = input("\nEnter German text: ").strip()
        
        if text.lower() == 'quit':
            break
        
        if not text:
            continue
        
        # Translate
        try:
            translation = translator.translate(text)
            print(f"English translation: {translation}")
        except Exception as e:
            print(f"Error: {e}")
    
    print("\nGoodbye!")


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) != 3:
        print("Usage: python translate.py <checkpoint_path> <tokenizer_path>")
        sys.exit(1)
    
    checkpoint_path = sys.argv[1]
    tokenizer_path = sys.argv[2]
    
    interactive_translation(checkpoint_path, tokenizer_path)