import argparse from pathlib import Path import torch import torch.nn.functional as F from tqdm import tqdm # Import necessary classes from your training script from smiles_train import MDLMLightningModule, PeptideAnalyzer from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer import pdb def generate_smiles(model, tokenizer, args): """ Generates peptide SMILES strings using the trained MDLM model with a forward (t=0 to t=1) flow matching process. Args: model (MDLMLightningModule): The trained PyTorch Lightning model. tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. args (argparse.Namespace): Command-line arguments containing sampling parameters. Returns: list[str]: A list of generated SMILES strings. float: The validity rate of the generated SMILES. """ print("Starting SMILES generation with forward flow matching (t=0 to t=1)...") model.eval() device = args.device # 1. Start with a tensor of random tokens (pure noise at t=0) x = torch.randint( 0, model.model.vocab_size, (args.n_samples, args.seq_len), device=device ) # 2. Define the time schedule for the forward process (0.0 to 1.0) time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device) # 3. Iteratively follow the flow from noise to data with torch.no_grad(): for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"): t_curr = time_steps[i] t_next = time_steps[i+1] # Prepare the current timestep tensor for the model t_tensor = torch.full((args.n_samples,), t_curr, device=device) # Get the model's prediction for the final clean sequence (at t=1) logits = model(x, t_tensor) logits = logits / args.temperature pred_x1 = torch.argmax(logits, dim=-1) # On the last step, the result is the final prediction if i == args.n_steps - 1: x = pred_x1 break # --- Construct the next state x_{t_next} --- # The probability of a token being noise at time t_next is (1 - t_next). noise_prob = 1.0 - t_next mask = torch.rand(x.shape, device=device) < noise_prob # Generate new random tokens for the noise positions noise = torch.randint( 0, model.model.vocab_size, x.shape, device=device ) # Combine the final prediction with noise to form the next intermediate state x = torch.where(mask, noise, pred_x1) # 4. Decode the final token IDs into SMILES strings generated_sequences = tokenizer.batch_decode(x) # 5. Analyze the validity of the generated sequences peptide_analyzer = PeptideAnalyzer() valid_count = 0 valid_smiles = [] for seq in generated_sequences: if peptide_analyzer.is_peptide(seq): valid_count += 1 valid_smiles.append(seq) validity_rate = valid_count / len(generated_sequences) print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}") return valid_smiles, validity_rate def main(): parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.") # --- Required Arguments --- parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") # --- Sampling Arguments --- parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.") parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.") parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.") parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.") # --- Environment Arguments --- parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.") args = parser.parse_args() # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" args.device = device print(f"Using device: {device}") # --- Load Model and Tokenizer --- print("Loading tokenizer...") tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) print(f"Loading model from checkpoint: {args.checkpoint_path}") # Load hyperparameters from the checkpoint to ensure model architecture matches checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) model_hparams = checkpoint["hyper_parameters"]["args"] # Instantiate the model with the loaded hyperparameters model = MDLMLightningModule.load_from_checkpoint( args.checkpoint_path, args=model_hparams, tokenizer=tokenizer, map_location=device, strict=False # Recommended if you have updated the code since training ) model.to(device) # --- Generate SMILES --- valid_smiles, validity_rate = generate_smiles(model, tokenizer, args) # pdb.set_trace() with open('./v0_samples_200.csv', 'a') as f: for smiles in valid_smiles: # print(smiles) f.write(smiles + '\n') print(validity_rate) if __name__ == "__main__": main()