AReUReDi / smiles /generation.py
Tong Chen
add files
295b1cd
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from tqdm import tqdm
# Import necessary classes from your training script
from smiles_train import MDLMLightningModule, PeptideAnalyzer
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
import pdb
def generate_smiles(model, tokenizer, args):
"""
Generates peptide SMILES strings using the trained MDLM model
with a forward (t=0 to t=1) flow matching process.
Args:
model (MDLMLightningModule): The trained PyTorch Lightning model.
tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training.
args (argparse.Namespace): Command-line arguments containing sampling parameters.
Returns:
list[str]: A list of generated SMILES strings.
float: The validity rate of the generated SMILES.
"""
print("Starting SMILES generation with forward flow matching (t=0 to t=1)...")
model.eval()
device = args.device
# 1. Start with a tensor of random tokens (pure noise at t=0)
x = torch.randint(
0,
model.model.vocab_size,
(args.n_samples, args.seq_len),
device=device
)
# 2. Define the time schedule for the forward process (0.0 to 1.0)
time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device)
# 3. Iteratively follow the flow from noise to data
with torch.no_grad():
for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"):
t_curr = time_steps[i]
t_next = time_steps[i+1]
# Prepare the current timestep tensor for the model
t_tensor = torch.full((args.n_samples,), t_curr, device=device)
# Get the model's prediction for the final clean sequence (at t=1)
logits = model(x, t_tensor)
logits = logits / args.temperature
pred_x1 = torch.argmax(logits, dim=-1)
# On the last step, the result is the final prediction
if i == args.n_steps - 1:
x = pred_x1
break
# --- Construct the next state x_{t_next} ---
# The probability of a token being noise at time t_next is (1 - t_next).
noise_prob = 1.0 - t_next
mask = torch.rand(x.shape, device=device) < noise_prob
# Generate new random tokens for the noise positions
noise = torch.randint(
0,
model.model.vocab_size,
x.shape,
device=device
)
# Combine the final prediction with noise to form the next intermediate state
x = torch.where(mask, noise, pred_x1)
# 4. Decode the final token IDs into SMILES strings
generated_sequences = tokenizer.batch_decode(x)
# 5. Analyze the validity of the generated sequences
peptide_analyzer = PeptideAnalyzer()
valid_count = 0
valid_smiles = []
for seq in generated_sequences:
if peptide_analyzer.is_peptide(seq):
valid_count += 1
valid_smiles.append(seq)
validity_rate = valid_count / len(generated_sequences)
print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}")
return valid_smiles, validity_rate
def main():
parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.")
# --- Required Arguments ---
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).")
# --- Sampling Arguments ---
parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.")
parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.")
parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.")
# --- Environment Arguments ---
parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.")
parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.")
parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.")
args = parser.parse_args()
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
args.device = device
print(f"Using device: {device}")
# --- Load Model and Tokenizer ---
print("Loading tokenizer...")
tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path)
print(f"Loading model from checkpoint: {args.checkpoint_path}")
# Load hyperparameters from the checkpoint to ensure model architecture matches
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False)
model_hparams = checkpoint["hyper_parameters"]["args"]
# Instantiate the model with the loaded hyperparameters
model = MDLMLightningModule.load_from_checkpoint(
args.checkpoint_path,
args=model_hparams,
tokenizer=tokenizer,
map_location=device,
strict=False # Recommended if you have updated the code since training
)
model.to(device)
# --- Generate SMILES ---
valid_smiles, validity_rate = generate_smiles(model, tokenizer, args)
# pdb.set_trace()
with open('./v0_samples_200.csv', 'a') as f:
for smiles in valid_smiles:
# print(smiles)
f.write(smiles + '\n')
print(validity_rate)
if __name__ == "__main__":
main()