import numpy as np import os import re from data import protein from openfold.utils import rigid_utils Rigid = rigid_utils.Rigid def create_full_prot( atom37: np.ndarray, atom37_mask: np.ndarray, aatype=None, b_factors=None, ): assert atom37.ndim == 3 assert atom37.shape[-1] == 3 assert atom37.shape[-2] == 37 n = atom37.shape[0] residue_index = np.arange(n) chain_index = np.zeros(n) if b_factors is None: b_factors = np.zeros([n, 37]) if aatype is None: aatype = np.zeros(n, dtype=int) return protein.Protein( atom_positions=atom37, atom_mask=atom37_mask, aatype=aatype, residue_index=residue_index, chain_index=chain_index, b_factors=b_factors) def write_prot_to_pdb( prot_pos: np.ndarray, file_path: str, aatype: np.ndarray=None, overwrite=False, no_indexing=False, b_factors=None, ): if overwrite: max_existing_idx = 0 else: file_dir = os.path.dirname(file_path) file_name = os.path.basename(file_path).strip('.pdb') existing_files = [x for x in os.listdir(file_dir) if file_name in x] max_existing_idx = max([ int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x) if re.findall(r'_(\d+).pdb', x)] + [0]) if not no_indexing: save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb' else: save_path = file_path with open(save_path, 'w') as f: if prot_pos.ndim == 4: for t, pos37 in enumerate(prot_pos): atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7 prot = create_full_prot( pos37, atom37_mask, aatype=aatype, b_factors=b_factors) pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False) f.write(pdb_prot) elif prot_pos.ndim == 3: atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7 prot = create_full_prot( prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors) pdb_prot = protein.to_pdb(prot, model=1, add_end=False) f.write(pdb_prot) else: raise ValueError(f'Invalid positions shape {prot_pos.shape}') f.write('END') return save_path