import argparse import math import os from functools import partial from collections import Counter import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_from_disk from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DDPStrategy from rdkit import Chem from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer from peptide_analyzer import PeptideAnalyzer import dataloading_for_dynamic_batching as dynamic_dataloader class RotaryPositionalEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x, seq_len=None): if seq_len is None: seq_len = x.shape[1] t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos_emb = emb.cos()[None, :, :] sin_emb = emb.sin()[None, :, :] return cos_emb, sin_emb def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # --- Model Architecture with RoPE --- def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size): super().__init__() self.mlp = nn.Sequential( nn.Linear(1, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) def forward(self, t): return self.mlp(t.unsqueeze(-1)) class MultiHeadAttentionWithRoPE(nn.Module): def __init__(self, hidden_size, n_heads): super().__init__() self.hidden_size = hidden_size self.n_heads = n_heads self.head_dim = hidden_size // n_heads assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads" self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.out_proj = nn.Linear(hidden_size, hidden_size) self.rope = RotaryPositionalEmbedding(self.head_dim) def forward(self, x): batch_size, seq_len, hidden_size = x.shape q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(q, seq_len) q, k = apply_rotary_pos_emb(q, k, cos, sin) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) output = self.out_proj(attn_output) return output class DiTBlock(nn.Module): def __init__(self, hidden_size, n_heads): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), nn.Linear(4 * hidden_size, hidden_size) ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) attn_output = self.attn(x_norm1) x = x + gate_msa.unsqueeze(1) * attn_output x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) mlp_output = self.mlp(x_norm2) x = x + gate_mlp.unsqueeze(1) * mlp_output return x class MDLM(nn.Module): def __init__(self, vocab_size, model_dim, n_heads, n_layers): super().__init__() self.vocab_size = vocab_size self.model_dim = model_dim self.mask_token_id = vocab_size self.token_embedder = nn.Embedding(vocab_size, model_dim) self.time_embedder = TimestepEmbedder(model_dim) self.transformer_blocks = nn.ModuleList([ DiTBlock(model_dim, n_heads) for _ in range(n_layers) ]) self.final_norm = nn.LayerNorm(model_dim) self.lm_head = nn.Linear(model_dim, vocab_size) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): if module.bias is not None: module.bias.data.zero_() if module.weight is not None: module.weight.data.fill_(1.0) def forward(self, x, t): x_embed = self.token_embedder(x) t_embed = self.time_embedder(t) for block in self.transformer_blocks: x_embed = block(x_embed, t_embed) x_embed = self.final_norm(x_embed) logits = self.lm_head(x_embed) return logits # --- PyTorch Lightning Module --- class MDLMLightningModule(pl.LightningModule): def __init__(self, args, tokenizer): super().__init__() self.save_hyperparameters(ignore=['tokenizer']) self.args = args self.tokenizer = tokenizer self.peptide_analyzer = PeptideAnalyzer() # Initialize model self.model = MDLM( vocab_size=tokenizer.vocab_size, model_dim=args.model_dim, n_heads=args.n_heads, n_layers=args.n_layers ) self.automatic_optimization = True self.validation_step_outputs = [] # Track training progress self.register_buffer('epoch_progress', torch.tensor(0.0)) def forward(self, x, t): return self.model(x, t) def _compute_invalid_loss(self, logits, t_continuous=None): """ Original invalid loss computation from PepTune with optional time-dependent weighting """ batch_token_ids = torch.argmax(logits, dim=-1) # (batch_size, seq_length) sampled_sequences = self.tokenizer.batch_decode(batch_token_ids) # Check validity using peptide analyzer penalties = torch.tensor( [1.0 if not self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences], dtype=torch.float32, device=self.device ) # (batch_size,) # Optional: Apply time-dependent scaling if t_continuous is not None and self.args.time_dependent_validity: # Less penalty at early timesteps (when t is close to 0) time_weight = t_continuous ** self.args.validity_time_power # Default power = 0.5 penalties = penalties * time_weight # Get softmax probabilities for selected tokens sampled_probs = torch.softmax(logits, dim=-1).gather( dim=-1, index=batch_token_ids.unsqueeze(-1) ).squeeze(-1).to(self.device) # (batch_size, seq_length) # Scale penalty by token probabilities (makes it differentiable) scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length) return scaled_penalty def get_validity_weight(self): """ Compute annealed validity weight based on training progress """ current_epoch = self.current_epoch # Stage 1: No validity loss for first N epochs if current_epoch < self.args.validity_start_epoch: return 0.0 # Stage 2: Gradually increase validity weight epochs_with_validity = current_epoch - self.args.validity_start_epoch max_epochs_with_validity = self.args.epochs - self.args.validity_start_epoch if self.args.validity_schedule == 'linear': # Linear increase from min to max weight progress = epochs_with_validity / max_epochs_with_validity weight = (self.args.validity_weight_min + (self.args.validity_weight_max - self.args.validity_weight_min) * progress) elif self.args.validity_schedule == 'exponential': # Exponential increase (starts slow, accelerates) progress = epochs_with_validity / max_epochs_with_validity weight = (self.args.validity_weight_min * (self.args.validity_weight_max / self.args.validity_weight_min) ** progress) elif self.args.validity_schedule == 'cosine': # Cosine schedule (smooth increase) progress = epochs_with_validity / max_epochs_with_validity cosine_factor = 0.5 * (1 - math.cos(math.pi * progress)) weight = (self.args.validity_weight_min + (self.args.validity_weight_max - self.args.validity_weight_min) * cosine_factor) elif self.args.validity_schedule == 'step': # Step-wise increase steps = [0.25, 0.5, 0.75, 1.0] weights = [self.args.validity_weight_min, self.args.validity_weight_min * 2, self.args.validity_weight_min * 5, self.args.validity_weight_max] progress = epochs_with_validity / max_epochs_with_validity for i, step in enumerate(steps): if progress <= step: weight = weights[i] break else: # Constant weight weight = self.args.validity_weight_max return weight def _loss(self, logits, x_1, attn_mask, t_continuous=None): """ Combined loss with staged validity loss """ # Standard cross-entropy loss ce_loss = F.cross_entropy( logits.view(-1, self.model.vocab_size), x_1.view(-1), reduction='none' ).view(x_1.shape[0], -1) # Get current validity weight validity_weight = self.get_validity_weight() # Compute invalid loss only if weight > 0 if validity_weight > 0: invalid_loss = self._compute_invalid_loss(logits, t_continuous) else: invalid_loss = torch.zeros_like(ce_loss) # Combine losses total_loss = ce_loss + validity_weight * invalid_loss # Apply attention mask masked_loss = total_loss * attn_mask num_tokens = attn_mask.sum() token_nll = masked_loss.sum() / num_tokens # Individual components for logging ce_token_loss = (ce_loss * attn_mask).sum() / num_tokens invalid_token_loss = (invalid_loss * attn_mask).sum() / num_tokens return token_nll, ce_token_loss, invalid_token_loss, validity_weight def training_step(self, batch, batch_idx): x_0 = batch['source_ids'].to(self.device) x_1 = batch['target_ids'].to(self.device) attn_mask = torch.ones_like(x_1).to(self.device) bond_mask = batch['bond_mask'].to(self.device).bool() batch_size, _ = x_1.shape # ReDi approach: random start -> target t_continuous = torch.rand(batch_size, device=self.device) # Bond-aware masking peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma non_peptide_prob = t_continuous.view(-1, 1) masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) mask = torch.rand(x_1.shape, device=self.device) < masking_prob x_t = torch.where(mask, x_1, x_0) # Forward pass logits = self.model(x_t, t_continuous) # Compute loss with staged validity token_nll, ce_loss, invalid_loss, validity_weight = self._loss( logits, x_1, attn_mask, t_continuous ) # Extensive logging self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True) # Log gradient norm for debugging if batch_idx % 1000 == 0: total_norm = 0 for p in self.model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 self.log('train/grad_norm', total_norm, batch_size=batch_size, sync_dist=True) return token_nll def validation_step(self, batch, batch_idx): x_0 = batch['source_ids'].to(self.device) x_1 = batch['target_ids'].to(self.device) attn_mask = torch.ones_like(x_1).to(self.device) bond_mask = batch['bond_mask'].to(self.device).bool() batch_size, _ = x_1.shape # Same masking as training t_continuous = torch.rand(batch_size, device=self.device) peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma non_peptide_prob = t_continuous.view(-1, 1) masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob) mask = torch.rand(x_1.shape, device=self.device) < masking_prob x_t = torch.where(mask, x_1, x_0) logits = self.model(x_t, t_continuous) token_nll, ce_loss, invalid_loss, validity_weight = self._loss( logits, x_1, attn_mask, t_continuous ) self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True) self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True) # Sample and check validity at different timesteps if batch_idx == 0: with torch.no_grad(): validity_results = {} for t_val in [0.9, 0.5, 0.1]: # Different timesteps t_test = torch.full((batch_size,), t_val, device=self.device) test_mask = torch.rand(x_1.shape, device=self.device) < t_val x_test = torch.where(test_mask, x_1, x_0) test_logits = self.model(x_test, t_test) test_preds = torch.argmax(test_logits, dim=-1) sequences = self.tokenizer.batch_decode(test_preds) valid_count = sum(1 for seq in sequences if self.peptide_analyzer.is_peptide(seq)) validity_rate = valid_count / len(sequences) self.log(f'val/validity_rate_t{t_val}', validity_rate, batch_size=batch_size, sync_dist=True) def configure_optimizers(self): optimizer = AdamW( self.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay ) # Calculate total steps if hasattr(self.trainer, 'estimated_stepping_batches'): num_training_steps = self.trainer.estimated_stepping_batches else: num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs warmup_steps = int(num_training_steps * 0.1) def lr_lambda(current_step): if current_step < warmup_steps: # Linear warmup lr_factor = current_step / warmup_steps return lr_factor else: # Cosine decay with min LR progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps) cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) min_lr_ratio = 0.1 return min_lr_ratio + (1 - min_lr_ratio) * cosine_decay scheduler = LambdaLR(optimizer, lr_lambda) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", "frequency": 1, }, } def main(args): # Set up checkpoint directory checkpoint_dir = (args.checkpoint_dir + f"new_lr{args.learning_rate}_layer{args.n_layers}_" f"head{args.n_heads}_{args.validity_schedule}") print(f"Saving to {checkpoint_dir}") os.makedirs(checkpoint_dir, exist_ok=True) print("Loading tokenizer...") tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}") # Initialize data module data_module = dynamic_dataloader.RectifyDataModule('/scratch/pranamlab/tong/data/smiles/v1') model = MDLMLightningModule(args, tokenizer) model = MDLMLightningModule.load_from_checkpoint( checkpoint_path=args.checkpoint, args=args, tokenizer=tokenizer ) # Set up logger logger = WandbLogger( project="smiles-redi-staged-training", entity="programmablebio", name=f"v1_lr{args.learning_rate}_epochs{args.validity_start_epoch}_{args.validity_schedule}", save_dir=checkpoint_dir ) # Set up callbacks callbacks = [ ModelCheckpoint( dirpath=checkpoint_dir, filename='best', monitor='val/token_nll', mode='min', save_top_k=1, save_last=True, # every_n_train_steps=5000 ), # Save every epoch ModelCheckpoint( dirpath=checkpoint_dir, filename='{epoch:02d}', save_top_k=-1, every_n_epochs=1, save_on_train_epoch_end=True ), LearningRateMonitor(logging_interval='step') ] # Initialize trainer trainer = pl.Trainer( max_epochs=args.epochs, devices=torch.cuda.device_count(), accelerator='gpu', strategy=DDPStrategy(find_unused_parameters=False), num_nodes=int(os.environ.get("SLURM_NNODES", 1)), precision="bf16", gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None, callbacks=callbacks, logger=logger, log_every_n_steps=100, check_val_every_n_epoch=None, # val_check_interval=5000, accumulate_grad_batches=1, enable_progress_bar=True, enable_model_summary=True ) print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") print(f"Training strategy: CE-only for {args.validity_start_epoch} epochs, then staged validity loss") print("Starting training...") # Train the model trainer.fit(model, data_module) print("Training complete.") print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train ReDi model with staged validity loss") # Model arguments parser.add_argument("--model_dim", type=int, default=1024) parser.add_argument("--n_heads", type=int, default=8) parser.add_argument("--n_layers", type=int, default=6) # Training arguments parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=1e-5) parser.add_argument("--label_smoothing", type=float, default=0) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--gamma", type=float, default=2.0) # Staged validity arguments parser.add_argument("--validity_start_epoch", type=int, default=2, help="Epoch to start adding validity loss (0-indexed)") parser.add_argument("--validity_weight_min", type=float, default=10.0, help="Initial validity weight when starting") parser.add_argument("--validity_weight_max", type=float, default=200.0, help="Maximum validity weight") parser.add_argument("--validity_schedule", type=str, default="linear", choices=['linear', 'exponential', 'cosine', 'step', 'constant'], help="Schedule for increasing validity weight") parser.add_argument("--time_dependent_validity", type=bool, default=False, help="Whether to apply time-dependent scaling to validity loss") parser.add_argument("--validity_time_power", type=float, default=0.5, help="Power for time-dependent validity scaling") # Other arguments parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles") parser.add_argument("--checkpoint", type=str, required=True) args = parser.parse_args() main(args)