Ratan1 commited on
Commit
aabde66
·
1 Parent(s): 1620846

made changes to atteniton and download

Browse files
Files changed (4) hide show
  1. PROJECT_SUMMARY.md +151 -0
  2. data/download.py +104 -49
  3. model/attention.py +110 -27
  4. requirements.txt +1 -0
PROJECT_SUMMARY.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransLingo Project Summary
2
+
3
+ ## ✅ Project Setup Complete!
4
+
5
+ All components of the TransLingo translation system have been successfully implemented and tested.
6
+
7
+ ## 📁 Project Structure
8
+
9
+ ```
10
+ translingo/
11
+ ├── data/ # Data processing pipeline
12
+ │ ├── download.py # Multi30k dataset downloader
13
+ │ └── preprocessing.py # Dataset and dataloader utilities
14
+ ├── model/ # Transformer implementation
15
+ │ ├── transformer.py # Main model class
16
+ │ ├── attention.py # Multi-head attention
17
+ │ ├── embeddings.py # Positional encoding
18
+ │ └── layers.py # Encoder/decoder layers
19
+ ├── training/ # Training components
20
+ │ ├── train.py # Main training script with CUDA support
21
+ │ ├── loss.py # Label smoothing loss
22
+ │ └── optimizer.py # Noam learning rate scheduler
23
+ ├── inference/ # Inference modules
24
+ │ ├── beam_search.py # Beam search decoder
25
+ │ └── translate.py # Translation interface
26
+ ├── frontend/ # User interfaces
27
+ │ └── gradio_app.py # Gradio web interface
28
+ ├── notebooks/ # Training notebooks
29
+ │ └── colab_training.py # Google Colab training script
30
+ └── configs/ # Configuration
31
+ └── config.yaml # Model and training configs
32
+ ```
33
+
34
+ ## 🚀 Next Steps
35
+
36
+ ### 1. Push to GitHub
37
+ ```bash
38
+ # Add your GitHub repository as remote
39
+ git remote add origin https://github.com/YOUR_USERNAME/translingo.git
40
+
41
+ # Push the code
42
+ git push -u origin main
43
+ ```
44
+
45
+ ### 2. Train on Google Colab
46
+ 1. Go to [Google Colab](https://colab.research.google.com/)
47
+ 2. Create a new notebook
48
+ 3. Copy the contents from `notebooks/colab_training.py`
49
+ 4. Follow these steps in the notebook:
50
+ - Mount Google Drive (optional, for saving checkpoints)
51
+ - Clone your GitHub repository
52
+ - Install dependencies
53
+ - Run the training script
54
+ 5. The training will use GPU acceleration automatically
55
+
56
+ ### 3. Download Trained Model
57
+ After training completes:
58
+ 1. Download the checkpoint files from Colab
59
+ 2. Place them in your local `checkpoints/` directory
60
+ 3. The files you need:
61
+ - `best.pt` or `latest.pt` (model checkpoint)
62
+ - `data/processed/tokenizer.model` (tokenizer)
63
+
64
+ ### 4. Run Gradio Demo
65
+ ```bash
66
+ # Activate virtual environment
67
+ source venv/bin/activate
68
+
69
+ # Run the demo
70
+ python frontend/gradio_app.py
71
+
72
+ # Or run without public URL
73
+ python frontend/gradio_app.py --no-share
74
+ ```
75
+
76
+ ## 📊 Model Configuration
77
+
78
+ - **Architecture**: 3-layer Transformer (optimized for faster training)
79
+ - **Model dimension**: 256
80
+ - **Attention heads**: 4
81
+ - **Feed-forward dimension**: 1024
82
+ - **Vocabulary size**: 10,000 (shared BPE)
83
+ - **Expected BLEU score**: 18-22 (with full training)
84
+
85
+ ## 🔧 Customization Options
86
+
87
+ ### For Faster Testing
88
+ Edit `configs/config.yaml`:
89
+ ```yaml
90
+ model:
91
+ n_layers: 2 # Reduce layers
92
+ training:
93
+ num_epochs: 5 # Fewer epochs
94
+ batch_size: 16 # Smaller batches if memory limited
95
+ ```
96
+
97
+ ### For Better Quality
98
+ ```yaml
99
+ model:
100
+ n_layers: 6 # More layers
101
+ d_model: 512 # Larger model
102
+ training:
103
+ num_epochs: 50 # More training
104
+ vocab_size: 20000 # Larger vocabulary
105
+ ```
106
+
107
+ ## 🐛 Troubleshooting
108
+
109
+ ### CUDA/GPU Issues
110
+ - Ensure you're using GPU runtime in Colab (Runtime → Change runtime type → GPU)
111
+ - Check GPU availability with `torch.cuda.is_available()`
112
+
113
+ ### Memory Issues
114
+ - Reduce batch size in `configs/config.yaml`
115
+ - Enable gradient accumulation (already configured)
116
+ - Clear GPU cache periodically (automatic in training script)
117
+
118
+ ### Import Errors
119
+ - The torchtext warning on macOS is normal and handled
120
+ - All other imports should work correctly
121
+
122
+ ## 📝 Additional Features
123
+
124
+ ### While Model is Training
125
+ You can work on these components locally:
126
+ - FastAPI backend (`api/` directory)
127
+ - React frontend (`frontend/web/` directory)
128
+ - Docker deployment (`deployment/` directory)
129
+ - Additional visualization tools
130
+
131
+ ### Testing Translation
132
+ Once you have a trained model:
133
+ ```python
134
+ # Interactive translation
135
+ python inference/translate.py checkpoints/best.pt data/processed/tokenizer.model
136
+ ```
137
+
138
+ ## 🎯 Success Metrics
139
+
140
+ - **Training Loss**: Should decrease below 2.0
141
+ - **Validation BLEU**: Target 18-22 for this configuration
142
+ - **Inference Speed**: < 500ms per sentence on GPU
143
+
144
+ ## 📧 Support
145
+
146
+ If you encounter any issues:
147
+ 1. Check the test script: `python test_setup.py`
148
+ 2. Review the logs in `logs/` directory
149
+ 3. Ensure all dependencies are installed correctly
150
+
151
+ Good luck with your translation system! 🌍🔤
data/download.py CHANGED
@@ -1,22 +1,27 @@
1
  import os
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  try:
4
  from torchtext.datasets import Multi30k
5
  from torchtext.data.utils import get_tokenizer
6
  from torchtext.vocab import build_vocab_from_iterator
7
  TORCHTEXT_AVAILABLE = True
8
  except Exception as e:
9
- print(f"Warning: torchtext import failed: {e}")
10
- print("Will use manual download method")
11
  TORCHTEXT_AVAILABLE = False
12
- import sentencepiece as spm
13
- from typing import List, Tuple, Optional, Dict
14
- import yaml
15
- import logging
16
- from tqdm import tqdm
17
- import urllib.request
18
- import tarfile
19
- import zipfile
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
@@ -32,37 +37,68 @@ class DataDownloader:
32
  os.makedirs(os.path.join(self.data_dir, 'processed'), exist_ok=True)
33
 
34
  def download_multi30k(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
35
- """Download Multi30k dataset"""
36
  logger.info("Downloading Multi30k dataset...")
37
 
38
- try:
39
- # Try using torchtext first if available
40
- if TORCHTEXT_AVAILABLE:
 
 
 
 
 
 
 
 
 
41
  train_data = list(Multi30k(split='train', language_pair=('de', 'en')))
42
  valid_data = list(Multi30k(split='valid', language_pair=('de', 'en')))
43
  test_data = list(Multi30k(split='test', language_pair=('de', 'en')))
44
- else:
45
- raise Exception("torchtext not available")
46
-
47
- logger.info(f"Train samples: {len(train_data)}")
48
- logger.info(f"Valid samples: {len(valid_data)}")
49
- logger.info(f"Test samples: {len(test_data)}")
50
-
51
- # Save to files for later use
52
- self._save_data_to_files(train_data, valid_data, test_data)
53
-
54
- return train_data, valid_data, test_data
55
-
56
- except Exception as e:
57
- logger.warning(f"Torchtext download failed: {e}")
58
- logger.info("Attempting alternative download method...")
59
-
60
- # Alternative: Download from direct URLs
61
- return self._download_multi30k_manual()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def _download_multi30k_manual(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
64
- """Manual download of Multi30k dataset"""
65
- base_url = "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
 
 
 
 
 
66
 
67
  files_to_download = {
68
  'train.de': 'train.de',
@@ -73,17 +109,28 @@ class DataDownloader:
73
  'test_2016_flickr.en': 'test.en'
74
  }
75
 
76
- for remote_file, local_file in files_to_download.items():
77
- url = base_url + remote_file
78
- output_path = os.path.join(self.data_dir, 'raw', local_file)
79
-
80
- if not os.path.exists(output_path):
81
- logger.info(f"Downloading {remote_file}...")
82
- try:
83
- urllib.request.urlretrieve(url, output_path)
84
- except Exception as e:
85
- logger.error(f"Failed to download {remote_file}: {e}")
86
- return [], [], []
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Load data from files
89
  train_data = self._load_parallel_data('train')
@@ -159,12 +206,13 @@ class DataDownloader:
159
  pad_piece='<pad>',
160
  unk_piece='<unk>',
161
  bos_piece='<bos>',
162
- eos_piece='<eos>'
 
163
  )
164
 
165
  # Clean up
166
  os.remove(temp_file)
167
- logger.info(f"SentencePiece model saved to {model_path}")
168
 
169
  def prepare_tokenizer(self, train_data: List[Tuple[str, str]]) -> None:
170
  """Prepare tokenizer from training data"""
@@ -182,12 +230,19 @@ class DataDownloader:
182
  self.train_sentencepiece(all_texts, "tokenizer", vocab_size=self.config['model']['vocab_size'])
183
 
184
  if __name__ == "__main__":
 
 
 
 
 
 
 
185
  downloader = DataDownloader()
186
  train_data, valid_data, test_data = downloader.download_multi30k()
187
 
188
  if train_data:
189
  # Train tokenizer
190
  downloader.prepare_tokenizer(train_data)
191
- logger.info("Data download and tokenizer training completed!")
192
  else:
193
- logger.error("Failed to download data.")
 
1
  import os
2
  import torch
3
+ import sentencepiece as spm
4
+ from typing import List, Tuple, Optional, Dict
5
+ import yaml
6
+ import logging
7
+ from tqdm import tqdm
8
+ import urllib.request
9
+
10
+ try:
11
+ from datasets import load_dataset
12
+ HUGGINGFACE_AVAILABLE = True
13
+ except ImportError:
14
+ HUGGINGFACE_AVAILABLE = False
15
+ print("Warning: datasets library not available. Install with: pip install datasets")
16
+
17
  try:
18
  from torchtext.datasets import Multi30k
19
  from torchtext.data.utils import get_tokenizer
20
  from torchtext.vocab import build_vocab_from_iterator
21
  TORCHTEXT_AVAILABLE = True
22
  except Exception as e:
 
 
23
  TORCHTEXT_AVAILABLE = False
24
+ print(f"Warning: torchtext import failed: {e}")
 
 
 
 
 
 
 
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
 
37
  os.makedirs(os.path.join(self.data_dir, 'processed'), exist_ok=True)
38
 
39
  def download_multi30k(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
40
+ """Download Multi30k dataset - tries multiple methods"""
41
  logger.info("Downloading Multi30k dataset...")
42
 
43
+ # Method 1: Try Hugging Face first (most reliable)
44
+ if HUGGINGFACE_AVAILABLE:
45
+ try:
46
+ logger.info("Attempting download from Hugging Face...")
47
+ return self._download_from_huggingface()
48
+ except Exception as e:
49
+ logger.warning(f"Hugging Face download failed: {e}")
50
+
51
+ # Method 2: Try torchtext if available
52
+ if TORCHTEXT_AVAILABLE:
53
+ try:
54
+ logger.info("Attempting download with torchtext...")
55
  train_data = list(Multi30k(split='train', language_pair=('de', 'en')))
56
  valid_data = list(Multi30k(split='valid', language_pair=('de', 'en')))
57
  test_data = list(Multi30k(split='test', language_pair=('de', 'en')))
58
+
59
+ logger.info(f"Train samples: {len(train_data)}")
60
+ logger.info(f"Valid samples: {len(valid_data)}")
61
+ logger.info(f"Test samples: {len(test_data)}")
62
+
63
+ self._save_data_to_files(train_data, valid_data, test_data)
64
+ return train_data, valid_data, test_data
65
+ except Exception as e:
66
+ logger.warning(f"Torchtext download failed: {e}")
67
+
68
+ # Method 3: Try manual download from GitHub
69
+ logger.info("Attempting manual download from GitHub...")
70
+ return self._download_multi30k_manual()
71
+
72
+ def _download_from_huggingface(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
73
+ """Download Multi30k from Hugging Face datasets hub"""
74
+ logger.info("Downloading from Hugging Face datasets hub...")
75
+
76
+ # Load dataset
77
+ dataset = load_dataset("bentrevett/multi30k")
78
+
79
+ # Convert to expected format: List[Tuple[str, str]]
80
+ train_data = [(item['de'], item['en']) for item in dataset['train']]
81
+ valid_data = [(item['de'], item['en']) for item in dataset['validation']]
82
+ test_data = [(item['de'], item['en']) for item in dataset['test']]
83
+
84
+ logger.info(f"✅ Downloaded from Hugging Face:")
85
+ logger.info(f" Train samples: {len(train_data)}")
86
+ logger.info(f" Valid samples: {len(valid_data)}")
87
+ logger.info(f" Test samples: {len(test_data)}")
88
+
89
+ # Save to files for consistency with other methods
90
+ self._save_data_to_files(train_data, valid_data, test_data)
91
+
92
+ return train_data, valid_data, test_data
93
 
94
  def _download_multi30k_manual(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
95
+ """Manual download of Multi30k dataset from GitHub"""
96
+ # Try multiple mirror URLs
97
+ base_urls = [
98
+ "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/",
99
+ "https://github.com/multi30k/dataset/raw/master/data/task1/raw/",
100
+ "https://raw.githubusercontent.com/bentrevett/pytorch-seq2seq/master/assets/data/"
101
+ ]
102
 
103
  files_to_download = {
104
  'train.de': 'train.de',
 
109
  'test_2016_flickr.en': 'test.en'
110
  }
111
 
112
+ success = False
113
+ for base_url in base_urls:
114
+ try:
115
+ for remote_file, local_file in files_to_download.items():
116
+ url = base_url + remote_file
117
+ output_path = os.path.join(self.data_dir, 'raw', local_file)
118
+
119
+ if not os.path.exists(output_path):
120
+ logger.info(f"Downloading {remote_file} from {base_url}...")
121
+ urllib.request.urlretrieve(url, output_path)
122
+
123
+ success = True
124
+ logger.info(f"✅ Successfully downloaded from {base_url}")
125
+ break
126
+ except Exception as e:
127
+ logger.warning(f"Failed to download from {base_url}: {e}")
128
+ continue
129
+
130
+ if not success:
131
+ logger.error("❌ Failed to download from all sources")
132
+ logger.info("Please install datasets library: pip install datasets")
133
+ return [], [], []
134
 
135
  # Load data from files
136
  train_data = self._load_parallel_data('train')
 
206
  pad_piece='<pad>',
207
  unk_piece='<unk>',
208
  bos_piece='<bos>',
209
+ eos_piece='<eos>',
210
+ character_coverage=1.0 # Important for handling all characters
211
  )
212
 
213
  # Clean up
214
  os.remove(temp_file)
215
+ logger.info(f"SentencePiece model saved to {model_path}")
216
 
217
  def prepare_tokenizer(self, train_data: List[Tuple[str, str]]) -> None:
218
  """Prepare tokenizer from training data"""
 
230
  self.train_sentencepiece(all_texts, "tokenizer", vocab_size=self.config['model']['vocab_size'])
231
 
232
  if __name__ == "__main__":
233
+ # Install datasets if not available
234
+ if not HUGGINGFACE_AVAILABLE:
235
+ import subprocess
236
+ print("Installing datasets library...")
237
+ subprocess.run(["pip", "install", "datasets", "-q"])
238
+ from datasets import load_dataset
239
+
240
  downloader = DataDownloader()
241
  train_data, valid_data, test_data = downloader.download_multi30k()
242
 
243
  if train_data:
244
  # Train tokenizer
245
  downloader.prepare_tokenizer(train_data)
246
+ logger.info("Data download and tokenizer training completed!")
247
  else:
248
+ logger.error("Failed to download data.")
model/attention.py CHANGED
@@ -5,7 +5,7 @@ import math
5
  from typing import Optional, Tuple
6
 
7
  class ScaledDotProductAttention(nn.Module):
8
- """Scaled Dot-Product Attention mechanism"""
9
 
10
  def __init__(self, temperature: float = 1.0, dropout: float = 0.1):
11
  super().__init__()
@@ -25,15 +25,26 @@ class ScaledDotProductAttention(nn.Module):
25
  output: Attention output [batch_size, n_heads, seq_len, d_k]
26
  attention: Attention weights [batch_size, n_heads, seq_len, seq_len]
27
  """
28
- # Calculate attention scores
29
- scores = torch.matmul(q, k.transpose(-2, -1)) / (self.temperature * math.sqrt(q.size(-1)))
 
30
 
31
- # Apply mask if provided
32
  if mask is not None:
33
- scores = scores.masked_fill(mask == 0, -1e9)
34
-
35
- # Apply softmax
 
 
 
 
 
 
 
 
36
  attention = F.softmax(scores, dim=-1)
 
 
37
  attention = self.dropout(attention)
38
 
39
  # Apply attention to values
@@ -43,21 +54,26 @@ class ScaledDotProductAttention(nn.Module):
43
 
44
 
45
  class MultiHeadAttention(nn.Module):
46
- """Multi-Head Attention mechanism"""
47
 
48
- def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
 
49
  super().__init__()
50
  assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
51
 
52
  self.d_model = d_model
53
  self.n_heads = n_heads
54
  self.d_k = d_model // n_heads
 
 
 
 
 
 
 
55
 
56
- # Linear projections
57
- self.W_q = nn.Linear(d_model, d_model)
58
- self.W_k = nn.Linear(d_model, d_model)
59
- self.W_v = nn.Linear(d_model, d_model)
60
- self.W_o = nn.Linear(d_model, d_model)
61
 
62
  # Attention
63
  self.attention = ScaledDotProductAttention(temperature=1.0, dropout=dropout)
@@ -66,8 +82,15 @@ class MultiHeadAttention(nn.Module):
66
  self.dropout = nn.Dropout(dropout)
67
 
68
  # Layer normalization
69
- self.layer_norm = nn.LayerNorm(d_model)
70
-
 
 
 
 
 
 
 
71
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
72
  mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
73
  """
@@ -81,8 +104,13 @@ class MultiHeadAttention(nn.Module):
81
  output: Multi-head attention output [batch_size, seq_len, d_model]
82
  attention: Attention weights [batch_size, n_heads, seq_len, seq_len]
83
  """
84
- batch_size = query.size(0)
85
- seq_len = query.size(1)
 
 
 
 
 
86
 
87
  # Store residual
88
  residual = query
@@ -104,8 +132,10 @@ class MultiHeadAttention(nn.Module):
104
  output = self.W_o(attn_output)
105
  output = self.dropout(output)
106
 
107
- # Add and normalize
108
- output = self.layer_norm(output + residual)
 
 
109
 
110
  return output, attention_weights
111
 
@@ -121,7 +151,9 @@ def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
121
  Returns:
122
  mask: Padding mask [batch_size, 1, 1, seq_len]
123
  """
124
- return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
 
 
125
 
126
 
127
  def create_look_ahead_mask(size: int, device: torch.device) -> torch.Tensor:
@@ -135,8 +167,11 @@ def create_look_ahead_mask(size: int, device: torch.device) -> torch.Tensor:
135
  Returns:
136
  mask: Look-ahead mask [1, 1, size, size]
137
  """
138
- mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
139
- return (1 - mask).unsqueeze(0).unsqueeze(0)
 
 
 
140
 
141
 
142
  def create_masks(src: torch.Tensor, tgt: torch.Tensor,
@@ -157,14 +192,62 @@ def create_masks(src: torch.Tensor, tgt: torch.Tensor,
157
  # Source mask (padding only)
158
  src_mask = create_padding_mask(src, pad_idx)
159
 
160
- # Target mask (padding + look-ahead)
161
  tgt_pad_mask = create_padding_mask(tgt, pad_idx)
 
 
162
  tgt_len = tgt.size(1)
163
  tgt_look_ahead_mask = create_look_ahead_mask(tgt_len, tgt.device)
164
- tgt_mask = tgt_pad_mask.float() * tgt_look_ahead_mask.float()
165
- tgt_mask = tgt_mask.bool()
166
 
167
- # Memory mask (same as source mask but different shape)
 
 
 
 
168
  memory_mask = src_mask
169
 
170
  return src_mask, tgt_mask, memory_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import Optional, Tuple
6
 
7
  class ScaledDotProductAttention(nn.Module):
8
+ """Scaled Dot-Product Attention mechanism with numerical stability"""
9
 
10
  def __init__(self, temperature: float = 1.0, dropout: float = 0.1):
11
  super().__init__()
 
25
  output: Attention output [batch_size, n_heads, seq_len, d_k]
26
  attention: Attention weights [batch_size, n_heads, seq_len, seq_len]
27
  """
28
+ # Calculate attention scores with temperature scaling
29
+ d_k = q.size(-1)
30
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.temperature * math.sqrt(d_k))
31
 
32
+ # Apply mask if provided - using fp16-safe value
33
  if mask is not None:
34
+ # Determine safe mask value based on dtype
35
+ if scores.dtype == torch.float16:
36
+ mask_value = -65504.0 # Max negative value for fp16
37
+ else:
38
+ mask_value = -1e9 # Original value for fp32
39
+
40
+ # Use torch.finfo for more robust dtype handling
41
+ mask_value = torch.finfo(scores.dtype).min if hasattr(torch, 'finfo') else mask_value
42
+ scores = scores.masked_fill(mask == 0, mask_value)
43
+
44
+ # Apply softmax with numerical stability
45
  attention = F.softmax(scores, dim=-1)
46
+
47
+ # Apply dropout
48
  attention = self.dropout(attention)
49
 
50
  # Apply attention to values
 
54
 
55
 
56
  class MultiHeadAttention(nn.Module):
57
+ """Multi-Head Attention mechanism with improved stability"""
58
 
59
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
60
+ use_bias: bool = True, pre_norm: bool = False):
61
  super().__init__()
62
  assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
63
 
64
  self.d_model = d_model
65
  self.n_heads = n_heads
66
  self.d_k = d_model // n_heads
67
+ self.pre_norm = pre_norm
68
+
69
+ # Linear projections with optional bias
70
+ self.W_q = nn.Linear(d_model, d_model, bias=use_bias)
71
+ self.W_k = nn.Linear(d_model, d_model, bias=use_bias)
72
+ self.W_v = nn.Linear(d_model, d_model, bias=use_bias)
73
+ self.W_o = nn.Linear(d_model, d_model, bias=use_bias)
74
 
75
+ # Initialize weights using Xavier uniform
76
+ self._init_weights()
 
 
 
77
 
78
  # Attention
79
  self.attention = ScaledDotProductAttention(temperature=1.0, dropout=dropout)
 
82
  self.dropout = nn.Dropout(dropout)
83
 
84
  # Layer normalization
85
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
86
+
87
+ def _init_weights(self):
88
+ """Initialize weights with Xavier uniform distribution"""
89
+ for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
90
+ nn.init.xavier_uniform_(module.weight)
91
+ if module.bias is not None:
92
+ nn.init.zeros_(module.bias)
93
+
94
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
95
  mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
96
  """
 
104
  output: Multi-head attention output [batch_size, seq_len, d_model]
105
  attention: Attention weights [batch_size, n_heads, seq_len, seq_len]
106
  """
107
+ batch_size, seq_len, _ = query.size()
108
+
109
+ # Pre-norm variant (if enabled)
110
+ if self.pre_norm:
111
+ query = self.layer_norm(query)
112
+ key = self.layer_norm(key)
113
+ value = self.layer_norm(value)
114
 
115
  # Store residual
116
  residual = query
 
132
  output = self.W_o(attn_output)
133
  output = self.dropout(output)
134
 
135
+ # Add residual and normalize
136
+ output = output + residual
137
+ if not self.pre_norm:
138
+ output = self.layer_norm(output)
139
 
140
  return output, attention_weights
141
 
 
151
  Returns:
152
  mask: Padding mask [batch_size, 1, 1, seq_len]
153
  """
154
+ # Create boolean mask
155
+ mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
156
+ return mask.to(torch.bool)
157
 
158
 
159
  def create_look_ahead_mask(size: int, device: torch.device) -> torch.Tensor:
 
167
  Returns:
168
  mask: Look-ahead mask [1, 1, size, size]
169
  """
170
+ # Create upper triangular matrix
171
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=torch.bool), diagonal=1)
172
+ # Invert it (1 for allowed positions, 0 for masked)
173
+ mask = ~mask
174
+ return mask.unsqueeze(0).unsqueeze(0)
175
 
