| import numpy as np | |
| import os | |
| import re | |
| from data import protein | |
| from openfold.utils import rigid_utils | |
| Rigid = rigid_utils.Rigid | |
| def create_full_prot( | |
| atom37: np.ndarray, | |
| atom37_mask: np.ndarray, | |
| aatype=None, | |
| b_factors=None, | |
| ): | |
| assert atom37.ndim == 3 | |
| assert atom37.shape[-1] == 3 | |
| assert atom37.shape[-2] == 37 | |
| n = atom37.shape[0] | |
| residue_index = np.arange(n) | |
| chain_index = np.zeros(n) | |
| if b_factors is None: | |
| b_factors = np.zeros([n, 37]) | |
| if aatype is None: | |
| aatype = np.zeros(n, dtype=int) | |
| return protein.Protein( | |
| atom_positions=atom37, | |
| atom_mask=atom37_mask, | |
| aatype=aatype, | |
| residue_index=residue_index, | |
| chain_index=chain_index, | |
| b_factors=b_factors) | |
| def write_prot_to_pdb( | |
| prot_pos: np.ndarray, | |
| file_path: str, | |
| aatype: np.ndarray=None, | |
| overwrite=False, | |
| no_indexing=False, | |
| b_factors=None, | |
| ): | |
| if overwrite: | |
| max_existing_idx = 0 | |
| else: | |
| file_dir = os.path.dirname(file_path) | |
| file_name = os.path.basename(file_path).strip('.pdb') | |
| existing_files = [x for x in os.listdir(file_dir) if file_name in x] | |
| max_existing_idx = max([ | |
| int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x) | |
| if re.findall(r'_(\d+).pdb', x)] + [0]) | |
| if not no_indexing: | |
| save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb' | |
| else: | |
| save_path = file_path | |
| with open(save_path, 'w') as f: | |
| if prot_pos.ndim == 4: | |
| for t, pos37 in enumerate(prot_pos): | |
| atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 | |
| prot = create_full_prot( | |
| pos37, atom37_mask, aatype=aatype, b_factors=b_factors) | |
| pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False) | |
| f.write(pdb_prot) | |
| elif prot_pos.ndim == 3: | |
| atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7 | |
| prot = create_full_prot( | |
| prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors) | |
| pdb_prot = protein.to_pdb(prot, model=1, add_end=False) | |
| f.write(pdb_prot) | |
| else: | |
| raise ValueError(f'Invalid positions shape {prot_pos.shape}') | |
| f.write('END') | |
| return save_path | |