| import argparse |
| import sys |
| import os |
| import warnings |
| import tempfile |
| import pandas as pd |
|
|
| from Bio.PDB import PDBParser |
| from pathlib import Path |
| from rdkit import Chem |
| from torch.utils.data import DataLoader |
| from functools import partial |
|
|
| basedir = Path(__file__).resolve().parent.parent |
| sys.path.append(str(basedir)) |
| warnings.filterwarnings("ignore") |
|
|
| from src import utils |
| from src.data.dataset import ProcessedLigandPocketDataset |
| from src.data.data_utils import TensorDict, process_raw_pair |
| from src.model.lightning import DrugFlow |
| from src.sbdd_metrics.metrics import FullEvaluator |
|
|
| from tqdm import tqdm |
| from pdb import set_trace |
|
|
|
|
| def aggregate_metrics(table): |
| agg_col = 'posebusters' |
| total = 0 |
| table[agg_col] = 0 |
| for column in table.columns: |
| if column.startswith(agg_col) and column != agg_col: |
| table[agg_col] += table[column].fillna(0).astype(float) |
| total += 1 |
| table[agg_col] = table[agg_col] / total |
|
|
| agg_col = 'reos' |
| total = 0 |
| table[agg_col] = 0 |
| for column in table.columns: |
| if column.startswith(agg_col) and column != agg_col: |
| table[agg_col] += table[column].fillna(0).astype(float) |
| total += 1 |
| table[agg_col] = table[agg_col] / total |
|
|
| agg_col = 'chembl_ring_systems' |
| total = 0 |
| table[agg_col] = 0 |
| for column in table.columns: |
| if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'): |
| table[agg_col] += table[column].fillna(0).astype(float) |
| total += 1 |
| table[agg_col] = table[agg_col] / total |
| return table |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument('--protein', type=str, required=True, help="Input PDB file.") |
| p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.") |
| p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.") |
| p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.") |
| p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.") |
| p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.") |
| p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.") |
| p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.") |
| p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.") |
| p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.") |
| p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.") |
| p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.") |
| p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.") |
| p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.") |
| p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.") |
| p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.") |
| args = p.parse_args() |
|
|
| utils.set_deterministic(seed=args.seed) |
| utils.disable_rdkit_logging() |
|
|
| if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()): |
| raise NotImplementedError( |
| "Please provide a path to the processed dataset (using `--datadir`) "\ |
| "to infer the number of nodes. It contains the size distribution histogram." |
| ) |
| |
| if not args.filter: |
| args.batch_size = min(args.batch_size, args.n_samples) |
|
|
| |
| chkpt_path = Path(args.checkpoint) |
| chkpt_name = chkpt_path.parts[-1].split('.')[0] |
| model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False) |
| if args.datadir is not None: |
| model.datadir = args.datadir |
|
|
| model.setup(stage='generation') |
| model.batch_size = model.eval_batch_size = args.batch_size |
| model.eval().to(args.device) |
| if args.n_steps is not None: |
| model.T = args.n_steps |
|
|
| |
| size_model = None |
| molecule_size = None |
| molecule_size_boundaries = None |
| if args.molecule_size is not None: |
| if args.molecule_size.isdigit(): |
| molecule_size = int(args.molecule_size) |
| print(f'Will generate molecules of size {molecule_size}') |
| else: |
| boundaries = [x.strip() for x in args.molecule_size.split(',')] |
| assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit() |
| left = int(boundaries[0]) |
| right = int(boundaries[1]) |
| molecule_size = f"uniform_{left}_{right}" |
| print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})') |
|
|
| |
| pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0] |
| rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0] |
|
|
| ligand, pocket = process_raw_pair( |
| pdb_model, rdmol, |
| dist_cutoff=args.pocket_distance_cutoff, |
| pocket_representation=model.pocket_representation, |
| compute_nerf_params=True, |
| nma_input=args.protein if model.dynamics.add_nma_feat else None |
| ) |
| ligand['name'] = 'ligand' |
| dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)] |
| dataloader = DataLoader( |
| dataset=dataset, |
| batch_size=args.batch_size, |
| collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None), |
| pin_memory=True |
| ) |
|
|
| |
| smiles = set() |
| sampled_molecules = [] |
| metrics = [] |
| Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True) |
| print(f'Will generate {args.n_samples} samples') |
|
|
| evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce) |
|
|
| with tqdm(total=args.n_samples) as pbar: |
| while len(sampled_molecules) < args.n_samples: |
| for i, data in enumerate(dataloader): |
| new_data = { |
| 'ligand': TensorDict(**data['ligand']).to(args.device), |
| 'pocket': TensorDict(**data['pocket']).to(args.device), |
| } |
| rdmols, rdpockets, _ = model.sample( |
| new_data, |
| n_samples=1, |
| timesteps=args.n_steps, |
| num_nodes=molecule_size, |
| ) |
|
|
| if args.filter or (args.metrics_output is not None): |
| results = [] |
| with tempfile.TemporaryDirectory() as tmpdir: |
| for mol, receptor in zip(rdmols, rdpockets): |
| receptor_path = Path(tmpdir, 'receptor.pdb') |
| Chem.MolToPDBFile(receptor, str(receptor_path)) |
| results.append(evaluator(mol, receptor_path)) |
|
|
| table = pd.DataFrame(results) |
| table['novel'] = ~table['representation.smiles'].isin(smiles) |
| table = aggregate_metrics(table) |
| |
| added_molecules = 0 |
| if args.filter: |
| table['passed_filters'] = ( |
| (table['posebusters'] == 1) & |
| |
| (table['chembl_ring_systems'] == 1) & |
| (table['novel'] == 1) |
| ) |
| for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values): |
| if passed: |
| sampled_molecules.append(rdmols[i]) |
| smiles.add(smi) |
| added_molecules += 1 |
|
|
| if args.metrics_output is not None: |
| metrics.append(table[table['passed_filters']]) |
| |
| else: |
| sampled_molecules.extend(rdmols) |
| added_molecules = len(rdmols) |
| if args.metrics_output is not None: |
| metrics.append(table) |
|
|
| pbar.update(added_molecules) |
|
|
| |
| utils.write_sdf_file(args.output, sampled_molecules) |
|
|
| if args.metrics_output is not None: |
| metrics = pd.concat(metrics) |
| metrics.to_csv(args.metrics_output, index=False) |
|
|