| | import io |
| | from itertools import accumulate, chain |
| | from copy import deepcopy |
| | import random |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from rdkit import Chem |
| | from torch_scatter import scatter_mean |
| | from Bio.PDB import StructureBuilder, Chain, Model, Structure |
| | from Bio.PDB.PICIO import read_PIC, write_PIC |
| | from scipy.ndimage import gaussian_filter |
| | from pdb import set_trace |
| |
|
| | from src.constants import FLOAT_TYPE, INT_TYPE |
| | from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index |
| | from src import utils |
| | from src.data.misc import protein_letters_3to1, is_aa |
| | from src.data.normal_modes import pdb_to_normal_modes |
| | from src.data.nerf import get_nerf_params, ic_to_coords |
| | import src.data.so3_utils as so3 |
| |
|
| |
|
| | class TensorDict(dict): |
| | def __init__(self, **kwargs): |
| | super(TensorDict, self).__init__(**kwargs) |
| |
|
| | def _apply(self, func: str, *args, **kwargs): |
| | """ Apply function to all tensors. """ |
| | for k, v in self.items(): |
| | if torch.is_tensor(v): |
| | self[k] = getattr(v, func)(*args, **kwargs) |
| | return self |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def cuda(self): |
| | return self.to('cuda') |
| |
|
| | def cpu(self): |
| | return self.to('cpu') |
| | |
| | def to(self, device): |
| | return self._apply("to", device) |
| | |
| | def detach(self): |
| | return self._apply("detach") |
| |
|
| | def __repr__(self): |
| | def val_to_str(val): |
| | if isinstance(val, torch.Tensor): |
| | |
| | |
| | return "%r" % list(val.size()) |
| | if isinstance(val, list): |
| | return "[%r,]" % len(val) |
| | else: |
| | return "?" |
| |
|
| | return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})" |
| |
|
| |
|
| | def collate_entity(batch): |
| |
|
| | out = {} |
| | for prop in batch[0].keys(): |
| |
|
| | if prop == 'name': |
| | out[prop] = [x[prop] for x in batch] |
| |
|
| | elif prop == 'size' or prop == 'n_bonds': |
| | out[prop] = torch.tensor([x[prop] for x in batch]) |
| |
|
| | elif prop == 'bonds': |
| | |
| | offset = list(accumulate([x['size'] for x in batch], initial=0)) |
| | out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1) |
| |
|
| | elif prop == 'residues': |
| | out[prop] = list(chain.from_iterable(x[prop] for x in batch)) |
| |
|
| | elif prop in {'mask', 'bond_mask'}: |
| | pass |
| |
|
| | else: |
| | out[prop] = torch.cat([x[prop] for x in batch], dim=0) |
| |
|
| | |
| | |
| | if prop == 'x': |
| | out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
| | for i, x in enumerate(batch)], dim=0) |
| | if prop == 'bond_one_hot': |
| | |
| | out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
| | for i, x in enumerate(batch)], dim=0) |
| |
|
| | return out |
| |
|
| |
|
| | def split_entity( |
| | batch, |
| | *, |
| | index_types={'bonds'}, |
| | edge_types={'bond_one_hot', 'bond_mask'}, |
| | no_split={'name', 'size', 'n_bonds'}, |
| | skip={'fragments'}, |
| | batch_mask=None, |
| | edge_mask=None |
| | ): |
| | """ Splits a batch into items and returns a list. """ |
| |
|
| | batch_mask = batch["mask"] if batch_mask is None else batch_mask |
| | edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask |
| | sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist() |
| |
|
| | batch_size = len(torch.unique(batch['mask'])) |
| | out = {} |
| | for prop in batch.keys(): |
| | if prop in skip: |
| | continue |
| | if prop in no_split: |
| | out[prop] = batch[prop] |
| |
|
| | elif prop in index_types: |
| | offsets = list(accumulate(sizes[:-1], initial=0)) |
| | out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets) |
| |
|
| | elif prop in edge_types: |
| | out[prop] = utils.batch_to_list(batch[prop], edge_mask) |
| |
|
| | else: |
| | out[prop] = utils.batch_to_list(batch[prop], batch_mask) |
| |
|
| | out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)] |
| | return out |
| |
|
| |
|
| | def repeat_items(batch, repeats): |
| | batch_list = split_entity(batch) |
| | out = collate_entity([x for _ in range(repeats) for x in batch_list]) |
| | return type(batch)(**out) |
| |
|
| |
|
| | def get_side_chain_bead_coord(biopython_residue): |
| | """ |
| | Places side chain bead at the location of the farthest side chain atom. |
| | """ |
| | if biopython_residue.get_resname() == 'GLY': |
| | return None |
| | if biopython_residue.get_resname() == 'ALA': |
| | return biopython_residue['CB'].get_coord() |
| |
|
| | ca_coord = biopython_residue['CA'].get_coord() |
| | side_chain_atoms = [a for a in biopython_residue.get_atoms() if |
| | a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H'] |
| | side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms]) |
| |
|
| | atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1)) |
| |
|
| | return side_chain_coords[atom_idx, :] |
| |
|
| |
|
| | def get_side_chain_vectors(res, index_dict, size=None): |
| | if size is None: |
| | size = max([x for aa in index_dict.values() for x in aa.values()]) + 1 |
| |
|
| | resname = protein_letters_3to1[res.get_resname()] |
| |
|
| | out = np.zeros((size, 3)) |
| | for atom in res.get_atoms(): |
| | if atom.get_name() in index_dict[resname]: |
| | idx = index_dict[resname][atom.get_name()] |
| | out[idx] = atom.get_coord() - res['CA'].get_coord() |
| | |
| | |
| | |
| |
|
| | return out |
| |
|
| |
|
| | def get_normal_modes(res, normal_mode_dict): |
| | nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] |
| | return nm |
| |
|
| |
|
| | def get_torsion_angles(res, device=None): |
| | """ |
| | Return the five chi angles. Missing angles are filled with zeros. |
| | """ |
| | ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
| |
|
| | ic_res = res.internal_coord |
| | chi_angles = [ic_res.get_angle(chi) for chi in ANGLES] |
| | chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles] |
| |
|
| | return torch.tensor(chi_angles, device=device) * np.pi / 180 |
| |
|
| |
|
| | def apply_torsion_angles(res, chi_angles): |
| | """ |
| | Set side chain torsion angles of a biopython residue object with |
| | internal coordinates. |
| | """ |
| | ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
| |
|
| | chi_angles = chi_angles * 180 / np.pi |
| |
|
| | |
| |
|
| | ic_res = res.internal_coord |
| | for chi, angle in zip(ANGLES, chi_angles): |
| | if ic_res.pick_angle(chi) is None: |
| | continue |
| | ic_res.bond_set(chi, angle) |
| |
|
| | res.parent.internal_to_atom_coordinates(verbose=False) |
| | |
| | |
| |
|
| | return res |
| |
|
| |
|
| | def prepare_internal_coord(res): |
| |
|
| | |
| | new_struct = Structure.Structure('X') |
| | new_struct.header = {} |
| | new_model = Model.Model(0) |
| | new_struct.add(new_model) |
| | new_chain = Chain.Chain('X') |
| | new_model.add(new_chain) |
| | new_chain.add(res) |
| | res.set_parent(new_chain) |
| |
|
| | |
| | new_chain.atom_to_internal_coordinates() |
| |
|
| | pic_io = io.StringIO() |
| | write_PIC(new_struct, pic_io) |
| | return pic_io.getvalue() |
| |
|
| |
|
| | def residue_from_internal_coord(ic_string): |
| | pic_io = io.StringIO(ic_string) |
| | struct = read_PIC(pic_io, quick=True) |
| | res = struct.child_list[0].child_list[0].child_list[0] |
| | res.parent.internal_to_atom_coordinates(verbose=False) |
| | return res |
| |
|
| |
|
| | def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder, |
| | residue_bond_encoder, pocket_representation='side_chain_bead', |
| | compute_nerf_params=False, compute_bb_frames=False, |
| | nma_input=None): |
| |
|
| | assert nma_input is None or pocket_representation == 'CA+', \ |
| | "vector features are only supported for CA+ pockets" |
| |
|
| | |
| | biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1])) |
| |
|
| | if nma_input is not None: |
| | |
| | if isinstance(nma_input, dict): |
| | nma_dict = nma_input |
| |
|
| | |
| | else: |
| | nma_dict = pdb_to_normal_modes(str(nma_input)) |
| |
|
| | if pocket_representation == 'side_chain_bead': |
| | ca_coords = np.zeros((len(biopython_residues), 3)) |
| | ca_types = np.zeros(len(biopython_residues), dtype='int64') |
| | side_chain_coords = [] |
| | side_chain_aa_types = [] |
| | edges = [] |
| | edge_types = [] |
| | last_res_id = None |
| | for i, res in enumerate(biopython_residues): |
| | aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
| | ca_coords[i, :] = res['CA'].get_coord() |
| | ca_types[i] = aa |
| | side_chain_coord = get_side_chain_bead_coord(res) |
| | if side_chain_coord is not None: |
| | side_chain_coords.append(side_chain_coord) |
| | side_chain_aa_types.append(aa) |
| | edges.append((i, len(ca_coords) + len(side_chain_coords) - 1)) |
| | edge_types.append(residue_bond_encoder['CA-SS']) |
| |
|
| | |
| | if i > 0 and res.id[1] == last_res_id + 1: |
| | edges.append((i - 1, i)) |
| | edge_types.append(residue_bond_encoder['CA-CA']) |
| |
|
| | last_res_id = res.id[1] |
| |
|
| | |
| | side_chain_coords = np.stack(side_chain_coords) |
| | pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0) |
| | pocket_coords = torch.from_numpy(pocket_coords) |
| |
|
| | |
| | amino_acid_onehot = F.one_hot( |
| | torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0), |
| | num_classes=len(amino_acid_encoder) |
| | ) |
| | side_chain_onehot = np.concatenate([ |
| | np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']), |
| | [len(ca_coords), 1]), |
| | np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']), |
| | [len(side_chain_coords), 1]) |
| | ], axis=0) |
| | side_chain_onehot = torch.from_numpy(side_chain_onehot) |
| | pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1) |
| |
|
| | vector_features = None |
| | nma_features = None |
| |
|
| | |
| | edges = torch.tensor(edges).T |
| | edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) |
| |
|
| | elif pocket_representation == 'CA+': |
| | ca_coords = np.zeros((len(biopython_residues), 3)) |
| | ca_types = np.zeros(len(biopython_residues), dtype='int64') |
| |
|
| | v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 |
| | vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32') |
| | nf_nma = 5 |
| | nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32') |
| |
|
| | edges = [] |
| | edge_types = [] |
| | last_res_id = None |
| | for i, res in enumerate(biopython_residues): |
| | aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
| | ca_coords[i, :] = res['CA'].get_coord() |
| | ca_types[i] = aa |
| |
|
| | vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim) |
| | if nma_input is not None: |
| | nma_feats[i] = get_normal_modes(res, nma_dict) |
| |
|
| | |
| | if i > 0 and res.id[1] == last_res_id + 1: |
| | edges.append((i - 1, i)) |
| | edge_types.append(residue_bond_encoder['CA-CA']) |
| |
|
| | last_res_id = res.id[1] |
| |
|
| | |
| | pocket_coords = torch.from_numpy(ca_coords) |
| |
|
| | |
| | pocket_onehot = F.one_hot(torch.from_numpy(ca_types), |
| | num_classes=len(amino_acid_encoder)) |
| |
|
| | vector_features = torch.from_numpy(vec_feats) |
| | nma_features = torch.from_numpy(nma_feats) |
| |
|
| | |
| | if len(edges) < 1: |
| | edges = torch.empty(2, 0) |
| | edge_types = torch.empty(0, len(residue_bond_encoder)) |
| | else: |
| | edges = torch.tensor(edges).T |
| | edge_types = F.one_hot(torch.tensor(edge_types), |
| | num_classes=len(residue_bond_encoder)) |
| |
|
| | else: |
| | raise NotImplementedError( |
| | f"Pocket representation '{pocket_representation}' not implemented") |
| |
|
| | |
| |
|
| | pocket = { |
| | 'x': pocket_coords.to(dtype=FLOAT_TYPE), |
| | 'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE), |
| | |
| | 'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE), |
| | 'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE), |
| | 'bonds': edges.to(INT_TYPE), |
| | 'bond_one_hot': edge_types.to(FLOAT_TYPE), |
| | 'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE), |
| | 'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE), |
| | } |
| |
|
| | if vector_features is not None: |
| | pocket['v'] = vector_features.to(dtype=FLOAT_TYPE) |
| |
|
| | if nma_input is not None: |
| | pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE) |
| |
|
| | if compute_nerf_params: |
| | nerf_params = [get_nerf_params(r) for r in biopython_residues] |
| | nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0) |
| | for k in nerf_params[0].keys()} |
| | pocket.update(nerf_params) |
| |
|
| | if compute_bb_frames: |
| | n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues])) |
| | ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues])) |
| | c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues])) |
| | pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz) |
| |
|
| | return pocket, biopython_residues |
| |
|
| |
|
| | def encode_atom(rd_atom, atom_encoder): |
| | element = rd_atom.GetSymbol().capitalize() |
| |
|
| | explicitHs = rd_atom.GetNumExplicitHs() |
| | if explicitHs == 1 and f'{element}H' in atom_encoder: |
| | return atom_encoder[f'{element}H'] |
| |
|
| | charge = rd_atom.GetFormalCharge() |
| | if charge == 1 and f'{element}+' in atom_encoder: |
| | return atom_encoder[f'{element}+'] |
| | if charge == -1 and f'{element}-' in atom_encoder: |
| | return atom_encoder[f'{element}-'] |
| |
|
| | return atom_encoder[element] |
| |
|
| |
|
| | def prepare_ligand(rdmol, atom_encoder, bond_encoder): |
| |
|
| | |
| | if 'H' not in atom_encoder: |
| | rdmol = Chem.RemoveAllHs(rdmol, sanitize=False) |
| |
|
| | |
| | ligand_coord = rdmol.GetConformer().GetPositions() |
| | ligand_coord = torch.from_numpy(ligand_coord) |
| |
|
| | |
| | ligand_onehot = F.one_hot( |
| | torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]), |
| | num_classes=len(atom_encoder) |
| | ) |
| |
|
| | |
| | adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND'] |
| | for b in rdmol.GetBonds(): |
| | i = b.GetBeginAtomIdx() |
| | j = b.GetEndAtomIdx() |
| | adj[i, j] = bond_encoder[str(b.GetBondType())] |
| | adj[j, i] = adj[i, j] |
| |
|
| | |
| | bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0) |
| | |
| | bond_types = adj[bonds[0], bonds[1]].astype('int64') |
| | bonds = torch.from_numpy(bonds) |
| | bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder)) |
| |
|
| | ligand = { |
| | 'x': ligand_coord.to(dtype=FLOAT_TYPE), |
| | 'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE), |
| | 'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE), |
| | 'bonds': bonds.to(INT_TYPE), |
| | 'bond_one_hot': bond_types.to(FLOAT_TYPE), |
| | 'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE), |
| | 'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE), |
| | 'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE), |
| | } |
| |
|
| | return ligand |
| |
|
| |
|
| | def process_raw_molecule_with_empty_pocket(rdmol): |
| | ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
| | pocket = { |
| | 'x': torch.tensor([], dtype=FLOAT_TYPE), |
| | 'one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
| | 'size': torch.tensor([], dtype=INT_TYPE), |
| | 'mask': torch.tensor([], dtype=INT_TYPE), |
| | 'bonds': torch.tensor([], dtype=INT_TYPE), |
| | 'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
| | 'bond_mask': torch.tensor([], dtype=INT_TYPE), |
| | 'n_bonds': torch.tensor([], dtype=INT_TYPE), |
| | } |
| | return ligand, pocket |
| |
|
| |
|
| | def process_raw_pair(biopython_model, rdmol, dist_cutoff=None, |
| | pocket_representation='side_chain_bead', |
| | compute_nerf_params=False, compute_bb_frames=False, |
| | nma_input=None, return_pocket_pdb=False): |
| |
|
| | |
| | ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
| |
|
| | |
| | pocket_residues = [] |
| | for residue in biopython_model.get_residues(): |
| |
|
| | |
| | if not is_aa(residue.get_resname(), standard=True): |
| | continue |
| |
|
| | res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()])) |
| | if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: |
| | pocket_residues.append(residue) |
| |
|
| | pocket, pocket_residues = prepare_pocket( |
| | pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder, |
| | pocket_representation, compute_nerf_params, compute_bb_frames, nma_input |
| | ) |
| |
|
| | if return_pocket_pdb: |
| | builder = StructureBuilder.StructureBuilder() |
| | builder.init_structure("") |
| | builder.init_model(0) |
| | pocket_struct = builder.get_structure() |
| | for residue in pocket_residues: |
| | chain = residue.get_parent().get_id() |
| |
|
| | |
| | if not pocket_struct[0].has_id(chain): |
| | builder.init_chain(chain) |
| |
|
| | |
| | pocket_struct[0][chain].add(residue) |
| |
|
| | pocket['pocket_pdb'] = pocket_struct |
| | |
| | |
| |
|
| | return ligand, pocket |
| |
|
| |
|
| | class AppendVirtualNodes: |
| | def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0): |
| | self.max_size = max_ligand_size |
| | self.atom_encoder = atom_encoder |
| | self.bond_encoder = bond_encoder |
| | self.vidx = atom_encoder['NOATOM'] |
| | self.bidx = bond_encoder['NOBOND'] |
| | self.scale = scale |
| |
|
| | def __call__(self, ligand, max_size=None, eps=1e-6): |
| | if max_size is None: |
| | max_size = self.max_size |
| |
|
| | n_virt = max_size - ligand['size'] |
| |
|
| | C = torch.cov(ligand['x'].T) |
| | L = torch.linalg.cholesky(C + torch.eye(3) * eps) |
| | mu = ligand['x'].mean(0, keepdim=True) |
| | virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale |
| |
|
| | |
| | virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
| | virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
| |
|
| | ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
| | ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
| | ligand['virtual_mask'] = virt_mask |
| | ligand['size'] = max_size |
| |
|
| | |
| | new_bonds = torch.triu_indices(max_size, max_size, offset=1) |
| |
|
| | bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx |
| | row, col = ligand['bonds'] |
| | bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
| | new_row, new_col = new_bonds |
| | bond_types = bond_types[new_row, new_col] |
| |
|
| | ligand['bonds'] = new_bonds |
| | ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
| | ligand['n_bonds'] = len(ligand['bond_one_hot']) |
| |
|
| | return ligand |
| |
|
| |
|
| | class AppendVirtualNodesInCoM: |
| | def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10): |
| | self.atom_encoder = atom_encoder |
| | self.bond_encoder = bond_encoder |
| | self.vidx = atom_encoder['NOATOM'] |
| | self.bidx = bond_encoder['NOBOND'] |
| | self.add_min = add_min |
| | self.add_max = add_max |
| |
|
| | def __call__(self, ligand): |
| |
|
| | n_virt = random.randint(self.add_min, self.add_max) |
| |
|
| | |
| | virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1) |
| |
|
| | |
| | virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
| | virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
| |
|
| | ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
| | ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
| | ligand['virtual_mask'] = virt_mask |
| | ligand['size'] = len(ligand['x']) |
| |
|
| | |
| | new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1) |
| |
|
| | bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx |
| | row, col = ligand['bonds'] |
| | bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
| | new_row, new_col = new_bonds |
| | bond_types = bond_types[new_row, new_col] |
| |
|
| | ligand['bonds'] = new_bonds |
| | ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
| | ligand['n_bonds'] = len(ligand['bond_one_hot']) |
| |
|
| | return ligand |
| |
|
| |
|
| | def rdmol_to_smiles(rdmol): |
| | mol = Chem.Mol(rdmol) |
| | Chem.RemoveStereochemistry(mol) |
| | mol = Chem.RemoveHs(mol) |
| | return Chem.MolToSmiles(mol) |
| |
|
| |
|
| | def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None): |
| | |
| | n_nodes_lig = [len(x) for x in lig_positions] |
| | n_nodes_pocket = [len(x) for x in pocket_positions] |
| |
|
| | joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, |
| | np.max(n_nodes_pocket) + 1)) |
| |
|
| | for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): |
| | joint_histogram[nlig, npocket] += 1 |
| |
|
| | print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' |
| | f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') |
| |
|
| | |
| | if smooth_sigma is not None: |
| | filtered_histogram = gaussian_filter( |
| | joint_histogram, sigma=smooth_sigma, order=0, mode='constant', |
| | cval=0.0, truncate=4.0) |
| |
|
| | print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' |
| | f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') |
| |
|
| | joint_histogram = filtered_histogram |
| |
|
| | return joint_histogram |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def get_type_histogram(one_hot, type_encoder): |
| |
|
| | one_hot = np.concatenate(one_hot, axis=0) |
| |
|
| | decoder = list(type_encoder.keys()) |
| | counts = {k: 0 for k in type_encoder.keys()} |
| | for a in [decoder[x] for x in one_hot.argmax(1)]: |
| | counts[a] += 1 |
| |
|
| | return counts |
| |
|
| |
|
| | def get_residue_with_resi(pdb_chain, resi): |
| | res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] |
| | assert len(res) == 1 |
| | return res[0] |
| |
|
| |
|
| | def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0): |
| |
|
| | if ligand.endswith(".sdf"): |
| | |
| | rdmol = Chem.SDMolSupplier(str(ligand))[0] |
| | ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float() |
| | resi = None |
| | else: |
| | |
| | chain, resi = ligand.split(':') |
| | ligand = get_residue_with_resi(pdb_model[chain], int(resi)) |
| | ligand_coords = torch.from_numpy( |
| | np.array([a.get_coord() for a in ligand.get_atoms()])) |
| |
|
| | pocket_residues = [] |
| | for residue in pdb_model.get_residues(): |
| | if residue.id[1] == resi: |
| | continue |
| |
|
| | res_coords = torch.from_numpy( |
| | np.array([a.get_coord() for a in residue.get_atoms()])) |
| | if is_aa(residue.get_resname(), standard=True) \ |
| | and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff: |
| | pocket_residues.append(residue) |
| |
|
| | return pocket_residues |
| |
|
| |
|
| | def encode_residues(biopython_residues, type_encoder, level='atom', |
| | remove_H=True): |
| | assert level in {'atom', 'residue'} |
| |
|
| | if level == 'atom': |
| | entities = [a for res in biopython_residues for a in res.get_atoms() |
| | if (a.element != 'H' or not remove_H)] |
| | types = [a.element.capitalize() for a in entities] |
| | else: |
| | entities = [res['CA'] for res in biopython_residues] |
| | types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues] |
| |
|
| | coord = torch.tensor(np.stack([e.get_coord() for e in entities])) |
| | one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]), |
| | num_classes=len(type_encoder)) |
| |
|
| | return coord, one_hot |
| |
|
| |
|
| | def center_data(ligand, pocket): |
| | if pocket['x'].numel() > 0: |
| | pocket_com = pocket.center() |
| | else: |
| | pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) |
| |
|
| | ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] |
| | return ligand, pocket |
| |
|
| |
|
| | def get_bb_transform(n_xyz, ca_xyz, c_xyz): |
| | """ |
| | Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with |
| | Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame |
| | |
| | Args: |
| | n_xyz: (n, 3) |
| | ca_xyz: (n, 3) |
| | c_xyz: (n, 3) |
| | |
| | Returns: |
| | axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3) |
| | translation vector of shape (n, 3) |
| | """ |
| |
|
| | def rotation_matrix(angle, axis): |
| | axis_mapping = {'x': 0, 'y': 1, 'z': 2} |
| | axis = axis_mapping[axis] |
| | vector = torch.zeros(len(angle), 3) |
| | vector[:, axis] = 1 |
| | |
| | return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector) |
| |
|
| | translation = ca_xyz |
| | n_xyz = n_xyz - translation |
| | c_xyz = c_xyz - translation |
| |
|
| | |
| |
|
| | |
| | theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) |
| | Ry = rotation_matrix(theta_y, 'y') |
| | Ry = Ry.transpose(2, 1) |
| | n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz) |
| |
|
| | |
| | theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0]) |
| | Rz = rotation_matrix(theta_z, 'z') |
| | Rz = Rz.transpose(2, 1) |
| | |
| |
|
| | |
| |
|
| | |
| | c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz) |
| | theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1]) |
| | Rx = rotation_matrix(theta_x, 'x') |
| | Rx = Rx.transpose(2, 1) |
| | |
| |
|
| | |
| | Ry = Ry.transpose(2, 1) |
| | Rz = Rz.transpose(2, 1) |
| | Rx = Rx.transpose(2, 1) |
| | R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx) |
| |
|
| | |
| | |
| | return so3.rotation_vector_from_matrix(R), translation |
| |
|
| |
|
| | class Residues(TensorDict): |
| | """ |
| | Dictionary-like container for residues that supports some basic transformations. |
| | """ |
| |
|
| | |
| | KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord', |
| | 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', |
| | 'chi_indices', 'axis_angle', 'mask', 'bond_mask'} |
| |
|
| | |
| | COORD_KEYS = {'x', 'fixed_coord'} |
| |
|
| | |
| | VECTOR_KEYS = {'v', 'nma_vec'} |
| |
|
| | |
| | MUTABLE_PROPS_SS_AND_BB = {'v'} |
| |
|
| | |
| | MUTABLE_PROPS_SS = {'chi'} |
| |
|
| | |
| | MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'} |
| |
|
| | |
| | IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask', |
| | 'atom_mask', 'nerf_indices', 'length', 'theta', |
| | 'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'} |
| |
|
| | def copy(self): |
| | data = super().copy() |
| | return Residues(**data) |
| |
|
| | def deepcopy(self): |
| | data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v) |
| | for k, v in self.items()} |
| | return Residues(**data) |
| |
|
| | def center(self): |
| | com = scatter_mean(self['x'], self['mask'], dim=0) |
| | self['x'] = self['x'] - com[self['mask']] |
| | self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1) |
| | return com |
| |
|
| | def set_empty_v(self): |
| | self['v'] = torch.tensor([], device=self['x'].device) |
| |
|
| | @torch.no_grad() |
| | def set_chi(self, chi_angles): |
| | self['chi'][:, :5] = chi_angles |
| | nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask', |
| | 'nerf_indices', 'length', 'theta', |
| | 'chi', 'ddihedral', 'chi_indices']} |
| | self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1) |
| |
|
| | @torch.no_grad() |
| | def set_frame(self, new_ca_coord, new_axis_angle): |
| | bb_coord = self['fixed_coord'] |
| | bb_coord = bb_coord - self['x'].unsqueeze(1) |
| | rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle']) |
| | rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle) |
| | rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2) |
| | bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord) |
| | bb_coord = bb_coord + new_ca_coord.unsqueeze(1) |
| |
|
| | self['x'] = new_ca_coord |
| | self['axis_angle'] = new_axis_angle |
| | self['fixed_coord'] = bb_coord |
| | self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v']) |
| |
|
| | @staticmethod |
| | def empty(device): |
| | return Residues( |
| | x=torch.zeros(1, 3, device=device).float(), |
| | mask=torch.zeros(1, 1, device=device).long(), |
| | size=torch.zeros(1, device=device).long(), |
| | ) |
| |
|
| |
|
| | def randomize_tensors(tensor_dict, exclude_keys=None): |
| | """Replace tensors with random tensors with the same shape.""" |
| | exclude_keys = set() if exclude_keys is None else set(exclude_keys) |
| | for k, v in tensor_dict.items(): |
| | if isinstance(v, torch.Tensor) and k not in exclude_keys: |
| | if torch.is_floating_point(v): |
| | tensor_dict[k] = torch.randn_like(v) |
| | else: |
| | tensor_dict[k] = torch.randint_like(v, low=-42, high=42) |
| | return tensor_dict |
| |
|