|
|
import copy |
|
|
from functools import partial |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import pickle |
|
|
from typing import Optional, Sequence, List, Any |
|
|
|
|
|
import ml_collections as mlc |
|
|
import numpy as np |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from torch.utils.data import RandomSampler |
|
|
|
|
|
from openfold.data import ( |
|
|
data_pipeline, |
|
|
feature_pipeline, |
|
|
mmcif_parsing, |
|
|
templates, |
|
|
) |
|
|
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap |
|
|
|
|
|
|
|
|
class OpenFoldSingleDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, |
|
|
data_dir: str, |
|
|
alignment_dir: str, |
|
|
template_mmcif_dir: str, |
|
|
max_template_date: str, |
|
|
config: mlc.ConfigDict, |
|
|
kalign_binary_path: str = '/usr/bin/kalign', |
|
|
max_template_hits: int = 4, |
|
|
obsolete_pdbs_file_path: Optional[str] = None, |
|
|
template_release_dates_cache_path: Optional[str] = None, |
|
|
shuffle_top_k_prefiltered: Optional[int] = None, |
|
|
treat_pdb_as_distillation: bool = True, |
|
|
mapping_path: Optional[str] = None, |
|
|
mode: str = "train", |
|
|
_output_raw: bool = False, |
|
|
_alignment_index: Optional[Any] = None |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
data_dir: |
|
|
A path to a directory containing mmCIF files (in train |
|
|
mode) or FASTA files (in inference mode). |
|
|
alignment_dir: |
|
|
A path to a directory containing only data in the format |
|
|
output by an AlignmentRunner |
|
|
(defined in openfold.features.alignment_runner). |
|
|
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID} |
|
|
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr |
|
|
files. |
|
|
template_mmcif_dir: |
|
|
Path to a directory containing template mmCIF files. |
|
|
config: |
|
|
A dataset config object. See openfold.config |
|
|
kalign_binary_path: |
|
|
Path to kalign binary. |
|
|
max_template_hits: |
|
|
An upper bound on how many templates are considered. During |
|
|
training, the templates ultimately used are subsampled |
|
|
from this total quantity. |
|
|
template_release_dates_cache_path: |
|
|
Path to the output of scripts/generate_mmcif_cache. |
|
|
obsolete_pdbs_file_path: |
|
|
Path to the file containing replacements for obsolete PDBs. |
|
|
shuffle_top_k_prefiltered: |
|
|
Whether to uniformly shuffle the top k template hits before |
|
|
parsing max_template_hits of them. Can be used to |
|
|
approximate DeepMind's training-time template subsampling |
|
|
scheme much more performantly. |
|
|
treat_pdb_as_distillation: |
|
|
Whether to assume that .pdb files in the data_dir are from |
|
|
the self-distillation set (and should be subjected to |
|
|
special distillation set preprocessing steps). |
|
|
mode: |
|
|
"train", "val", or "predict" |
|
|
""" |
|
|
super(OpenFoldSingleDataset, self).__init__() |
|
|
self.data_dir = data_dir |
|
|
self.alignment_dir = alignment_dir |
|
|
self.config = config |
|
|
self.treat_pdb_as_distillation = treat_pdb_as_distillation |
|
|
self.mode = mode |
|
|
self._output_raw = _output_raw |
|
|
self._alignment_index = _alignment_index |
|
|
|
|
|
valid_modes = ["train", "eval", "predict"] |
|
|
if(mode not in valid_modes): |
|
|
raise ValueError(f'mode must be one of {valid_modes}') |
|
|
|
|
|
if(template_release_dates_cache_path is None): |
|
|
logging.warning( |
|
|
"Template release dates cache does not exist. Remember to run " |
|
|
"scripts/generate_mmcif_cache.py before running OpenFold" |
|
|
) |
|
|
|
|
|
if(_alignment_index is not None): |
|
|
self._chain_ids = list(_alignment_index.keys()) |
|
|
elif(mapping_path is None): |
|
|
self._chain_ids = list(os.listdir(alignment_dir)) |
|
|
else: |
|
|
with open(mapping_path, "r") as f: |
|
|
self._chain_ids = [l.strip() for l in f.readlines()] |
|
|
|
|
|
self._chain_id_to_idx_dict = { |
|
|
chain: i for i, chain in enumerate(self._chain_ids) |
|
|
} |
|
|
|
|
|
template_featurizer = templates.TemplateHitFeaturizer( |
|
|
mmcif_dir=template_mmcif_dir, |
|
|
max_template_date=max_template_date, |
|
|
max_hits=max_template_hits, |
|
|
kalign_binary_path=kalign_binary_path, |
|
|
release_dates_path=template_release_dates_cache_path, |
|
|
obsolete_pdbs_path=obsolete_pdbs_file_path, |
|
|
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered, |
|
|
) |
|
|
|
|
|
self.data_pipeline = data_pipeline.DataPipeline( |
|
|
template_featurizer=template_featurizer, |
|
|
) |
|
|
|
|
|
if(not self._output_raw): |
|
|
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) |
|
|
|
|
|
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): |
|
|
with open(path, 'r') as f: |
|
|
mmcif_string = f.read() |
|
|
|
|
|
mmcif_object = mmcif_parsing.parse( |
|
|
file_id=file_id, mmcif_string=mmcif_string |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if(mmcif_object.mmcif_object is None): |
|
|
raise list(mmcif_object.errors.values())[0] |
|
|
|
|
|
mmcif_object = mmcif_object.mmcif_object |
|
|
|
|
|
data = self.data_pipeline.process_mmcif( |
|
|
mmcif=mmcif_object, |
|
|
alignment_dir=alignment_dir, |
|
|
chain_id=chain_id, |
|
|
_alignment_index=_alignment_index |
|
|
) |
|
|
|
|
|
return data |
|
|
|
|
|
def chain_id_to_idx(self, chain_id): |
|
|
return self._chain_id_to_idx_dict[chain_id] |
|
|
|
|
|
def idx_to_chain_id(self, idx): |
|
|
return self._chain_ids[idx] |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
name = self.idx_to_chain_id(idx) |
|
|
alignment_dir = os.path.join(self.alignment_dir, name) |
|
|
|
|
|
_alignment_index = None |
|
|
if(self._alignment_index is not None): |
|
|
alignment_dir = self.alignment_dir |
|
|
_alignment_index = self._alignment_index[name] |
|
|
|
|
|
if(self.mode == 'train' or self.mode == 'eval'): |
|
|
spl = name.rsplit('_', 1) |
|
|
if(len(spl) == 2): |
|
|
file_id, chain_id = spl |
|
|
else: |
|
|
file_id, = spl |
|
|
chain_id = None |
|
|
|
|
|
path = os.path.join(self.data_dir, file_id) |
|
|
if(os.path.exists(path + ".cif")): |
|
|
data = self._parse_mmcif( |
|
|
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, |
|
|
) |
|
|
elif(os.path.exists(path + ".core")): |
|
|
data = self.data_pipeline.process_core( |
|
|
path + ".core", alignment_dir, _alignment_index, |
|
|
) |
|
|
elif(os.path.exists(path + ".pdb")): |
|
|
data = self.data_pipeline.process_pdb( |
|
|
pdb_path=path + ".pdb", |
|
|
alignment_dir=alignment_dir, |
|
|
is_distillation=self.treat_pdb_as_distillation, |
|
|
chain_id=chain_id, |
|
|
_alignment_index=_alignment_index, |
|
|
) |
|
|
else: |
|
|
raise ValueError("Invalid file type") |
|
|
else: |
|
|
path = os.path.join(name, name + ".fasta") |
|
|
data = self.data_pipeline.process_fasta( |
|
|
fasta_path=path, |
|
|
alignment_dir=alignment_dir, |
|
|
_alignment_index=_alignment_index, |
|
|
) |
|
|
|
|
|
if(self._output_raw): |
|
|
return data |
|
|
|
|
|
feats = self.feature_pipeline.process_features( |
|
|
data, self.mode |
|
|
) |
|
|
|
|
|
return feats |
|
|
|
|
|
def __len__(self): |
|
|
return len(self._chain_ids) |
|
|
|
|
|
|
|
|
def deterministic_train_filter( |
|
|
chain_data_cache_entry: Any, |
|
|
max_resolution: float = 9., |
|
|
max_single_aa_prop: float = 0.8, |
|
|
) -> bool: |
|
|
|
|
|
resolution = chain_data_cache_entry.get("resolution", None) |
|
|
if(resolution is not None and resolution > max_resolution): |
|
|
return False |
|
|
|
|
|
seq = chain_data_cache_entry["seq"] |
|
|
counts = {} |
|
|
for aa in seq: |
|
|
counts.setdefault(aa, 0) |
|
|
counts[aa] += 1 |
|
|
largest_aa_count = max(counts.values()) |
|
|
largest_single_aa_prop = largest_aa_count / len(seq) |
|
|
if(largest_single_aa_prop > max_single_aa_prop): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def get_stochastic_train_filter_prob( |
|
|
chain_data_cache_entry: Any, |
|
|
) -> List[float]: |
|
|
|
|
|
probabilities = [] |
|
|
|
|
|
cluster_size = chain_data_cache_entry.get("cluster_size", None) |
|
|
if(cluster_size is not None and cluster_size > 0): |
|
|
probabilities.append(1 / cluster_size) |
|
|
|
|
|
chain_length = len(chain_data_cache_entry["seq"]) |
|
|
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256))) |
|
|
|
|
|
|
|
|
out = 1 |
|
|
for p in probabilities: |
|
|
out *= p |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class OpenFoldDataset(torch.utils.data.Dataset): |
|
|
""" |
|
|
Implements the stochastic filters applied during AlphaFold's training. |
|
|
Because samples are selected from constituent datasets randomly, the |
|
|
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected |
|
|
and filtered once at initialization. |
|
|
""" |
|
|
def __init__(self, |
|
|
datasets: Sequence[OpenFoldSingleDataset], |
|
|
probabilities: Sequence[int], |
|
|
epoch_len: int, |
|
|
chain_data_cache_paths: List[str], |
|
|
generator: torch.Generator = None, |
|
|
_roll_at_init: bool = True, |
|
|
): |
|
|
self.datasets = datasets |
|
|
self.probabilities = probabilities |
|
|
self.epoch_len = epoch_len |
|
|
self.generator = generator |
|
|
|
|
|
self.chain_data_caches = [] |
|
|
for path in chain_data_cache_paths: |
|
|
with open(path, "r") as fp: |
|
|
self.chain_data_caches.append(json.load(fp)) |
|
|
|
|
|
def looped_shuffled_dataset_idx(dataset_len): |
|
|
while True: |
|
|
|
|
|
weights = [1. for _ in range(dataset_len)] |
|
|
shuf = torch.multinomial( |
|
|
torch.tensor(weights), |
|
|
num_samples=dataset_len, |
|
|
replacement=False, |
|
|
generator=self.generator, |
|
|
) |
|
|
for idx in shuf: |
|
|
yield idx |
|
|
|
|
|
def looped_samples(dataset_idx): |
|
|
max_cache_len = int(epoch_len * probabilities[dataset_idx]) |
|
|
dataset = self.datasets[dataset_idx] |
|
|
idx_iter = looped_shuffled_dataset_idx(len(dataset)) |
|
|
chain_data_cache = self.chain_data_caches[dataset_idx] |
|
|
while True: |
|
|
weights = [] |
|
|
idx = [] |
|
|
for _ in range(max_cache_len): |
|
|
candidate_idx = next(idx_iter) |
|
|
chain_id = dataset.idx_to_chain_id(candidate_idx) |
|
|
chain_data_cache_entry = chain_data_cache[chain_id] |
|
|
if(not deterministic_train_filter(chain_data_cache_entry)): |
|
|
continue |
|
|
|
|
|
p = get_stochastic_train_filter_prob( |
|
|
chain_data_cache_entry, |
|
|
) |
|
|
weights.append([1. - p, p]) |
|
|
idx.append(candidate_idx) |
|
|
|
|
|
samples = torch.multinomial( |
|
|
torch.tensor(weights), |
|
|
num_samples=1, |
|
|
generator=self.generator, |
|
|
) |
|
|
samples = samples.squeeze() |
|
|
|
|
|
cache = [i for i, s in zip(idx, samples) if s] |
|
|
|
|
|
for datapoint_idx in cache: |
|
|
yield datapoint_idx |
|
|
|
|
|
self._samples = [looped_samples(i) for i in range(len(self.datasets))] |
|
|
|
|
|
if(_roll_at_init): |
|
|
self.reroll() |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
dataset_idx, datapoint_idx = self.datapoints[idx] |
|
|
return self.datasets[dataset_idx][datapoint_idx] |
|
|
|
|
|
def __len__(self): |
|
|
return self.epoch_len |
|
|
|
|
|
def reroll(self): |
|
|
dataset_choices = torch.multinomial( |
|
|
torch.tensor(self.probabilities), |
|
|
num_samples=self.epoch_len, |
|
|
replacement=True, |
|
|
generator=self.generator, |
|
|
) |
|
|
|
|
|
self.datapoints = [] |
|
|
for dataset_idx in dataset_choices: |
|
|
samples = self._samples[dataset_idx] |
|
|
datapoint_idx = next(samples) |
|
|
self.datapoints.append((dataset_idx, datapoint_idx)) |
|
|
|
|
|
|
|
|
class OpenFoldBatchCollator: |
|
|
def __init__(self, config, stage="train"): |
|
|
self.stage = stage |
|
|
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) |
|
|
|
|
|
def __call__(self, raw_prots): |
|
|
processed_prots = [] |
|
|
for prot in raw_prots: |
|
|
features = self.feature_pipeline.process_features( |
|
|
prot, self.stage |
|
|
) |
|
|
processed_prots.append(features) |
|
|
|
|
|
stack_fn = partial(torch.stack, dim=0) |
|
|
return dict_multimap(stack_fn, processed_prots) |
|
|
|
|
|
|
|
|
class OpenFoldDataLoader(torch.utils.data.DataLoader): |
|
|
def __init__(self, *args, config, stage="train", generator=None, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.config = config |
|
|
self.stage = stage |
|
|
|
|
|
if(generator is None): |
|
|
generator = torch.Generator() |
|
|
|
|
|
self.generator = generator |
|
|
self._prep_batch_properties_probs() |
|
|
|
|
|
def _prep_batch_properties_probs(self): |
|
|
keyed_probs = [] |
|
|
stage_cfg = self.config[self.stage] |
|
|
|
|
|
max_iters = self.config.common.max_recycling_iters |
|
|
if(stage_cfg.supervised): |
|
|
clamp_prob = self.config.supervised.clamp_prob |
|
|
keyed_probs.append( |
|
|
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) |
|
|
) |
|
|
|
|
|
if(stage_cfg.uniform_recycling): |
|
|
recycling_probs = [ |
|
|
1. / (max_iters + 1) for _ in range(max_iters + 1) |
|
|
] |
|
|
else: |
|
|
recycling_probs = [ |
|
|
0. for _ in range(max_iters + 1) |
|
|
] |
|
|
recycling_probs[-1] = 1. |
|
|
|
|
|
keyed_probs.append( |
|
|
("no_recycling_iters", recycling_probs) |
|
|
) |
|
|
|
|
|
keys, probs = zip(*keyed_probs) |
|
|
max_len = max([len(p) for p in probs]) |
|
|
padding = [[0.] * (max_len - len(p)) for p in probs] |
|
|
|
|
|
self.prop_keys = keys |
|
|
self.prop_probs_tensor = torch.tensor( |
|
|
[p + pad for p, pad in zip(probs, padding)], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
def _add_batch_properties(self, batch): |
|
|
samples = torch.multinomial( |
|
|
self.prop_probs_tensor, |
|
|
num_samples=1, |
|
|
replacement=True, |
|
|
generator=self.generator |
|
|
) |
|
|
|
|
|
aatype = batch["aatype"] |
|
|
batch_dims = aatype.shape[:-2] |
|
|
recycling_dim = aatype.shape[-1] |
|
|
no_recycling = recycling_dim |
|
|
for i, key in enumerate(self.prop_keys): |
|
|
sample = int(samples[i][0]) |
|
|
sample_tensor = torch.tensor( |
|
|
sample, |
|
|
device=aatype.device, |
|
|
requires_grad=False |
|
|
) |
|
|
orig_shape = sample_tensor.shape |
|
|
sample_tensor = sample_tensor.view( |
|
|
(1,) * len(batch_dims) + sample_tensor.shape + (1,) |
|
|
) |
|
|
sample_tensor = sample_tensor.expand( |
|
|
batch_dims + orig_shape + (recycling_dim,) |
|
|
) |
|
|
batch[key] = sample_tensor |
|
|
|
|
|
if(key == "no_recycling_iters"): |
|
|
no_recycling = sample |
|
|
|
|
|
resample_recycling = lambda t: t[..., :no_recycling + 1] |
|
|
batch = tensor_tree_map(resample_recycling, batch) |
|
|
|
|
|
return batch |
|
|
|
|
|
def __iter__(self): |
|
|
it = super().__iter__() |
|
|
|
|
|
def _batch_prop_gen(iterator): |
|
|
for batch in iterator: |
|
|
yield self._add_batch_properties(batch) |
|
|
|
|
|
return _batch_prop_gen(it) |
|
|
|
|
|
|
|
|
class OpenFoldDataModule(pl.LightningDataModule): |
|
|
def __init__(self, |
|
|
config: mlc.ConfigDict, |
|
|
template_mmcif_dir: str, |
|
|
max_template_date: str, |
|
|
train_data_dir: Optional[str] = None, |
|
|
train_alignment_dir: Optional[str] = None, |
|
|
train_chain_data_cache_path: Optional[str] = None, |
|
|
distillation_data_dir: Optional[str] = None, |
|
|
distillation_alignment_dir: Optional[str] = None, |
|
|
distillation_chain_data_cache_path: Optional[str] = None, |
|
|
val_data_dir: Optional[str] = None, |
|
|
val_alignment_dir: Optional[str] = None, |
|
|
predict_data_dir: Optional[str] = None, |
|
|
predict_alignment_dir: Optional[str] = None, |
|
|
kalign_binary_path: str = '/usr/bin/kalign', |
|
|
train_mapping_path: Optional[str] = None, |
|
|
distillation_mapping_path: Optional[str] = None, |
|
|
obsolete_pdbs_file_path: Optional[str] = None, |
|
|
template_release_dates_cache_path: Optional[str] = None, |
|
|
batch_seed: Optional[int] = None, |
|
|
train_epoch_len: int = 50000, |
|
|
_alignment_index_path: Optional[str] = None, |
|
|
**kwargs |
|
|
): |
|
|
super(OpenFoldDataModule, self).__init__() |
|
|
|
|
|
self.config = config |
|
|
self.template_mmcif_dir = template_mmcif_dir |
|
|
self.max_template_date = max_template_date |
|
|
self.train_data_dir = train_data_dir |
|
|
self.train_alignment_dir = train_alignment_dir |
|
|
self.train_chain_data_cache_path = train_chain_data_cache_path |
|
|
self.distillation_data_dir = distillation_data_dir |
|
|
self.distillation_alignment_dir = distillation_alignment_dir |
|
|
self.distillation_chain_data_cache_path = ( |
|
|
distillation_chain_data_cache_path |
|
|
) |
|
|
self.val_data_dir = val_data_dir |
|
|
self.val_alignment_dir = val_alignment_dir |
|
|
self.predict_data_dir = predict_data_dir |
|
|
self.predict_alignment_dir = predict_alignment_dir |
|
|
self.kalign_binary_path = kalign_binary_path |
|
|
self.train_mapping_path = train_mapping_path |
|
|
self.distillation_mapping_path = distillation_mapping_path |
|
|
self.template_release_dates_cache_path = ( |
|
|
template_release_dates_cache_path |
|
|
) |
|
|
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path |
|
|
self.batch_seed = batch_seed |
|
|
self.train_epoch_len = train_epoch_len |
|
|
|
|
|
if(self.train_data_dir is None and self.predict_data_dir is None): |
|
|
raise ValueError( |
|
|
'At least one of train_data_dir or predict_data_dir must be ' |
|
|
'specified' |
|
|
) |
|
|
|
|
|
self.training_mode = self.train_data_dir is not None |
|
|
|
|
|
if(self.training_mode and train_alignment_dir is None): |
|
|
raise ValueError( |
|
|
'In training mode, train_alignment_dir must be specified' |
|
|
) |
|
|
elif(not self.training_mode and predict_alignment_dir is None): |
|
|
raise ValueError( |
|
|
'In inference mode, predict_alignment_dir must be specified' |
|
|
) |
|
|
elif(val_data_dir is not None and val_alignment_dir is None): |
|
|
raise ValueError( |
|
|
'If val_data_dir is specified, val_alignment_dir must ' |
|
|
'be specified as well' |
|
|
) |
|
|
|
|
|
|
|
|
self._alignment_index = None |
|
|
if(_alignment_index_path is not None): |
|
|
with open(_alignment_index_path, "r") as fp: |
|
|
self._alignment_index = json.load(fp) |
|
|
|
|
|
def setup(self): |
|
|
|
|
|
dataset_gen = partial(OpenFoldSingleDataset, |
|
|
template_mmcif_dir=self.template_mmcif_dir, |
|
|
max_template_date=self.max_template_date, |
|
|
config=self.config, |
|
|
kalign_binary_path=self.kalign_binary_path, |
|
|
template_release_dates_cache_path= |
|
|
self.template_release_dates_cache_path, |
|
|
obsolete_pdbs_file_path= |
|
|
self.obsolete_pdbs_file_path, |
|
|
) |
|
|
|
|
|
if(self.training_mode): |
|
|
train_dataset = dataset_gen( |
|
|
data_dir=self.train_data_dir, |
|
|
alignment_dir=self.train_alignment_dir, |
|
|
mapping_path=self.train_mapping_path, |
|
|
max_template_hits=self.config.train.max_template_hits, |
|
|
shuffle_top_k_prefiltered= |
|
|
self.config.train.shuffle_top_k_prefiltered, |
|
|
treat_pdb_as_distillation=False, |
|
|
mode="train", |
|
|
_output_raw=True, |
|
|
_alignment_index=self._alignment_index, |
|
|
) |
|
|
|
|
|
distillation_dataset = None |
|
|
if(self.distillation_data_dir is not None): |
|
|
distillation_dataset = dataset_gen( |
|
|
data_dir=self.distillation_data_dir, |
|
|
alignment_dir=self.distillation_alignment_dir, |
|
|
mapping_path=self.distillation_mapping_path, |
|
|
max_template_hits=self.train.max_template_hits, |
|
|
treat_pdb_as_distillation=True, |
|
|
mode="train", |
|
|
_output_raw=True, |
|
|
) |
|
|
|
|
|
d_prob = self.config.train.distillation_prob |
|
|
|
|
|
if(distillation_dataset is not None): |
|
|
datasets = [train_dataset, distillation_dataset] |
|
|
d_prob = self.config.train.distillation_prob |
|
|
probabilities = [1 - d_prob, d_prob] |
|
|
chain_data_cache_paths = [ |
|
|
self.train_chain_data_cache_path, |
|
|
self.distillation_chain_data_cache_path, |
|
|
] |
|
|
else: |
|
|
datasets = [train_dataset] |
|
|
probabilities = [1.] |
|
|
chain_data_cache_paths = [ |
|
|
self.train_chain_data_cache_path, |
|
|
] |
|
|
|
|
|
self.train_dataset = OpenFoldDataset( |
|
|
datasets=datasets, |
|
|
probabilities=probabilities, |
|
|
epoch_len=self.train_epoch_len, |
|
|
chain_data_cache_paths=chain_data_cache_paths, |
|
|
_roll_at_init=False, |
|
|
) |
|
|
|
|
|
if(self.val_data_dir is not None): |
|
|
self.eval_dataset = dataset_gen( |
|
|
data_dir=self.val_data_dir, |
|
|
alignment_dir=self.val_alignment_dir, |
|
|
mapping_path=None, |
|
|
max_template_hits=self.config.eval.max_template_hits, |
|
|
mode="eval", |
|
|
_output_raw=True, |
|
|
) |
|
|
else: |
|
|
self.eval_dataset = None |
|
|
else: |
|
|
self.predict_dataset = dataset_gen( |
|
|
data_dir=self.predict_data_dir, |
|
|
alignment_dir=self.predict_alignment_dir, |
|
|
mapping_path=None, |
|
|
max_template_hits=self.config.predict.max_template_hits, |
|
|
mode="predict", |
|
|
) |
|
|
|
|
|
def _gen_dataloader(self, stage): |
|
|
generator = torch.Generator() |
|
|
if(self.batch_seed is not None): |
|
|
generator = generator.manual_seed(self.batch_seed) |
|
|
|
|
|
dataset = None |
|
|
if(stage == "train"): |
|
|
dataset = self.train_dataset |
|
|
|
|
|
|
|
|
dataset.reroll() |
|
|
elif(stage == "eval"): |
|
|
dataset = self.eval_dataset |
|
|
elif(stage == "predict"): |
|
|
dataset = self.predict_dataset |
|
|
else: |
|
|
raise ValueError("Invalid stage") |
|
|
|
|
|
batch_collator = OpenFoldBatchCollator(self.config, stage) |
|
|
|
|
|
dl = OpenFoldDataLoader( |
|
|
dataset, |
|
|
config=self.config, |
|
|
stage=stage, |
|
|
generator=generator, |
|
|
batch_size=self.config.data_module.data_loaders.batch_size, |
|
|
num_workers=self.config.data_module.data_loaders.num_workers, |
|
|
collate_fn=batch_collator, |
|
|
) |
|
|
|
|
|
return dl |
|
|
|
|
|
def train_dataloader(self): |
|
|
return self._gen_dataloader("train") |
|
|
|
|
|
def val_dataloader(self): |
|
|
if(self.eval_dataset is not None): |
|
|
return self._gen_dataloader("eval") |
|
|
return None |
|
|
|
|
|
def predict_dataloader(self): |
|
|
return self._gen_dataloader("predict") |
|
|
|
|
|
|
|
|
class DummyDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, batch_path): |
|
|
with open(batch_path, "rb") as f: |
|
|
self.batch = pickle.load(f) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return copy.deepcopy(self.batch) |
|
|
|
|
|
def __len__(self): |
|
|
return 1000 |
|
|
|
|
|
|
|
|
class DummyDataLoader(pl.LightningDataModule): |
|
|
def __init__(self, batch_path): |
|
|
super().__init__() |
|
|
self.dataset = DummyDataset(batch_path) |
|
|
|
|
|
def train_dataloader(self): |
|
|
return torch.utils.data.DataLoader(self.dataset) |
|
|
|