translingo / model /embeddings.py
Ratan1's picture
Initial commit: Complete TransLingo translation system
1620846
import torch
import torch.nn as nn
import math
from typing import Optional
class PositionalEncoding(nn.Module):
"""Positional encoding for transformer models"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
# Create div_term for sin/cos frequencies
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
# Apply sin to even indices
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cos to odd indices
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension and register as buffer
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor [batch_size, seq_len, d_model]
Returns:
Tensor with positional encoding added
"""
# Add positional encoding
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TokenEmbedding(nn.Module):
"""Token embedding with scaling"""
def __init__(self, vocab_size: int, d_model: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
self.scale = math.sqrt(d_model)
# Initialize embeddings
nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input token indices [batch_size, seq_len]
Returns:
Scaled embeddings [batch_size, seq_len, d_model]
"""
return self.embedding(x) * self.scale
class TransformerEmbedding(nn.Module):
"""Combined token and positional embedding for transformer"""
def __init__(self, vocab_size: int, d_model: int, max_len: int = 5000,
dropout: float = 0.1, scale_embedding: bool = True):
super().__init__()
# Token embedding
self.token_embedding = TokenEmbedding(vocab_size, d_model) if scale_embedding else nn.Embedding(vocab_size, d_model)
# Positional encoding
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
# Optional learned positional embeddings (alternative to sinusoidal)
self.use_learned_pos = False
if self.use_learned_pos:
self.pos_embedding = nn.Embedding(max_len, d_model)
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: Input token indices [batch_size, seq_len]
pos: Optional position indices for learned positional embeddings
Returns:
Embedded and encoded tensor [batch_size, seq_len, d_model]
"""
# Get token embeddings
if isinstance(self.token_embedding, TokenEmbedding):
token_emb = self.token_embedding(x)
else:
token_emb = self.token_embedding(x) * math.sqrt(self.token_embedding.embedding_dim)
# Add positional encoding
if self.use_learned_pos and pos is not None:
pos_emb = self.pos_embedding(pos)
output = token_emb + pos_emb
output = self.positional_encoding.dropout(output)
else:
output = self.positional_encoding(token_emb)
return output
class LearnedPositionalEmbedding(nn.Module):
"""Learned positional embeddings (alternative to sinusoidal)"""
def __init__(self, max_len: int, d_model: int, dropout: float = 0.1):
super().__init__()
self.embedding = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(dropout)
# Initialize
nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor [batch_size, seq_len, d_model]
Returns:
Tensor with learned positional embeddings added
"""
batch_size, seq_len = x.size(0), x.size(1)
# Create position indices
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
# Add positional embeddings
x = x + self.embedding(positions)
return self.dropout(x)