Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| from typing import Optional | |
| class PositionalEncoding(nn.Module): | |
| """Positional encoding for transformer models""" | |
| def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| self.d_model = d_model | |
| # Create positional encoding matrix | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1).float() | |
| # Create div_term for sin/cos frequencies | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| -(math.log(10000.0) / d_model)) | |
| # Apply sin to even indices | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| # Apply cos to odd indices | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| # Add batch dimension and register as buffer | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input tensor [batch_size, seq_len, d_model] | |
| Returns: | |
| Tensor with positional encoding added | |
| """ | |
| # Add positional encoding | |
| x = x + self.pe[:, :x.size(1)] | |
| return self.dropout(x) | |
| class TokenEmbedding(nn.Module): | |
| """Token embedding with scaling""" | |
| def __init__(self, vocab_size: int, d_model: int): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| self.d_model = d_model | |
| self.scale = math.sqrt(d_model) | |
| # Initialize embeddings | |
| nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input token indices [batch_size, seq_len] | |
| Returns: | |
| Scaled embeddings [batch_size, seq_len, d_model] | |
| """ | |
| return self.embedding(x) * self.scale | |
| class TransformerEmbedding(nn.Module): | |
| """Combined token and positional embedding for transformer""" | |
| def __init__(self, vocab_size: int, d_model: int, max_len: int = 5000, | |
| dropout: float = 0.1, scale_embedding: bool = True): | |
| super().__init__() | |
| # Token embedding | |
| self.token_embedding = TokenEmbedding(vocab_size, d_model) if scale_embedding else nn.Embedding(vocab_size, d_model) | |
| # Positional encoding | |
| self.positional_encoding = PositionalEncoding(d_model, max_len, dropout) | |
| # Optional learned positional embeddings (alternative to sinusoidal) | |
| self.use_learned_pos = False | |
| if self.use_learned_pos: | |
| self.pos_embedding = nn.Embedding(max_len, d_model) | |
| def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input token indices [batch_size, seq_len] | |
| pos: Optional position indices for learned positional embeddings | |
| Returns: | |
| Embedded and encoded tensor [batch_size, seq_len, d_model] | |
| """ | |
| # Get token embeddings | |
| if isinstance(self.token_embedding, TokenEmbedding): | |
| token_emb = self.token_embedding(x) | |
| else: | |
| token_emb = self.token_embedding(x) * math.sqrt(self.token_embedding.embedding_dim) | |
| # Add positional encoding | |
| if self.use_learned_pos and pos is not None: | |
| pos_emb = self.pos_embedding(pos) | |
| output = token_emb + pos_emb | |
| output = self.positional_encoding.dropout(output) | |
| else: | |
| output = self.positional_encoding(token_emb) | |
| return output | |
| class LearnedPositionalEmbedding(nn.Module): | |
| """Learned positional embeddings (alternative to sinusoidal)""" | |
| def __init__(self, max_len: int, d_model: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.embedding = nn.Embedding(max_len, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| # Initialize | |
| nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input tensor [batch_size, seq_len, d_model] | |
| Returns: | |
| Tensor with learned positional embeddings added | |
| """ | |
| batch_size, seq_len = x.size(0), x.size(1) | |
| # Create position indices | |
| positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) | |
| # Add positional embeddings | |
| x = x + self.embedding(positions) | |
| return self.dropout(x) | |