import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from model.attention import MultiHeadAttention class FeedForward(nn.Module): """Position-wise feed-forward network""" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() # Layer normalization self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Input tensor [batch_size, seq_len, d_model] Returns: Output tensor [batch_size, seq_len, d_model] """ # Store residual residual = x # Feed-forward network x = self.linear1(x) x = self.activation(x) x = self.dropout(x) x = self.linear2(x) x = self.dropout(x) # Add and normalize x = self.layer_norm(x + residual) return x class EncoderLayer(nn.Module): """Single encoder layer""" def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() # Multi-head attention self.self_attention = MultiHeadAttention(d_model, n_heads, dropout) # Feed-forward network self.feed_forward = FeedForward(d_model, d_ff, dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: Input tensor [batch_size, seq_len, d_model] mask: Attention mask Returns: Output tensor [batch_size, seq_len, d_model] """ # Self-attention x, _ = self.self_attention(x, x, x, mask) # Feed-forward x = self.feed_forward(x) return x class DecoderLayer(nn.Module): """Single decoder layer""" def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() # Masked self-attention self.self_attention = MultiHeadAttention(d_model, n_heads, dropout) # Cross-attention self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout) # Feed-forward network self.feed_forward = FeedForward(d_model, d_ff, dropout) def forward(self, x: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: Input tensor [batch_size, tgt_len, d_model] memory: Encoder output [batch_size, src_len, d_model] tgt_mask: Target attention mask memory_mask: Memory attention mask Returns: output: Output tensor [batch_size, tgt_len, d_model] self_attn: Self-attention weights cross_attn: Cross-attention weights """ # Masked self-attention x, self_attn = self.self_attention(x, x, x, tgt_mask) # Cross-attention x, cross_attn = self.cross_attention(x, memory, memory, memory_mask) # Feed-forward x = self.feed_forward(x) return x, self_attn, cross_attn class Encoder(nn.Module): """Transformer encoder""" def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() # Stack of encoder layers self.layers = nn.ModuleList([ EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) # Final layer normalization self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: Input tensor [batch_size, seq_len, d_model] mask: Attention mask Returns: Output tensor [batch_size, seq_len, d_model] """ # Pass through encoder layers for layer in self.layers: x = layer(x, mask) # Final layer normalization x = self.layer_norm(x) return x class Decoder(nn.Module): """Transformer decoder""" def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() # Stack of decoder layers self.layers = nn.ModuleList([ DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) # Final layer normalization self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: Input tensor [batch_size, tgt_len, d_model] memory: Encoder output [batch_size, src_len, d_model] tgt_mask: Target attention mask memory_mask: Memory attention mask Returns: Output tensor [batch_size, tgt_len, d_model] """ # Pass through decoder layers for layer in self.layers: x, _, _ = layer(x, memory, tgt_mask, memory_mask) # Final layer normalization x = self.layer_norm(x) return x