|
|
import os |
|
|
from typing import * |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from scipy.spatial import distance |
|
|
|
|
|
from src.common.geo_utils import rmsd, _find_rigid_alignment, squared_deviation |
|
|
from scipy.linalg import fractional_matrix_power |
|
|
from sklearn.mixture import GaussianMixture |
|
|
from Bio.PDB import PDBParser |
|
|
|
|
|
from Bio.PDB.Polypeptide import PPBuilder |
|
|
import multiprocessing as mp |
|
|
|
|
|
EPS = 1e-12 |
|
|
PSEUDO_C = 1e-6 |
|
|
|
|
|
|
|
|
def adjacent_ca_distance(coords): |
|
|
"""Calculate distance array for a single chain of CA atoms. Only k=1 neighbors. |
|
|
Args: |
|
|
coords: (..., L, 3) |
|
|
return |
|
|
dist: (..., L-1) |
|
|
""" |
|
|
assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
|
|
dX = coords[..., :-1, :] - coords[..., 1:, :] |
|
|
dist = np.sqrt(np.sum(dX**2, axis=-1)) |
|
|
return dist |
|
|
|
|
|
|
|
|
def distance_matrix_ca(coords): |
|
|
"""Calculate distance matrix for a single chain of CA atoms. W/o exclude neighbors. |
|
|
Args: |
|
|
coords: (..., L, 3) |
|
|
Return: |
|
|
dist: (..., L, L) |
|
|
""" |
|
|
assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
|
|
dX = coords[..., None, :, :] - coords[..., None, :] |
|
|
dist = np.sqrt(np.sum(dX**2, axis=-1)) |
|
|
return dist |
|
|
|
|
|
|
|
|
def pairwise_distance_ca(coords, k=1): |
|
|
"""Calculate pairwise distance vector for a single chain of CA atoms. W/o exclude neighbors. |
|
|
Args: |
|
|
coords: (..., L, 3) |
|
|
Return: |
|
|
dist: (..., D) (D=L * (L - 1) // 2) when k=1) |
|
|
""" |
|
|
assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
|
|
dist = distance_matrix_ca(coords) |
|
|
L = dist.shape[-1] |
|
|
row, col = np.triu_indices(L, k=k) |
|
|
triu = dist[..., row, col] |
|
|
return triu |
|
|
|
|
|
|
|
|
def radius_of_gyration(coords, masses=None): |
|
|
"""Compute the radius of gyration for every frame. |
|
|
|
|
|
Args: |
|
|
coords: (..., num_atoms, 3) |
|
|
masses: (num_atoms,) |
|
|
|
|
|
Returns: |
|
|
Rg: (..., ) |
|
|
|
|
|
If masses are none, assumes equal masses. |
|
|
""" |
|
|
assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
|
|
|
|
|
if masses is None: |
|
|
masses = np.ones(coords.shape[-2]) |
|
|
else: |
|
|
assert len(masses.shape) == 1, f"masses should be 1D, got {masses.shape}" |
|
|
assert masses.shape[0] == coords.shape[-2], f"masses {masses.shape} != number of particles {coords.shape[-2]}" |
|
|
|
|
|
weights = masses / masses.sum() |
|
|
centered = coords - coords.mean(-2, keepdims=True) |
|
|
squared_dists = (centered ** 2).sum(-1) |
|
|
Rg = (squared_dists * weights).sum(-1) ** 0.5 |
|
|
return Rg |
|
|
|
|
|
|
|
|
def _steric_clash(coords, ca_vdw_radius=1.7, allowable_overlap=0.4, k_exclusion=0): |
|
|
""" https://www.schrodinger.com/sites/default/files/s3/public/python_api/2022-3/_modules/schrodinger/structutils/interactions/steric_clash.html#clash_iterator |
|
|
Calculate the number of clashes in a single chain of CA atoms. |
|
|
|
|
|
Usage: |
|
|
n_clash = calc_clash(coords) |
|
|
|
|
|
Args: |
|
|
coords: (n_atoms, 3), CA coordinates, coords should from one protein chain. |
|
|
ca_vdw_radius: float, default 1.7. |
|
|
allowable_overlap: float, default 0.4. |
|
|
k_exclusion: int, default 0. Exclude neighbors within [i-k-1, i+k+1]. |
|
|
|
|
|
""" |
|
|
assert np.isnan(coords).sum() == 0, "coords should not contain nan" |
|
|
assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
|
|
assert k_exclusion >= 0, "k_exclusion should be non-negative" |
|
|
bar = 2 * ca_vdw_radius - allowable_overlap |
|
|
|
|
|
|
|
|
pwd = pairwise_distance_ca(coords, k=k_exclusion+1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(pwd.shape) == 2, f"pwd should be 2D, got {pwd.shape}" |
|
|
n_clash = np.sum(pwd < bar, axis=-1) |
|
|
return n_clash.astype(int) |
|
|
|
|
|
|
|
|
def validity(ca_coords_dict, **clash_kwargs): |
|
|
"""Calculate clash validity of ensembles. |
|
|
Args: |
|
|
ca_coords_dict: {k: (B, L, 3)} |
|
|
Return: |
|
|
valid: {k: validity in [0,1]} |
|
|
""" |
|
|
num_residue = float(ca_coords_dict['target'].shape[1]) |
|
|
n_clash = { |
|
|
k: _steric_clash(v, **clash_kwargs) |
|
|
for k, v in ca_coords_dict.items() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
results = { |
|
|
k: 1.0 - (v/num_residue).mean() for k, v in n_clash.items() |
|
|
} |
|
|
|
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
return results |
|
|
|
|
|
|
|
|
def bonding_validity(ca_coords_dict, ref_key='target', eps=1e-6): |
|
|
"""Calculate bonding dissociation validity of ensembles.""" |
|
|
adj_dist = {k: adjacent_ca_distance(v) |
|
|
for k, v in ca_coords_dict.items() |
|
|
} |
|
|
thres = adj_dist[ref_key].max()+ 1e-6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = { |
|
|
k: (v < thres).mean() |
|
|
for k, v in adj_dist.items() |
|
|
} |
|
|
|
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
return results |
|
|
|
|
|
|
|
|
def js_pwd(ca_coords_dict, ref_key='target', n_bins=50, pwd_offset=3, weights=None): |
|
|
|
|
|
|
|
|
|
|
|
ca_pwd = { |
|
|
k: pairwise_distance_ca(v, k=pwd_offset) for k, v in ca_coords_dict.items() |
|
|
} |
|
|
|
|
|
if weights is None: |
|
|
weights = {} |
|
|
weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) |
|
|
|
|
|
d_min = ca_pwd[ref_key].min(axis=0) |
|
|
d_max = ca_pwd[ref_key].max(axis=0) |
|
|
ca_pwd_binned = { |
|
|
k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0, |
|
|
np.concatenate([v, d_min[None], d_max[None]], axis=0)) |
|
|
for k, v in ca_pwd.items() |
|
|
} |
|
|
|
|
|
results = {k: distance.jensenshannon(v, ca_pwd_binned[ref_key], axis=0).mean() |
|
|
for k, v in ca_pwd_binned.items() if k != ref_key} |
|
|
results[ref_key] = 0.0 |
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def js_rg(ca_coords_dict, ref_key='target', n_bins=50, weights=None): |
|
|
ca_rg = { |
|
|
k: radius_of_gyration(v) for k, v in ca_coords_dict.items() |
|
|
} |
|
|
if weights is None: |
|
|
weights = {} |
|
|
weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) |
|
|
|
|
|
d_min = ca_rg[ref_key].min() |
|
|
d_max = ca_rg[ref_key].max() |
|
|
ca_rg_binned = { |
|
|
k: np.histogram(v, bins=n_bins, weights=weights[k], range=(d_min, d_max))[0]+PSEUDO_C |
|
|
for k, v in ca_rg.items() |
|
|
} |
|
|
|
|
|
results = {k: distance.jensenshannon(v, ca_rg_binned[ref_key], axis=0).mean() |
|
|
for k, v in ca_rg_binned.items() if k != ref_key} |
|
|
|
|
|
results[ref_key] = 0.0 |
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
return results |
|
|
|
|
|
def div_rmsd(ca_coords_dict): |
|
|
results = {} |
|
|
for k, v in ca_coords_dict.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
v = torch.as_tensor(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
count = 0 |
|
|
rmsd_2_sum = 0 |
|
|
for coord1 in v: |
|
|
for coord2 in v: |
|
|
count += 1 |
|
|
rmsd_2_sum += squared_deviation(coord1,coord2,reduction='none') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results[k]=torch.sqrt(rmsd_2_sum/count) |
|
|
results[k]=np.around(float(torch.mean(results[k])), decimals=4) |
|
|
results['pred'] = (results['pred']-results['target'])/results['target'] |
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
|
|
|
return results |
|
|
|
|
|
def div_rmsf(ca_coords_dict): |
|
|
''' |
|
|
1D and 0D data |
|
|
''' |
|
|
results = {} |
|
|
for k, v in ca_coords_dict.items(): |
|
|
|
|
|
v = torch.as_tensor(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
count = 0 |
|
|
rmsd_2_sum = 0 |
|
|
mean_str = torch.mean(v,dim = 0) |
|
|
for coord1 in v: |
|
|
count += 1 |
|
|
rmsd_2_sum += squared_deviation(coord1,mean_str,reduction='none') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results[k]=torch.sqrt(rmsd_2_sum/count) |
|
|
results[k]=np.around(float(torch.mean(results[k])), decimals=4) |
|
|
|
|
|
results['pred'] = (results['pred']-results['target'])/results['target'] |
|
|
results = {k: np.around(v, decimals=4) for k, v in results.items()} |
|
|
return results |
|
|
|
|
|
def w2_rmwd(ca_coords_dict): |
|
|
result = {} |
|
|
means_total = {} |
|
|
covariances_total = {} |
|
|
count = 0 |
|
|
v_ref = torch.as_tensor(ca_coords_dict['target'][0]) |
|
|
for k, v in ca_coords_dict.items(): |
|
|
|
|
|
v = torch.as_tensor(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
means_total[k] = [] |
|
|
covariances_total[k] = [] |
|
|
|
|
|
for idx_residue in range(v.shape[1]): |
|
|
gmm = GaussianMixture(n_components=1) |
|
|
gmm.fit(v[:, idx_residue, :]) |
|
|
means = torch.as_tensor(gmm.means_[0]) |
|
|
covariances = torch.as_tensor(gmm.covariances_[0]) |
|
|
|
|
|
means_total[k].append(means) |
|
|
covariances_total[k].append(covariances) |
|
|
means_total[k] = torch.stack(means_total[k], dim=0) |
|
|
covariances_total[k] = torch.stack(covariances_total[k], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
sigma_1_2_sqrt = [torch.as_tensor(fractional_matrix_power(i, 0.5)) for i in torch.matmul(covariances_total['target'], covariances_total['pred'])] |
|
|
sigma_1_2_sqrt = torch.stack(sigma_1_2_sqrt, dim=0) |
|
|
sigma_trace = covariances_total['target'] + covariances_total['pred'] - 2 * sigma_1_2_sqrt |
|
|
sigma_trace = [torch.trace(i) for i in sigma_trace] |
|
|
sigma_trace = torch.stack(sigma_trace, dim=0) |
|
|
|
|
|
result_1D = torch.sum((means_total['target'] - means_total['pred'])**2, dim=-1) + sigma_trace |
|
|
result['pred'] = np.around(float(torch.mean(result_1D)), decimals=4) |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def pro_w_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): |
|
|
result = {} |
|
|
w_contacts_total = {} |
|
|
|
|
|
dist = distance_matrix_ca(cry_ca_coords) |
|
|
L = dist.shape[-1] |
|
|
row, col = np.triu_indices(L, k=1) |
|
|
triu = dist[..., row, col] |
|
|
w_contacts_crystall = (triu < dist_threshold) |
|
|
|
|
|
for k, v in ca_coords_dict.items(): |
|
|
|
|
|
dist = distance_matrix_ca(v) |
|
|
|
|
|
L = dist.shape[-1] |
|
|
row, col = np.triu_indices(L, k=1) |
|
|
triu = dist[..., row, col] |
|
|
|
|
|
w_contacts = (torch.tensor(triu) > dist_threshold).type(torch.float32) |
|
|
w_contacts = torch.mean(w_contacts, dim=0) |
|
|
w_contacts = w_contacts > percent_threshold |
|
|
|
|
|
w_contacts_total[k] = w_contacts & w_contacts_crystall |
|
|
|
|
|
jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) |
|
|
result['pred'] = np.around(float(jac_w_contacts), decimals=4) |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def pro_t_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): |
|
|
result = {} |
|
|
w_contacts_total = {} |
|
|
|
|
|
dist = distance_matrix_ca(cry_ca_coords) |
|
|
L = dist.shape[-1] |
|
|
row, col = np.triu_indices(L, k=1) |
|
|
triu = dist[..., row, col] |
|
|
w_contacts_crystall = (triu >= dist_threshold) |
|
|
|
|
|
for k, v in ca_coords_dict.items(): |
|
|
|
|
|
dist = distance_matrix_ca(v) |
|
|
|
|
|
L = dist.shape[-1] |
|
|
row, col = np.triu_indices(L, k=1) |
|
|
triu = dist[..., row, col] |
|
|
|
|
|
w_contacts = (torch.tensor(triu) <= dist_threshold).type(torch.float32) |
|
|
w_contacts = torch.mean(w_contacts, dim=0) |
|
|
w_contacts = w_contacts > percent_threshold |
|
|
|
|
|
w_contacts_total[k] = w_contacts & w_contacts_crystall |
|
|
|
|
|
jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) |
|
|
result['pred'] = np.around(float(jac_w_contacts), decimals=4) |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|