translingo / model /attention.py
Ratan1's picture
changed attention.py
9c2509e
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention mechanism with numerical stability"""
def __init__(self, temperature: float = 1.0, dropout: float = 0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(dropout)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q: Query tensor [batch_size, n_heads, seq_len, d_k]
k: Key tensor [batch_size, n_heads, seq_len, d_k]
v: Value tensor [batch_size, n_heads, seq_len, d_k]
mask: Mask tensor [batch_size, 1, seq_len, seq_len] or [batch_size, 1, 1, seq_len]
Returns:
output: Attention output [batch_size, n_heads, seq_len, d_k]
attention: Attention weights [batch_size, n_heads, seq_len, seq_len]
"""
# Calculate attention scores with temperature scaling
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.temperature * math.sqrt(d_k))
# Apply mask if provided - using fp16-safe value
if mask is not None:
# Determine safe mask value based on dtype
if scores.dtype == torch.float16:
mask_value = -65504.0 # Max negative value for fp16
else:
mask_value = -1e9 # Original value for fp32
# Use torch.finfo for more robust dtype handling
mask_value = torch.finfo(scores.dtype).min if hasattr(torch, 'finfo') else mask_value
scores = scores.masked_fill(mask == 0, mask_value)
# Apply softmax with numerical stability
attention = F.softmax(scores, dim=-1)
# Apply dropout
attention = self.dropout(attention)
# Apply attention to values
output = torch.matmul(attention, v)
return output, attention
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention mechanism with improved stability"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
use_bias: bool = True, pre_norm: bool = False):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.pre_norm = pre_norm
# Linear projections with optional bias
self.W_q = nn.Linear(d_model, d_model, bias=use_bias)
self.W_k = nn.Linear(d_model, d_model, bias=use_bias)
self.W_v = nn.Linear(d_model, d_model, bias=use_bias)
self.W_o = nn.Linear(d_model, d_model, bias=use_bias)
# Initialize weights using Xavier uniform
self._init_weights()
# Attention
self.attention = ScaledDotProductAttention(temperature=1.0, dropout=dropout)
# Dropout
self.dropout = nn.Dropout(dropout)
# Layer normalization
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def _init_weights(self):
"""Initialize weights with Xavier uniform distribution"""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query: Query tensor [batch_size, seq_len_q, d_model]
key: Key tensor [batch_size, seq_len_k, d_model]
value: Value tensor [batch_size, seq_len_v, d_model]
mask: Mask tensor
Returns:
output: Multi-head attention output [batch_size, seq_len_q, d_model]
attention: Attention weights [batch_size, n_heads, seq_len_q, seq_len_k]
"""
batch_size = query.size(0)
seq_len_q = query.size(1) # Query sequence length
seq_len_k = key.size(1) # Key sequence length (can be different!)
seq_len_v = value.size(1) # Value sequence length (same as key)
# Pre-norm variant (if enabled)
if self.pre_norm:
query = self.layer_norm(query)
key = self.layer_norm(key)
value = self.layer_norm(value)
# Store residual
residual = query
# Linear projections - FIXED: Use correct sequence lengths
Q = self.W_q(query).view(batch_size, seq_len_q, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, seq_len_k, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, seq_len_v, self.n_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, attention_weights = self.attention(Q, K, V, mask)
# Concatenate heads - use seq_len_q for output
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, self.d_model
)
# Final linear projection
output = self.W_o(attn_output)
output = self.dropout(output)
# Add residual and normalize
output = output + residual
if not self.pre_norm:
output = self.layer_norm(output)
return output, attention_weights
def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
"""
Create padding mask for attention
Args:
seq: Input sequence [batch_size, seq_len]
pad_idx: Padding index
Returns:
mask: Padding mask [batch_size, 1, 1, seq_len]
"""
# Create boolean mask
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask.to(torch.bool)
def create_look_ahead_mask(size: int, device: torch.device) -> torch.Tensor:
"""
Create look-ahead mask for decoder self-attention
Args:
size: Sequence length
device: Device to create mask on
Returns:
mask: Look-ahead mask [1, 1, size, size]
"""
# Create upper triangular matrix
mask = torch.triu(torch.ones(size, size, device=device, dtype=torch.bool), diagonal=1)
# Invert it (1 for allowed positions, 0 for masked)
mask = ~mask
return mask.unsqueeze(0).unsqueeze(0)
def create_masks(src: torch.Tensor, tgt: torch.Tensor,
pad_idx: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create all masks needed for transformer
Args:
src: Source sequence [batch_size, src_len]
tgt: Target sequence [batch_size, tgt_len]
pad_idx: Padding index
Returns:
src_mask: Source padding mask
tgt_mask: Target mask (padding + look-ahead)
memory_mask: Memory mask for decoder cross-attention
"""
# Source mask (padding only)
src_mask = create_padding_mask(src, pad_idx)
# Target padding mask
tgt_pad_mask = create_padding_mask(tgt, pad_idx)
# Target look-ahead mask
tgt_len = tgt.size(1)
tgt_look_ahead_mask = create_look_ahead_mask(tgt_len, tgt.device)
# Combine padding and look-ahead masks for target
# Both masks should be True where attention is allowed
tgt_mask = tgt_pad_mask & tgt_look_ahead_mask
# Memory mask (same as source mask)
memory_mask = src_mask
return src_mask, tgt_mask, memory_mask
# Optional: Flash Attention wrapper (if available)
try:
from torch.nn.functional import scaled_dot_product_attention
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
class FlashAttention(nn.Module):
"""Flash Attention wrapper for better performance (if available)"""
def __init__(self, dropout: float = 0.1):
super().__init__()
self.dropout = dropout
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
"""
Uses PyTorch's scaled_dot_product_attention if available (includes Flash Attention)
"""
if FLASH_ATTENTION_AVAILABLE and mask is None:
# Use efficient implementation when no mask
output = scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False
)
return output, None
else:
# Fallback to standard implementation
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention = F.softmax(scores, dim=-1)
if self.training and self.dropout > 0:
attention = F.dropout(attention, p=self.dropout)
output = torch.matmul(attention, v)
return output, attention