|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
|
|
|
from openfold.data import data_transforms |
|
|
|
|
|
|
|
|
def nonensembled_transform_fns(common_cfg, mode_cfg): |
|
|
"""Input pipeline data transformers that are not ensembled.""" |
|
|
transforms = [ |
|
|
data_transforms.cast_to_64bit_ints, |
|
|
data_transforms.correct_msa_restypes, |
|
|
data_transforms.squeeze_features, |
|
|
data_transforms.randomly_replace_msa_with_unknown(0.0), |
|
|
data_transforms.make_seq_mask, |
|
|
data_transforms.make_msa_mask, |
|
|
data_transforms.make_hhblits_profile, |
|
|
] |
|
|
if common_cfg.use_templates: |
|
|
transforms.extend( |
|
|
[ |
|
|
data_transforms.fix_templates_aatype, |
|
|
data_transforms.make_template_mask, |
|
|
data_transforms.make_pseudo_beta("template_"), |
|
|
] |
|
|
) |
|
|
if common_cfg.use_template_torsion_angles: |
|
|
transforms.extend( |
|
|
[ |
|
|
data_transforms.atom37_to_torsion_angles("template_"), |
|
|
] |
|
|
) |
|
|
|
|
|
transforms.extend( |
|
|
[ |
|
|
data_transforms.make_atom14_masks, |
|
|
] |
|
|
) |
|
|
|
|
|
if mode_cfg.supervised: |
|
|
transforms.extend( |
|
|
[ |
|
|
data_transforms.make_atom14_positions, |
|
|
data_transforms.atom37_to_frames, |
|
|
data_transforms.atom37_to_torsion_angles(""), |
|
|
data_transforms.make_pseudo_beta(""), |
|
|
data_transforms.get_backbone_frames, |
|
|
data_transforms.get_chi_angles, |
|
|
] |
|
|
) |
|
|
|
|
|
return transforms |
|
|
|
|
|
|
|
|
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): |
|
|
"""Input pipeline data transformers that can be ensembled and averaged.""" |
|
|
transforms = [] |
|
|
|
|
|
if "max_distillation_msa_clusters" in mode_cfg: |
|
|
transforms.append( |
|
|
data_transforms.sample_msa_distillation( |
|
|
mode_cfg.max_distillation_msa_clusters |
|
|
) |
|
|
) |
|
|
|
|
|
if common_cfg.reduce_msa_clusters_by_max_templates: |
|
|
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates |
|
|
else: |
|
|
pad_msa_clusters = mode_cfg.max_msa_clusters |
|
|
|
|
|
max_msa_clusters = pad_msa_clusters |
|
|
max_extra_msa = common_cfg.max_extra_msa |
|
|
|
|
|
msa_seed = None |
|
|
if(not common_cfg.resample_msa_in_recycling): |
|
|
msa_seed = ensemble_seed |
|
|
|
|
|
transforms.append( |
|
|
data_transforms.sample_msa( |
|
|
max_msa_clusters, |
|
|
keep_extra=True, |
|
|
seed=msa_seed, |
|
|
) |
|
|
) |
|
|
|
|
|
if "masked_msa" in common_cfg: |
|
|
|
|
|
|
|
|
|
|
|
transforms.append( |
|
|
data_transforms.make_masked_msa( |
|
|
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction |
|
|
) |
|
|
) |
|
|
|
|
|
if common_cfg.msa_cluster_features: |
|
|
transforms.append(data_transforms.nearest_neighbor_clusters()) |
|
|
transforms.append(data_transforms.summarize_clusters()) |
|
|
|
|
|
|
|
|
if max_extra_msa: |
|
|
transforms.append(data_transforms.crop_extra_msa(max_extra_msa)) |
|
|
else: |
|
|
transforms.append(data_transforms.delete_extra_msa) |
|
|
|
|
|
transforms.append(data_transforms.make_msa_feat()) |
|
|
|
|
|
crop_feats = dict(common_cfg.feat) |
|
|
|
|
|
if mode_cfg.fixed_size: |
|
|
transforms.append(data_transforms.select_feat(list(crop_feats))) |
|
|
transforms.append( |
|
|
data_transforms.random_crop_to_size( |
|
|
mode_cfg.crop_size, |
|
|
mode_cfg.max_templates, |
|
|
crop_feats, |
|
|
mode_cfg.subsample_templates, |
|
|
seed=ensemble_seed + 1, |
|
|
) |
|
|
) |
|
|
transforms.append( |
|
|
data_transforms.make_fixed_size( |
|
|
crop_feats, |
|
|
pad_msa_clusters, |
|
|
common_cfg.max_extra_msa, |
|
|
mode_cfg.crop_size, |
|
|
mode_cfg.max_templates, |
|
|
) |
|
|
) |
|
|
else: |
|
|
transforms.append( |
|
|
data_transforms.crop_templates(mode_cfg.max_templates) |
|
|
) |
|
|
|
|
|
return transforms |
|
|
|
|
|
|
|
|
def process_tensors_from_config(tensors, common_cfg, mode_cfg): |
|
|
"""Based on the config, apply filters and transformations to the data.""" |
|
|
|
|
|
ensemble_seed = torch.Generator().seed() |
|
|
|
|
|
def wrap_ensemble_fn(data, i): |
|
|
"""Function to be mapped over the ensemble dimension.""" |
|
|
d = data.copy() |
|
|
fns = ensembled_transform_fns( |
|
|
common_cfg, |
|
|
mode_cfg, |
|
|
ensemble_seed, |
|
|
) |
|
|
fn = compose(fns) |
|
|
d["ensemble_index"] = i |
|
|
return fn(d) |
|
|
|
|
|
no_templates = True |
|
|
if("template_aatype" in tensors): |
|
|
no_templates = tensors["template_aatype"].shape[0] == 0 |
|
|
|
|
|
nonensembled = nonensembled_transform_fns( |
|
|
common_cfg, |
|
|
mode_cfg, |
|
|
) |
|
|
|
|
|
tensors = compose(nonensembled)(tensors) |
|
|
|
|
|
if("no_recycling_iters" in tensors): |
|
|
num_recycling = int(tensors["no_recycling_iters"]) |
|
|
else: |
|
|
num_recycling = common_cfg.max_recycling_iters |
|
|
|
|
|
tensors = map_fn( |
|
|
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) |
|
|
) |
|
|
|
|
|
return tensors |
|
|
|
|
|
|
|
|
@data_transforms.curry1 |
|
|
def compose(x, fs): |
|
|
for f in fs: |
|
|
x = f(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def map_fn(fun, x): |
|
|
ensembles = [fun(elem) for elem in x] |
|
|
features = ensembles[0].keys() |
|
|
ensembled_dict = {} |
|
|
for feat in features: |
|
|
ensembled_dict[feat] = torch.stack( |
|
|
[dict_i[feat] for dict_i in ensembles], dim=-1 |
|
|
) |
|
|
return ensembled_dict |
|
|
|