176
 
177
  def create_masks(src: torch.Tensor, tgt: torch.Tensor,
 
192
  # Source mask (padding only)
193
  src_mask = create_padding_mask(src, pad_idx)
194
 
195
+ # Target padding mask
196
  tgt_pad_mask = create_padding_mask(tgt, pad_idx)
197
+
198
+ # Target look-ahead mask
199
  tgt_len = tgt.size(1)
200
  tgt_look_ahead_mask = create_look_ahead_mask(tgt_len, tgt.device)
 
 
201
 
202
+ # Combine padding and look-ahead masks for target
203
+ # Both masks should be True where attention is allowed
204
+ tgt_mask = tgt_pad_mask & tgt_look_ahead_mask
205
+
206
+ # Memory mask (same as source mask)
207
  memory_mask = src_mask
208
 
209
  return src_mask, tgt_mask, memory_mask
210
+
211
+
212
+ # Optional: Flash Attention wrapper (if available)
213
+ try:
214
+ from torch.nn.functional import scaled_dot_product_attention
215
+ FLASH_ATTENTION_AVAILABLE = True
216
+ except ImportError:
217
+ FLASH_ATTENTION_AVAILABLE = False
218
+
219
+ class FlashAttention(nn.Module):
220
+ """Flash Attention wrapper for better performance (if available)"""
221
+
222
+ def __init__(self, dropout: float = 0.1):
223
+ super().__init__()
224
+ self.dropout = dropout
225
+
226
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
227
+ mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
228
+ """
229
+ Uses PyTorch's scaled_dot_product_attention if available (includes Flash Attention)
230
+ """
231
+ if FLASH_ATTENTION_AVAILABLE and mask is None:
232
+ # Use efficient implementation when no mask
233
+ output = scaled_dot_product_attention(
234
+ q, k, v,
235
+ dropout_p=self.dropout if self.training else 0.0,
236
+ is_causal=False
237
+ )
238
+ return output, None
239
+ else:
240
+ # Fallback to standard implementation
241
+ d_k = q.size(-1)
242
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
243
+
244
+ if mask is not None:
245
+ mask_value = torch.finfo(scores.dtype).min
246
+ scores = scores.masked_fill(mask == 0, mask_value)
247
+
248
+ attention = F.softmax(scores, dim=-1)
249
+ if self.training and self.dropout > 0:
250
+ attention = F.dropout(attention, p=self.dropout)
251
+
252
+ output = torch.matmul(attention, v)
253
+ return output, attention
requirements.txt CHANGED
@@ -17,3 +17,4 @@ aiofiles>=23.1.0
17
  pytest>=7.3.0
18
  black>=23.3.0
19
  flake8>=6.0.0
 
 
17
  pytest>=7.3.0
18
  black>=23.3.0
19
  flake8>=6.0.0
20
+ datasets>=4.4.1