|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
import numpy as np |
|
|
import torch |
|
|
from typing import Union, List |
|
|
|
|
|
|
|
|
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" |
|
|
|
|
|
|
|
|
|
|
|
class ParamType(Enum): |
|
|
LinearWeight = partial( |
|
|
lambda w: w.transpose(-1, -2) |
|
|
) |
|
|
LinearWeightMHA = partial( |
|
|
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) |
|
|
) |
|
|
LinearMHAOutputWeight = partial( |
|
|
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) |
|
|
) |
|
|
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1)) |
|
|
LinearWeightOPM = partial( |
|
|
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) |
|
|
) |
|
|
Other = partial(lambda w: w) |
|
|
|
|
|
def __init__(self, fn): |
|
|
self.transformation = fn |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Param: |
|
|
param: Union[torch.Tensor, List[torch.Tensor]] |
|
|
param_type: ParamType = ParamType.Other |
|
|
stacked: bool = False |
|
|
|
|
|
|
|
|
def _process_translations_dict(d, top_layer=True): |
|
|
flat = {} |
|
|
for k, v in d.items(): |
|
|
if type(v) == dict: |
|
|
prefix = _NPZ_KEY_PREFIX if top_layer else "" |
|
|
sub_flat = { |
|
|
(prefix + "/".join([k, k_prime])): v_prime |
|
|
for k_prime, v_prime in _process_translations_dict( |
|
|
v, top_layer=False |
|
|
).items() |
|
|
} |
|
|
flat.update(sub_flat) |
|
|
else: |
|
|
k = "/" + k if not top_layer else k |
|
|
flat[k] = v |
|
|
|
|
|
return flat |
|
|
|
|
|
|
|
|
def stacked(param_dict_list, out=None): |
|
|
""" |
|
|
Args: |
|
|
param_dict_list: |
|
|
A list of (nested) Param dicts to stack. The structure of |
|
|
each dict must be the identical (down to the ParamTypes of |
|
|
"parallel" Params). There must be at least one dict |
|
|
in the list. |
|
|
""" |
|
|
if out is None: |
|
|
out = {} |
|
|
template = param_dict_list[0] |
|
|
for k, _ in template.items(): |
|
|
v = [d[k] for d in param_dict_list] |
|
|
if type(v[0]) is dict: |
|
|
out[k] = {} |
|
|
stacked(v, out=out[k]) |
|
|
elif type(v[0]) is Param: |
|
|
stacked_param = Param( |
|
|
param=[param.param for param in v], |
|
|
param_type=v[0].param_type, |
|
|
stacked=True, |
|
|
) |
|
|
|
|
|
out[k] = stacked_param |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def assign(translation_dict, orig_weights): |
|
|
for k, param in translation_dict.items(): |
|
|
with torch.no_grad(): |
|
|
weights = torch.as_tensor(orig_weights[k]) |
|
|
ref, param_type = param.param, param.param_type |
|
|
if param.stacked: |
|
|
weights = torch.unbind(weights, 0) |
|
|
else: |
|
|
weights = [weights] |
|
|
ref = [ref] |
|
|
|
|
|
try: |
|
|
weights = list(map(param_type.transformation, weights)) |
|
|
for p, w in zip(ref, weights): |
|
|
p.copy_(w) |
|
|
except: |
|
|
print(k) |
|
|
print(ref[0].shape) |
|
|
print(weights[0].shape) |
|
|
raise |
|
|
|
|
|
|
|
|
def import_jax_weights_(model, npz_path, version="model_1"): |
|
|
data = np.load(npz_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) |
|
|
|
|
|
LinearBias = lambda l: (Param(l)) |
|
|
|
|
|
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) |
|
|
|
|
|
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) |
|
|
|
|
|
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) |
|
|
|
|
|
LinearParams = lambda l: { |
|
|
"weights": LinearWeight(l.weight), |
|
|
"bias": LinearBias(l.bias), |
|
|
} |
|
|
|
|
|
LayerNormParams = lambda l: { |
|
|
"scale": Param(l.weight), |
|
|
"offset": Param(l.bias), |
|
|
} |
|
|
|
|
|
AttentionParams = lambda att: { |
|
|
"query_w": LinearWeightMHA(att.linear_q.weight), |
|
|
"key_w": LinearWeightMHA(att.linear_k.weight), |
|
|
"value_w": LinearWeightMHA(att.linear_v.weight), |
|
|
"output_w": Param( |
|
|
att.linear_o.weight, |
|
|
param_type=ParamType.LinearMHAOutputWeight, |
|
|
), |
|
|
"output_b": LinearBias(att.linear_o.bias), |
|
|
} |
|
|
|
|
|
AttentionGatedParams = lambda att: dict( |
|
|
**AttentionParams(att), |
|
|
**{ |
|
|
"gating_w": LinearWeightMHA(att.linear_g.weight), |
|
|
"gating_b": LinearBiasMHA(att.linear_g.bias), |
|
|
}, |
|
|
) |
|
|
|
|
|
GlobalAttentionParams = lambda att: dict( |
|
|
AttentionGatedParams(att), |
|
|
key_w=LinearWeight(att.linear_k.weight), |
|
|
value_w=LinearWeight(att.linear_v.weight), |
|
|
) |
|
|
|
|
|
TriAttParams = lambda tri_att: { |
|
|
"query_norm": LayerNormParams(tri_att.layer_norm), |
|
|
"feat_2d_weights": LinearWeight(tri_att.linear.weight), |
|
|
"attention": AttentionGatedParams(tri_att.mha), |
|
|
} |
|
|
|
|
|
TriMulOutParams = lambda tri_mul: { |
|
|
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), |
|
|
"left_projection": LinearParams(tri_mul.linear_a_p), |
|
|
"right_projection": LinearParams(tri_mul.linear_b_p), |
|
|
"left_gate": LinearParams(tri_mul.linear_a_g), |
|
|
"right_gate": LinearParams(tri_mul.linear_b_g), |
|
|
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), |
|
|
"output_projection": LinearParams(tri_mul.linear_z), |
|
|
"gating_linear": LinearParams(tri_mul.linear_g), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TriMulInParams = lambda tri_mul: { |
|
|
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), |
|
|
"left_projection": LinearParams(tri_mul.linear_b_p), |
|
|
"right_projection": LinearParams(tri_mul.linear_a_p), |
|
|
"left_gate": LinearParams(tri_mul.linear_b_g), |
|
|
"right_gate": LinearParams(tri_mul.linear_a_g), |
|
|
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), |
|
|
"output_projection": LinearParams(tri_mul.linear_z), |
|
|
"gating_linear": LinearParams(tri_mul.linear_g), |
|
|
} |
|
|
|
|
|
PairTransitionParams = lambda pt: { |
|
|
"input_layer_norm": LayerNormParams(pt.layer_norm), |
|
|
"transition1": LinearParams(pt.linear_1), |
|
|
"transition2": LinearParams(pt.linear_2), |
|
|
} |
|
|
|
|
|
MSAAttParams = lambda matt: { |
|
|
"query_norm": LayerNormParams(matt.layer_norm_m), |
|
|
"attention": AttentionGatedParams(matt.mha), |
|
|
} |
|
|
|
|
|
MSAColAttParams = lambda matt: { |
|
|
"query_norm": LayerNormParams(matt._msa_att.layer_norm_m), |
|
|
"attention": AttentionGatedParams(matt._msa_att.mha), |
|
|
} |
|
|
|
|
|
MSAGlobalAttParams = lambda matt: { |
|
|
"query_norm": LayerNormParams(matt.layer_norm_m), |
|
|
"attention": GlobalAttentionParams(matt.global_attention), |
|
|
} |
|
|
|
|
|
MSAAttPairBiasParams = lambda matt: dict( |
|
|
**MSAAttParams(matt), |
|
|
**{ |
|
|
"feat_2d_norm": LayerNormParams(matt.layer_norm_z), |
|
|
"feat_2d_weights": LinearWeight(matt.linear_z.weight), |
|
|
}, |
|
|
) |
|
|
|
|
|
IPAParams = lambda ipa: { |
|
|
"q_scalar": LinearParams(ipa.linear_q), |
|
|
"kv_scalar": LinearParams(ipa.linear_kv), |
|
|
"q_point_local": LinearParams(ipa.linear_q_points), |
|
|
"kv_point_local": LinearParams(ipa.linear_kv_points), |
|
|
"trainable_point_weights": Param( |
|
|
param=ipa.head_weights, param_type=ParamType.Other |
|
|
), |
|
|
"attention_2d": LinearParams(ipa.linear_b), |
|
|
"output_projection": LinearParams(ipa.linear_out), |
|
|
} |
|
|
|
|
|
TemplatePairBlockParams = lambda b: { |
|
|
"triangle_attention_starting_node": TriAttParams(b.tri_att_start), |
|
|
"triangle_attention_ending_node": TriAttParams(b.tri_att_end), |
|
|
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out), |
|
|
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in), |
|
|
"pair_transition": PairTransitionParams(b.pair_transition), |
|
|
} |
|
|
|
|
|
MSATransitionParams = lambda m: { |
|
|
"input_layer_norm": LayerNormParams(m.layer_norm), |
|
|
"transition1": LinearParams(m.linear_1), |
|
|
"transition2": LinearParams(m.linear_2), |
|
|
} |
|
|
|
|
|
OuterProductMeanParams = lambda o: { |
|
|
"layer_norm_input": LayerNormParams(o.layer_norm), |
|
|
"left_projection": LinearParams(o.linear_1), |
|
|
"right_projection": LinearParams(o.linear_2), |
|
|
"output_w": LinearWeightOPM(o.linear_out.weight), |
|
|
"output_b": LinearBias(o.linear_out.bias), |
|
|
} |
|
|
|
|
|
def EvoformerBlockParams(b, is_extra_msa=False): |
|
|
if is_extra_msa: |
|
|
col_att_name = "msa_column_global_attention" |
|
|
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col) |
|
|
else: |
|
|
col_att_name = "msa_column_attention" |
|
|
msa_col_att_params = MSAColAttParams(b.msa_att_col) |
|
|
|
|
|
d = { |
|
|
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams( |
|
|
b.msa_att_row |
|
|
), |
|
|
col_att_name: msa_col_att_params, |
|
|
"msa_transition": MSATransitionParams(b.core.msa_transition), |
|
|
"outer_product_mean": |
|
|
OuterProductMeanParams(b.core.outer_product_mean), |
|
|
"triangle_multiplication_outgoing": |
|
|
TriMulOutParams(b.core.tri_mul_out), |
|
|
"triangle_multiplication_incoming": |
|
|
TriMulInParams(b.core.tri_mul_in), |
|
|
"triangle_attention_starting_node": |
|
|
TriAttParams(b.core.tri_att_start), |
|
|
"triangle_attention_ending_node": |
|
|
TriAttParams(b.core.tri_att_end), |
|
|
"pair_transition": |
|
|
PairTransitionParams(b.core.pair_transition), |
|
|
} |
|
|
|
|
|
return d |
|
|
|
|
|
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) |
|
|
|
|
|
FoldIterationParams = lambda sm: { |
|
|
"invariant_point_attention": IPAParams(sm.ipa), |
|
|
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), |
|
|
"transition": LinearParams(sm.transition.layers[0].linear_1), |
|
|
"transition_1": LinearParams(sm.transition.layers[0].linear_2), |
|
|
"transition_2": LinearParams(sm.transition.layers[0].linear_3), |
|
|
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm), |
|
|
"affine_update": LinearParams(sm.bb_update.linear), |
|
|
"rigid_sidechain": { |
|
|
"input_projection": LinearParams(sm.angle_resnet.linear_in), |
|
|
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial), |
|
|
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), |
|
|
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), |
|
|
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), |
|
|
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), |
|
|
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tps_blocks = model.template_pair_stack.blocks |
|
|
tps_blocks_params = stacked( |
|
|
[TemplatePairBlockParams(b) for b in tps_blocks] |
|
|
) |
|
|
|
|
|
ems_blocks = model.extra_msa_stack.blocks |
|
|
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) |
|
|
|
|
|
evo_blocks = model.evoformer.blocks |
|
|
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) |
|
|
|
|
|
translations = { |
|
|
"evoformer": { |
|
|
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), |
|
|
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m), |
|
|
"left_single": LinearParams(model.input_embedder.linear_tf_z_i), |
|
|
"right_single": LinearParams(model.input_embedder.linear_tf_z_j), |
|
|
"prev_pos_linear": LinearParams(model.recycling_embedder.linear), |
|
|
"prev_msa_first_row_norm": LayerNormParams( |
|
|
model.recycling_embedder.layer_norm_m |
|
|
), |
|
|
"prev_pair_norm": LayerNormParams( |
|
|
model.recycling_embedder.layer_norm_z |
|
|
), |
|
|
"pair_activiations": LinearParams( |
|
|
model.input_embedder.linear_relpos |
|
|
), |
|
|
"template_embedding": { |
|
|
"single_template_embedding": { |
|
|
"embedding2d": LinearParams( |
|
|
model.template_pair_embedder.linear |
|
|
), |
|
|
"template_pair_stack": { |
|
|
"__layer_stack_no_state": tps_blocks_params, |
|
|
}, |
|
|
"output_layer_norm": LayerNormParams( |
|
|
model.template_pair_stack.layer_norm |
|
|
), |
|
|
}, |
|
|
"attention": AttentionParams(model.template_pointwise_att.mha), |
|
|
}, |
|
|
"extra_msa_activations": LinearParams( |
|
|
model.extra_msa_embedder.linear |
|
|
), |
|
|
"extra_msa_stack": ems_blocks_params, |
|
|
"template_single_embedding": LinearParams( |
|
|
model.template_angle_embedder.linear_1 |
|
|
), |
|
|
"template_projection": LinearParams( |
|
|
model.template_angle_embedder.linear_2 |
|
|
), |
|
|
"evoformer_iteration": evo_blocks_params, |
|
|
"single_activations": LinearParams(model.evoformer.linear), |
|
|
}, |
|
|
"structure_module": { |
|
|
"single_layer_norm": LayerNormParams( |
|
|
model.structure_module.layer_norm_s |
|
|
), |
|
|
"initial_projection": LinearParams( |
|
|
model.structure_module.linear_in |
|
|
), |
|
|
"pair_layer_norm": LayerNormParams( |
|
|
model.structure_module.layer_norm_z |
|
|
), |
|
|
"fold_iteration": FoldIterationParams(model.structure_module), |
|
|
}, |
|
|
"predicted_lddt_head": { |
|
|
"input_layer_norm": LayerNormParams( |
|
|
model.aux_heads.plddt.layer_norm |
|
|
), |
|
|
"act_0": LinearParams(model.aux_heads.plddt.linear_1), |
|
|
"act_1": LinearParams(model.aux_heads.plddt.linear_2), |
|
|
"logits": LinearParams(model.aux_heads.plddt.linear_3), |
|
|
}, |
|
|
"distogram_head": { |
|
|
"half_logits": LinearParams(model.aux_heads.distogram.linear), |
|
|
}, |
|
|
"experimentally_resolved_head": { |
|
|
"logits": LinearParams( |
|
|
model.aux_heads.experimentally_resolved.linear |
|
|
), |
|
|
}, |
|
|
"masked_msa_head": { |
|
|
"logits": LinearParams(model.aux_heads.masked_msa.linear), |
|
|
}, |
|
|
} |
|
|
|
|
|
no_templ = [ |
|
|
"model_3", |
|
|
"model_4", |
|
|
"model_5", |
|
|
"model_3_ptm", |
|
|
"model_4_ptm", |
|
|
"model_5_ptm", |
|
|
] |
|
|
if version in no_templ: |
|
|
evo_dict = translations["evoformer"] |
|
|
keys = list(evo_dict.keys()) |
|
|
for k in keys: |
|
|
if "template_" in k: |
|
|
evo_dict.pop(k) |
|
|
|
|
|
if "_ptm" in version: |
|
|
translations["predicted_aligned_error_head"] = { |
|
|
"logits": LinearParams(model.aux_heads.tm.linear) |
|
|
} |
|
|
|
|
|
|
|
|
flat = _process_translations_dict(translations) |
|
|
|
|
|
|
|
|
keys = list(data.keys()) |
|
|
flat_keys = list(flat.keys()) |
|
|
incorrect = [k for k in flat_keys if k not in keys] |
|
|
missing = [k for k in keys if k not in flat_keys] |
|
|
|
|
|
|
|
|
|
|
|
assert len(incorrect) == 0 |
|
|
|
|
|
|
|
|
|
|
|
assign(flat, data) |
|
|
|