P2DFlow / openfold /utils /import_weights.py
Holmes
test
ca7299e
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
from typing import Union, List
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
)
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = "/" + k if not top_layer else k
flat[k] = v
return flat
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
)
out[k] = stacked_param
return out
def assign(translation_dict, orig_weights):
for k, param in translation_dict.items():
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
except:
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
}
AttentionParams = lambda att: {
"query_w": LinearWeightMHA(att.linear_q.weight),
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
AttentionGatedParams = lambda att: dict(
**AttentionParams(att),
**{
"gating_w": LinearWeightMHA(att.linear_g.weight),
"gating_b": LinearBiasMHA(att.linear_g.bias),
},
)
GlobalAttentionParams = lambda att: dict(
AttentionGatedParams(att),
key_w=LinearWeight(att.linear_k.weight),
value_w=LinearWeight(att.linear_v.weight),
)
TriAttParams = lambda tri_att: {
"query_norm": LayerNormParams(tri_att.layer_norm),
"feat_2d_weights": LinearWeight(tri_att.linear.weight),
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
"transition1": LinearParams(pt.linear_1),
"transition2": LinearParams(pt.linear_2),
}
MSAAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
"attention": AttentionGatedParams(matt._msa_att.mha),
}
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
**MSAAttParams(matt),
**{
"feat_2d_norm": LayerNormParams(matt.layer_norm_z),
"feat_2d_weights": LinearWeight(matt.linear_z.weight),
},
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
TemplatePairBlockParams = lambda b: {
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"pair_transition": PairTransitionParams(b.pair_transition),
}
MSATransitionParams = lambda m: {
"input_layer_norm": LayerNormParams(m.layer_norm),
"transition1": LinearParams(m.linear_1),
"transition2": LinearParams(m.linear_2),
}
OuterProductMeanParams = lambda o: {
"layer_norm_input": LayerNormParams(o.layer_norm),
"left_projection": LinearParams(o.linear_1),
"right_projection": LinearParams(o.linear_2),
"output_w": LinearWeightOPM(o.linear_out.weight),
"output_b": LinearBias(o.linear_out.bias),
}
def EvoformerBlockParams(b, is_extra_msa=False):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
col_att_name = "msa_column_attention"
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
"transition_2": LinearParams(sm.transition.layers[0].linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
############################
# translations dict overflow
############################
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_5_ptm",
]
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign(flat, data)