from typing import Any, Dict, List, Tuple import os from time import strftime import numpy as np import pandas as pd import torch # import hydra # import rootutils # from lightning import LightningDataModule, LightningModule, Trainer # from lightning.pytorch.loggers import Logger from omegaconf import DictConfig # rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: # - adding project root dir to PYTHONPATH # (so you don't need to force user to install project as a package) # (necessary before importing any local modules e.g. `from src import utils`) # - setting up PROJECT_ROOT environment variable # (which is used as a base for paths in "configs/paths/default.yaml") # (this way all filepaths are the same no matter where you run the code) # - loading environment variables from ".env" in root dir # # you can remove it if you: # 1. either install project as a package or move entry files to project root dir # 2. set `root_dir` to "." in "configs/paths/default.yaml" # # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # from src.utils import ( RankedLogger, extras, instantiate_loggers, log_hyperparameters, task_wrapper, checkpoint_utils, plot_utils, ) from src.common.pdb_utils import extract_backbone_coords from src.metrics import metrics from src.common.geo_utils import _find_rigid_alignment log = RankedLogger(__name__, rank_zero_only=True) def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None): """Evaluate prediction results based on pdb files. """ if target_dir is None or not os.path.isdir(target_dir): log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.") return {} assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory." targets = [ d.replace(".pdb", "") for d in os.listdir(target_dir) ] # pred_bases = os.listdir(pred_dir) output_dir = pred_dir tag = tag if tag is not None else "dev" timestamp = strftime("%m%d-%H-%M") fns = { 'val_clash': metrics.validity, 'val_bond': metrics.bonding_validity, 'js_pwd': metrics.js_pwd, 'js_rg': metrics.js_rg, # 'js_tica_pos': metrics.js_tica_pos, 'w2_rmwd': metrics.w2_rmwd, # 'div_rmsd': metrics.div_rmsd, 'div_rmsf': metrics.div_rmsf, 'pro_w_contacks': metrics.pro_w_contacts, 'pro_t_contacks': metrics.pro_t_contacts, # 'pro_c_contacks': metrics.pro_c_contacts, } eval_res = {k: {} for k in fns} print(f"total_md_num = {len(targets)}") count = 0 for target in targets: count += 1 print("") print(count, target) pred_file = os.path.join(pred_dir, f"{target}.pdb") # assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist." if not os.path.isfile(pred_file): continue target_file = os.path.join(target_dir, f"{target}.pdb") ca_coords = { 'target': extract_backbone_coords(target_file), 'pred': extract_backbone_coords(pred_file), } cry_target_file = os.path.join(crystal_dir, f"{target}.pdb") cry_ca_coords = extract_backbone_coords(cry_target_file)[0] for f_name, func in fns.items(): print(f_name) if f_name == 'w2_rmwd': v_ref = torch.as_tensor(ca_coords['target'][0]) for k, v in ca_coords.items(): v = torch.as_tensor(v) # (250,356,3) for idx in range(v.shape[0]): R, t = _find_rigid_alignment(v[idx], v_ref) v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0) ca_coords[k] = v.numpy() if f_name.startswith('js_'): res = func(ca_coords, ref_key='target') elif f_name == 'pro_c_contacks': res = func(target_file, pred_file, cry_target_file) elif f_name.startswith('pro_'): res = func(ca_coords, cry_ca_coords) else: res = func(ca_coords) if f_name == 'js_tica' or f_name == 'js_tica_pos': pass # eval_res[f_name][target] = res[0]['pred'] # save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png") # plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target') else: eval_res[f_name][target] = res['pred'] csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv") df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name df.to_csv(csv_save_to) print(f"metrics saved to {csv_save_to}") mean_metrics = np.around(df.mean(), decimals=4) return mean_metrics # @task_wrapper # def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # """Sample on a test set and report evaluation metrics. # This method is wrapped in optional @task_wrapper decorator, that controls the behavior during # failure. Useful for multiruns, saving info about the crash, etc. # :param cfg: DictConfig configuration composed by Hydra. # :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. # """ # # assert cfg.ckpt_path # pred_dir = cfg.get("pred_dir") # if pred_dir and os.path.isdir(pred_dir): # log.info(f"Found pre-computed prediction directory {pred_dir}.") # metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir) # return metric_dict, None # log.info(f"Instantiating datamodule <{cfg.data._target_}>") # datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) # log.info(f"Instantiating model <{cfg.model._target_}>") # model: LightningModule = hydra.utils.instantiate(cfg.model) # log.info("Instantiating loggers...") # logger: List[Logger] = instantiate_loggers(cfg.get("logger")) # log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) # object_dict = { # "cfg": cfg, # "datamodule": datamodule, # "model": model, # "logger": logger, # "trainer": trainer, # } # if logger: # log.info("Logging hyperparameters!") # log_hyperparameters(object_dict) # # Load checkpoint manually. # model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path) # # log.info("Starting testing!") # # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) # # Get dataloader for prediction. # datamodule.setup(stage="predict") # dataloaders = datamodule.test_dataloader() # log.info("Starting predictions.") # pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1] # # metric_dict = trainer.callback_metrics # log.info("Starting evaluations.") # metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir) # return metric_dict, object_dict # @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") # def main(cfg: DictConfig) -> None: # """Main entry point for evaluation. # :param cfg: DictConfig configuration composed by Hydra. # """ # # apply extra utilities # # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) # extras(cfg) # evaluate(cfg) # if __name__ == "__main__": # main()