|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
from smiles_train import MDLMLightningModule, PeptideAnalyzer |
|
|
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
def generate_smiles(model, tokenizer, args): |
|
|
""" |
|
|
Generates peptide SMILES strings using the trained MDLM model |
|
|
with a forward (t=0 to t=1) flow matching process. |
|
|
|
|
|
Args: |
|
|
model (MDLMLightningModule): The trained PyTorch Lightning model. |
|
|
tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. |
|
|
args (argparse.Namespace): Command-line arguments containing sampling parameters. |
|
|
|
|
|
Returns: |
|
|
list[str]: A list of generated SMILES strings. |
|
|
float: The validity rate of the generated SMILES. |
|
|
""" |
|
|
print("Starting SMILES generation with forward flow matching (t=0 to t=1)...") |
|
|
model.eval() |
|
|
device = args.device |
|
|
|
|
|
|
|
|
x = torch.randint( |
|
|
0, |
|
|
model.model.vocab_size, |
|
|
(args.n_samples, args.seq_len), |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"): |
|
|
t_curr = time_steps[i] |
|
|
t_next = time_steps[i+1] |
|
|
|
|
|
|
|
|
t_tensor = torch.full((args.n_samples,), t_curr, device=device) |
|
|
|
|
|
|
|
|
logits = model(x, t_tensor) |
|
|
logits = logits / args.temperature |
|
|
|
|
|
pred_x1 = torch.argmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
if i == args.n_steps - 1: |
|
|
x = pred_x1 |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
noise_prob = 1.0 - t_next |
|
|
mask = torch.rand(x.shape, device=device) < noise_prob |
|
|
|
|
|
|
|
|
noise = torch.randint( |
|
|
0, |
|
|
model.model.vocab_size, |
|
|
x.shape, |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
x = torch.where(mask, noise, pred_x1) |
|
|
|
|
|
|
|
|
generated_sequences = tokenizer.batch_decode(x) |
|
|
|
|
|
|
|
|
peptide_analyzer = PeptideAnalyzer() |
|
|
valid_count = 0 |
|
|
valid_smiles = [] |
|
|
for seq in generated_sequences: |
|
|
if peptide_analyzer.is_peptide(seq): |
|
|
valid_count += 1 |
|
|
valid_smiles.append(seq) |
|
|
|
|
|
validity_rate = valid_count / len(generated_sequences) |
|
|
|
|
|
print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}") |
|
|
return valid_smiles, validity_rate |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.") |
|
|
|
|
|
|
|
|
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") |
|
|
|
|
|
|
|
|
parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.") |
|
|
parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.") |
|
|
parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.") |
|
|
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.") |
|
|
|
|
|
|
|
|
parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") |
|
|
parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") |
|
|
parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
args.device = device |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) |
|
|
|
|
|
print(f"Loading model from checkpoint: {args.checkpoint_path}") |
|
|
|
|
|
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
|
|
model_hparams = checkpoint["hyper_parameters"]["args"] |
|
|
|
|
|
|
|
|
model = MDLMLightningModule.load_from_checkpoint( |
|
|
args.checkpoint_path, |
|
|
args=model_hparams, |
|
|
tokenizer=tokenizer, |
|
|
map_location=device, |
|
|
strict=False |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
valid_smiles, validity_rate = generate_smiles(model, tokenizer, args) |
|
|
|
|
|
|
|
|
|
|
|
with open('./v0_samples_200.csv', 'a') as f: |
|
|
for smiles in valid_smiles: |
|
|
|
|
|
f.write(smiles + '\n') |
|
|
print(validity_rate) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |