AReUReDi / smiles /rectify_train.py
Tong Chen
add files
295b1cd
import argparse
import math
import os
from functools import partial
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from rdkit import Chem
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from peptide_analyzer import PeptideAnalyzer
import dataloading_for_dynamic_batching as dynamic_dataloader
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[1]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_emb = emb.cos()[None, :, :]
sin_emb = emb.sin()[None, :, :]
return cos_emb, sin_emb
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# --- Model Architecture with RoPE ---
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
return self.mlp(t.unsqueeze(-1))
class MultiHeadAttentionWithRoPE(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
self.head_dim = hidden_size // n_heads
assert self.head_dim * n_heads == hidden_size, "hidden_size must be divisible by n_heads"
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(self, x):
batch_size, seq_len, hidden_size = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(q, seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
output = self.out_proj(attn_output)
return output
class DiTBlock(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = MultiHeadAttentionWithRoPE(hidden_size, n_heads)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output = self.attn(x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
def __init__(self, vocab_size, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.model_dim = model_dim
self.mask_token_id = vocab_size
self.token_embedder = nn.Embedding(vocab_size, model_dim)
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([
DiTBlock(model_dim, n_heads) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
if module.weight is not None:
module.weight.data.fill_(1.0)
def forward(self, x, t):
x_embed = self.token_embedder(x)
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
# --- PyTorch Lightning Module ---
class MDLMLightningModule(pl.LightningModule):
def __init__(self, args, tokenizer):
super().__init__()
self.save_hyperparameters(ignore=['tokenizer'])
self.args = args
self.tokenizer = tokenizer
self.peptide_analyzer = PeptideAnalyzer()
# Initialize model
self.model = MDLM(
vocab_size=tokenizer.vocab_size,
model_dim=args.model_dim,
n_heads=args.n_heads,
n_layers=args.n_layers
)
self.automatic_optimization = True
self.validation_step_outputs = []
# Track training progress
self.register_buffer('epoch_progress', torch.tensor(0.0))
def forward(self, x, t):
return self.model(x, t)
def _compute_invalid_loss(self, logits, t_continuous=None):
"""
Original invalid loss computation from PepTune
with optional time-dependent weighting
"""
batch_token_ids = torch.argmax(logits, dim=-1) # (batch_size, seq_length)
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
# Check validity using peptide analyzer
penalties = torch.tensor(
[1.0 if not self.peptide_analyzer.is_peptide(seq) else 0.0 for seq in sampled_sequences],
dtype=torch.float32,
device=self.device
) # (batch_size,)
# Optional: Apply time-dependent scaling
if t_continuous is not None and self.args.time_dependent_validity:
# Less penalty at early timesteps (when t is close to 0)
time_weight = t_continuous ** self.args.validity_time_power # Default power = 0.5
penalties = penalties * time_weight
# Get softmax probabilities for selected tokens
sampled_probs = torch.softmax(logits, dim=-1).gather(
dim=-1, index=batch_token_ids.unsqueeze(-1)
).squeeze(-1).to(self.device) # (batch_size, seq_length)
# Scale penalty by token probabilities (makes it differentiable)
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
return scaled_penalty
def get_validity_weight(self):
"""
Compute annealed validity weight based on training progress
"""
current_epoch = self.current_epoch
# Stage 1: No validity loss for first N epochs
if current_epoch < self.args.validity_start_epoch:
return 0.0
# Stage 2: Gradually increase validity weight
epochs_with_validity = current_epoch - self.args.validity_start_epoch
max_epochs_with_validity = self.args.epochs - self.args.validity_start_epoch
if self.args.validity_schedule == 'linear':
# Linear increase from min to max weight
progress = epochs_with_validity / max_epochs_with_validity
weight = (self.args.validity_weight_min +
(self.args.validity_weight_max - self.args.validity_weight_min) * progress)
elif self.args.validity_schedule == 'exponential':
# Exponential increase (starts slow, accelerates)
progress = epochs_with_validity / max_epochs_with_validity
weight = (self.args.validity_weight_min *
(self.args.validity_weight_max / self.args.validity_weight_min) ** progress)
elif self.args.validity_schedule == 'cosine':
# Cosine schedule (smooth increase)
progress = epochs_with_validity / max_epochs_with_validity
cosine_factor = 0.5 * (1 - math.cos(math.pi * progress))
weight = (self.args.validity_weight_min +
(self.args.validity_weight_max - self.args.validity_weight_min) * cosine_factor)
elif self.args.validity_schedule == 'step':
# Step-wise increase
steps = [0.25, 0.5, 0.75, 1.0]
weights = [self.args.validity_weight_min,
self.args.validity_weight_min * 2,
self.args.validity_weight_min * 5,
self.args.validity_weight_max]
progress = epochs_with_validity / max_epochs_with_validity
for i, step in enumerate(steps):
if progress <= step:
weight = weights[i]
break
else:
# Constant weight
weight = self.args.validity_weight_max
return weight
def _loss(self, logits, x_1, attn_mask, t_continuous=None):
"""
Combined loss with staged validity loss
"""
# Standard cross-entropy loss
ce_loss = F.cross_entropy(
logits.view(-1, self.model.vocab_size),
x_1.view(-1),
reduction='none'
).view(x_1.shape[0], -1)
# Get current validity weight
validity_weight = self.get_validity_weight()
# Compute invalid loss only if weight > 0
if validity_weight > 0:
invalid_loss = self._compute_invalid_loss(logits, t_continuous)
else:
invalid_loss = torch.zeros_like(ce_loss)
# Combine losses
total_loss = ce_loss + validity_weight * invalid_loss
# Apply attention mask
masked_loss = total_loss * attn_mask
num_tokens = attn_mask.sum()
token_nll = masked_loss.sum() / num_tokens
# Individual components for logging
ce_token_loss = (ce_loss * attn_mask).sum() / num_tokens
invalid_token_loss = (invalid_loss * attn_mask).sum() / num_tokens
return token_nll, ce_token_loss, invalid_token_loss, validity_weight
def training_step(self, batch, batch_idx):
x_0 = batch['source_ids'].to(self.device)
x_1 = batch['target_ids'].to(self.device)
attn_mask = torch.ones_like(x_1).to(self.device)
bond_mask = batch['bond_mask'].to(self.device).bool()
batch_size, _ = x_1.shape
# ReDi approach: random start -> target
t_continuous = torch.rand(batch_size, device=self.device)
# Bond-aware masking
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma
non_peptide_prob = t_continuous.view(-1, 1)
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob)
mask = torch.rand(x_1.shape, device=self.device) < masking_prob
x_t = torch.where(mask, x_1, x_0)
# Forward pass
logits = self.model(x_t, t_continuous)
# Compute loss with staged validity
token_nll, ce_loss, invalid_loss, validity_weight = self._loss(
logits, x_1, attn_mask, t_continuous
)
# Extensive logging
self.log('train/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True)
self.log('train/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
self.log('train/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
self.log('train/validity_weight', validity_weight, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True)
# Log gradient norm for debugging
if batch_idx % 1000 == 0:
total_norm = 0
for p in self.model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log('train/grad_norm', total_norm, batch_size=batch_size, sync_dist=True)
return token_nll
def validation_step(self, batch, batch_idx):
x_0 = batch['source_ids'].to(self.device)
x_1 = batch['target_ids'].to(self.device)
attn_mask = torch.ones_like(x_1).to(self.device)
bond_mask = batch['bond_mask'].to(self.device).bool()
batch_size, _ = x_1.shape
# Same masking as training
t_continuous = torch.rand(batch_size, device=self.device)
peptide_bond_prob = t_continuous.view(-1, 1) ** self.args.gamma
non_peptide_prob = t_continuous.view(-1, 1)
masking_prob = torch.where(bond_mask, peptide_bond_prob, non_peptide_prob)
mask = torch.rand(x_1.shape, device=self.device) < masking_prob
x_t = torch.where(mask, x_1, x_0)
logits = self.model(x_t, t_continuous)
token_nll, ce_loss, invalid_loss, validity_weight = self._loss(
logits, x_1, attn_mask, t_continuous
)
self.log('val/token_nll', token_nll.item(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, sync_dist=True)
self.log('val/ce_loss', ce_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
self.log('val/invalid_loss', invalid_loss.item(), on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
# Sample and check validity at different timesteps
if batch_idx == 0:
with torch.no_grad():
validity_results = {}
for t_val in [0.9, 0.5, 0.1]: # Different timesteps
t_test = torch.full((batch_size,), t_val, device=self.device)
test_mask = torch.rand(x_1.shape, device=self.device) < t_val
x_test = torch.where(test_mask, x_1, x_0)
test_logits = self.model(x_test, t_test)
test_preds = torch.argmax(test_logits, dim=-1)
sequences = self.tokenizer.batch_decode(test_preds)
valid_count = sum(1 for seq in sequences if self.peptide_analyzer.is_peptide(seq))
validity_rate = valid_count / len(sequences)
self.log(f'val/validity_rate_t{t_val}', validity_rate, batch_size=batch_size, sync_dist=True)
def configure_optimizers(self):
optimizer = AdamW(
self.parameters(),
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay
)
# Calculate total steps
if hasattr(self.trainer, 'estimated_stepping_batches'):
num_training_steps = self.trainer.estimated_stepping_batches
else:
num_training_steps = len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs
warmup_steps = int(num_training_steps * 0.1)
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup
lr_factor = current_step / warmup_steps
return lr_factor
else:
# Cosine decay with min LR
progress = (current_step - warmup_steps) / (num_training_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
min_lr_ratio = 0.1
return min_lr_ratio + (1 - min_lr_ratio) * cosine_decay
scheduler = LambdaLR(optimizer, lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
},
}
def main(args):
# Set up checkpoint directory
checkpoint_dir = (args.checkpoint_dir +
f"new_lr{args.learning_rate}_layer{args.n_layers}_"
f"head{args.n_heads}_{args.validity_schedule}")
print(f"Saving to {checkpoint_dir}")
os.makedirs(checkpoint_dir, exist_ok=True)
print("Loading tokenizer...")
tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt',
'/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt')
print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
# Initialize data module
data_module = dynamic_dataloader.RectifyDataModule('/scratch/pranamlab/tong/data/smiles/v1')
model = MDLMLightningModule(args, tokenizer)
model = MDLMLightningModule.load_from_checkpoint(
checkpoint_path=args.checkpoint,
args=args,
tokenizer=tokenizer
)
# Set up logger
logger = WandbLogger(
project="smiles-redi-staged-training",
entity="programmablebio",
name=f"v1_lr{args.learning_rate}_epochs{args.validity_start_epoch}_{args.validity_schedule}",
save_dir=checkpoint_dir
)
# Set up callbacks
callbacks = [
ModelCheckpoint(
dirpath=checkpoint_dir,
filename='best',
monitor='val/token_nll',
mode='min',
save_top_k=1,
save_last=True,
# every_n_train_steps=5000
),
# Save every epoch
ModelCheckpoint(
dirpath=checkpoint_dir,
filename='{epoch:02d}',
save_top_k=-1,
every_n_epochs=1,
save_on_train_epoch_end=True
),
LearningRateMonitor(logging_interval='step')
]
# Initialize trainer
trainer = pl.Trainer(
max_epochs=args.epochs,
devices=torch.cuda.device_count(),
accelerator='gpu',
strategy=DDPStrategy(find_unused_parameters=False),
num_nodes=int(os.environ.get("SLURM_NNODES", 1)),
precision="bf16",
gradient_clip_val=args.grad_clip if args.grad_clip > 0 else None,
callbacks=callbacks,
logger=logger,
log_every_n_steps=100,
check_val_every_n_epoch=None,
# val_check_interval=5000,
accumulate_grad_batches=1,
enable_progress_bar=True,
enable_model_summary=True
)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")
print(f"Training strategy: CE-only for {args.validity_start_epoch} epochs, then staged validity loss")
print("Starting training...")
# Train the model
trainer.fit(model, data_module)
print("Training complete.")
print(f"Best checkpoint saved at: {trainer.checkpoint_callback.best_model_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train ReDi model with staged validity loss")
# Model arguments
parser.add_argument("--model_dim", type=int, default=1024)
parser.add_argument("--n_heads", type=int, default=8)
parser.add_argument("--n_layers", type=int, default=6)
# Training arguments
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument("--label_smoothing", type=float, default=0)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--gamma", type=float, default=2.0)
# Staged validity arguments
parser.add_argument("--validity_start_epoch", type=int, default=2, help="Epoch to start adding validity loss (0-indexed)")
parser.add_argument("--validity_weight_min", type=float, default=10.0, help="Initial validity weight when starting")
parser.add_argument("--validity_weight_max", type=float, default=200.0, help="Maximum validity weight")
parser.add_argument("--validity_schedule", type=str, default="linear", choices=['linear', 'exponential', 'cosine', 'step', 'constant'], help="Schedule for increasing validity weight")
parser.add_argument("--time_dependent_validity", type=bool, default=False, help="Whether to apply time-dependent scaling to validity loss")
parser.add_argument("--validity_time_power", type=float, default=0.5, help="Power for time-dependent validity scaling")
# Other arguments
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints_smiles")
parser.add_argument("--checkpoint", type=str, required=True)
args = parser.parse_args()
main(args)