| | from pathlib import Path |
| | from time import time |
| | import argparse |
| | import shutil |
| | import random |
| | import yaml |
| | from collections import defaultdict |
| |
|
| | import torch |
| | from tqdm import tqdm |
| | import numpy as np |
| | from Bio.PDB import PDBParser |
| | from rdkit import Chem |
| |
|
| | import sys |
| | basedir = Path(__file__).resolve().parent.parent.parent |
| | sys.path.append(str(basedir)) |
| |
|
| | from src.data.data_utils import process_raw_pair, get_n_nodes, get_type_histogram |
| | from src.data.data_utils import rdmol_to_smiles |
| | from src.constants import atom_encoder, bond_encoder |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('basedir', type=Path) |
| | parser.add_argument('--outdir', type=Path, default=None) |
| | parser.add_argument('--split_path', type=Path, default=None) |
| | parser.add_argument('--pocket', type=str, default='CA+', |
| | choices=['side_chain_bead', 'CA+']) |
| | parser.add_argument('--random_seed', type=int, default=42) |
| | parser.add_argument('--val_size', type=int, default=100) |
| | parser.add_argument('--normal_modes', action='store_true') |
| | parser.add_argument('--flex', action='store_true') |
| | parser.add_argument('--toy', action='store_true') |
| | args = parser.parse_args() |
| |
|
| | random.seed(args.random_seed) |
| |
|
| | datadir = args.basedir / 'crossdocked_pocket10/' |
| |
|
| | |
| | dirname = f"processed_crossdocked_{args.pocket}" |
| | if args.flex: |
| | dirname += '_flex' |
| | if args.normal_modes: |
| | dirname += '_nma' |
| | if args.toy: |
| | dirname += '_toy' |
| | processed_dir = Path(args.basedir, dirname) if args.outdir is None else args.outdir |
| | processed_dir.mkdir(parents=True) |
| |
|
| | |
| | split_path = Path(args.basedir, 'split_by_name.pt') if args.split_path is None else args.split_path |
| | data_split = torch.load(split_path) |
| |
|
| | |
| | |
| | if 'val' not in data_split: |
| | random.shuffle(data_split['train']) |
| | data_split['val'] = data_split['train'][-args.val_size:] |
| | data_split['train'] = data_split['train'][:-args.val_size] |
| |
|
| | if args.toy: |
| | data_split['train'] = random.sample(data_split['train'], 100) |
| |
|
| | failed = {} |
| | train_smiles = [] |
| |
|
| | n_samples_after = {} |
| | for split in data_split.keys(): |
| |
|
| | print(f"Processing {split} dataset...") |
| |
|
| | ligands = defaultdict(list) |
| | pockets = defaultdict(list) |
| |
|
| | tic = time() |
| | pbar = tqdm(data_split[split]) |
| | for pocket_fn, ligand_fn in pbar: |
| |
|
| | pbar.set_description(f'#failed: {len(failed)}') |
| |
|
| | sdffile = datadir / f'{ligand_fn}' |
| | pdbfile = datadir / f'{pocket_fn}' |
| |
|
| | try: |
| | pdb_model = PDBParser(QUIET=True).get_structure('', pdbfile)[0] |
| |
|
| | rdmol = Chem.SDMolSupplier(str(sdffile))[0] |
| |
|
| | ligand, pocket = process_raw_pair( |
| | pdb_model, rdmol, pocket_representation=args.pocket, |
| | compute_nerf_params=args.flex, compute_bb_frames=args.flex, |
| | nma_input=pdbfile if args.normal_modes else None) |
| |
|
| | except (KeyError, AssertionError, FileNotFoundError, IndexError, |
| | ValueError, AttributeError) as e: |
| | failed[(split, sdffile, pdbfile)] = (type(e).__name__, str(e)) |
| | continue |
| |
|
| | nerf_keys = ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices'] |
| | for k in ['x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec'] + nerf_keys + ['axis_angle']: |
| | if k in ligand: |
| | ligands[k].append(ligand[k]) |
| | if k in pocket: |
| | pockets[k].append(pocket[k]) |
| |
|
| | pocket_file = pdbfile.name.replace('_', '-') |
| | ligand_file = Path(pocket_file).stem + '_' + Path(sdffile).name.replace('_', '-') |
| | ligands['name'].append(ligand_file) |
| | pockets['name'].append(pocket_file) |
| | train_smiles.append(rdmol_to_smiles(rdmol)) |
| |
|
| | if split in {'val', 'test'}: |
| | pdb_sdf_dir = processed_dir / split |
| | pdb_sdf_dir.mkdir(exist_ok=True) |
| |
|
| | |
| | pdb_file_out = Path(pdb_sdf_dir, pocket_file) |
| | shutil.copy(pdbfile, pdb_file_out) |
| |
|
| | |
| | sdf_file_out = Path(pdb_sdf_dir, ligand_file) |
| | shutil.copy(sdffile, sdf_file_out) |
| |
|
| | data = {'ligands': ligands, 'pockets': pockets} |
| | torch.save(data, Path(processed_dir, f'{split}.pt')) |
| |
|
| | if split == 'train': |
| | np.save(Path(processed_dir, 'train_smiles.npy'), train_smiles) |
| |
|
| | print(f"Processing {split} set took {(time() - tic) / 60.0:.2f} minutes") |
| |
|
| |
|
| | |
| | |
| | |
| | train_data = torch.load(Path(processed_dir, f'train.pt')) |
| |
|
| | |
| | max_ligand_size = max([len(x) for x in train_data['ligands']['x']]) |
| |
|
| | |
| | pocket_coords = train_data['pockets']['x'] |
| | ligand_coords = train_data['ligands']['x'] |
| | n_nodes = get_n_nodes(ligand_coords, pocket_coords) |
| | np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes) |
| |
|
| | |
| | lig_one_hot = [x.numpy() for x in train_data['ligands']['one_hot']] |
| | ligand_hist = get_type_histogram(lig_one_hot, atom_encoder) |
| | np.save(Path(processed_dir, 'ligand_type_histogram.npy'), ligand_hist) |
| |
|
| | |
| | lig_bond_one_hot = [x.numpy() for x in train_data['ligands']['bond_one_hot']] |
| | ligand_bond_hist = get_type_histogram(lig_bond_one_hot, bond_encoder) |
| | np.save(Path(processed_dir, 'ligand_bond_type_histogram.npy'), ligand_bond_hist) |
| |
|
| | |
| | error_str = "" |
| | for k, v in failed.items(): |
| | error_str += f"{'Split':<15}: {k[0]}\n" |
| | error_str += f"{'Ligand':<15}: {k[1]}\n" |
| | error_str += f"{'Pocket':<15}: {k[2]}\n" |
| | error_str += f"{'Error type':<15}: {v[0]}\n" |
| | error_str += f"{'Error msg':<15}: {v[1]}\n\n" |
| |
|
| | with open(Path(processed_dir, 'errors.txt'), 'w') as f: |
| | f.write(error_str) |
| |
|
| | metadata = { |
| | 'max_ligand_size': max_ligand_size |
| | } |
| | with open(Path(processed_dir, 'metadata.yml'), 'w') as f: |
| | yaml.dump(metadata, f, default_flow_style=False) |
| |
|