| | import warnings |
| |
|
| | import torch |
| | from rdkit import Chem |
| | from rdkit.Chem import Draw, AllChem |
| | from rdkit.Chem import SanitizeFlags |
| | from src.analysis.metrics import check_mol |
| | from src import utils |
| | from src.data.molecule_builder import build_molecule |
| | from src.data.misc import protein_letters_1to3 |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None, |
| | atom_decoder=None, aa_decoder=None, residue_decoder=None, |
| | aa_atom_index=None): |
| |
|
| | rdpockets = [] |
| | for i in torch.unique(pocket['mask']): |
| |
|
| | node_coord = pocket['x'][pocket['mask'] == i] |
| | h = pocket['one_hot'][pocket['mask'] == i] |
| | atom_mask = pocket['atom_mask'][pocket['mask'] == i] |
| |
|
| | pdb_infos = [] |
| |
|
| | if pocket_representation == 'side_chain_bead': |
| | coord = node_coord |
| |
|
| | node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)] |
| | atom_types = ['C' if r == 'CA' else 'F' for r in node_types] |
| |
|
| | elif pocket_representation == 'CA+': |
| | aa_types = [aa_decoder[b] for b in h.argmax(-1)] |
| | side_chain_vec = pocket['v'][pocket['mask'] == i] |
| |
|
| | coord = [] |
| | atom_types = [] |
| | for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)): |
| |
|
| | |
| | for atom_name, idx in aa_atom_index[aa].items(): |
| |
|
| | if ~am[idx]: |
| | warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}") |
| | continue |
| |
|
| | coord.append(xyz + vec[idx]) |
| | atom_types.append(atom_name[0]) |
| |
|
| | info = Chem.AtomPDBResidueInfo() |
| | |
| | info.SetResidueName(protein_letters_1to3[aa]) |
| | info.SetResidueNumber(resi + 1) |
| | info.SetOccupancy(1.0) |
| | info.SetTempFactor(0.0) |
| | info.SetName(f' {atom_name:<3}') |
| | pdb_infos.append(info) |
| |
|
| | coord = torch.stack(coord, dim=0) |
| |
|
| | else: |
| | raise NotImplementedError(f"{pocket_representation} residue representation not supported") |
| |
|
| | atom_types = torch.tensor([atom_encoder[a] for a in atom_types]) |
| | rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder) |
| |
|
| | if len(pdb_infos) == len(rdmol.GetAtoms()): |
| | for a, info in zip(rdmol.GetAtoms(), pdb_infos): |
| | a.SetPDBResidueInfo(info) |
| |
|
| | rdpockets.append(rdmol) |
| |
|
| | return rdpockets |
| |
|
| |
|
| | def mols_to_pdbfile(rdmols, filename, flavor=0): |
| | pdb_str = "" |
| | for i, mol in enumerate(rdmols): |
| | pdb_str += f"MODEL{i + 1:>9}\n" |
| | block = Chem.MolToPDBBlock(mol, flavor=flavor) |
| | block = "\n".join(block.split("\n")[:-2]) |
| | pdb_str += block + "\n" |
| | pdb_str += f"ENDMDL\n" |
| | pdb_str += f"END\n" |
| |
|
| | with open(filename, 'w') as f: |
| | f.write(pdb_str) |
| |
|
| | return pdb_str |
| |
|
| |
|
| | def mol_as_pdb(rdmol, filename=None, bfactor=None): |
| |
|
| | _rdmol = Chem.Mol(rdmol) |
| | for a in _rdmol.GetAtoms(): |
| | a.SetIsAromatic(False) |
| | for b in _rdmol.GetBonds(): |
| | b.SetIsAromatic(False) |
| |
|
| | if bfactor is not None: |
| | for a in _rdmol.GetAtoms(): |
| | val = a.GetPropsAsDict()[bfactor] |
| |
|
| | info = Chem.AtomPDBResidueInfo() |
| | info.SetResidueName('UNL') |
| | info.SetResidueNumber(1) |
| | info.SetName(f' {a.GetSymbol():<3}') |
| | info.SetIsHeteroAtom(True) |
| | info.SetOccupancy(1.0) |
| | info.SetTempFactor(val) |
| | a.SetPDBResidueInfo(info) |
| |
|
| | pdb_str = Chem.MolToPDBBlock(_rdmol) |
| |
|
| | if filename is not None: |
| | with open(filename, 'w') as f: |
| | f.write(pdb_str) |
| |
|
| | return pdb_str |
| |
|
| |
|
| | def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200), |
| | label=check_mol, |
| | highlight_atom=lambda atom: False, |
| | highlight_bond=lambda bond: False): |
| |
|
| | draw_mols = [] |
| | marked_atoms = [] |
| | marked_bonds = [] |
| | for mol in molecules: |
| | draw_mol = Chem.Mol(mol) |
| | Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE) |
| | AllChem.Compute2DCoords(draw_mol) |
| | draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol, |
| | kekulize=False) |
| | draw_mols.append(draw_mol) |
| | marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)]) |
| | marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)]) |
| |
|
| | drawOptions = Draw.rdMolDraw2D.MolDrawOptions() |
| | drawOptions.prepareMolsBeforeDrawing = False |
| | drawOptions.highlightBondWidthMultiplier = 20 |
| |
|
| | return Draw.MolsToGridImage(draw_mols, |
| | molsPerRow=mols_per_row, |
| | subImgSize=fig_size, |
| | drawOptions=drawOptions, |
| | highlightAtomLists=marked_atoms, |
| | highlightBondLists=marked_bonds, |
| | legends=[f'[{i}] {label(mol)}' for |
| | i, mol in enumerate(draw_mols)]) |
| |
|