Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import pickle | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from boltz.data.mol import load_molecules | |
| from boltz.data import const | |
| from boltz.data.parse.mmcif_with_constraints import parse_mmcif | |
| from multiprocessing import Pool | |
| def compute_torsion_angles(coords, torsion_index): | |
| r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :] | |
| r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :] | |
| r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :] | |
| n_ijk = np.cross(r_ij, r_kj, axis=-1) | |
| n_jkl = np.cross(r_kj, r_kl, axis=-1) | |
| r_kj_norm = np.linalg.norm(r_kj, axis=-1) | |
| n_ijk_norm = np.linalg.norm(n_ijk, axis=-1) | |
| n_jkl_norm = np.linalg.norm(n_jkl, axis=-1) | |
| sign_phi = np.sign( | |
| r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None] | |
| ).squeeze(axis=(-1, -2)) | |
| phi = sign_phi * np.arccos( | |
| np.clip( | |
| (n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2)) | |
| / (n_ijk_norm * n_jkl_norm), | |
| -1 + 1e-8, | |
| 1 - 1e-8, | |
| ) | |
| ) | |
| return phi | |
| def check_ligand_distance_geometry( | |
| structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2 | |
| ): | |
| coords = structure.coords["coords"] | |
| rdkit_bounds_constraints = constraints.rdkit_bounds_constraints | |
| pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T | |
| bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool) | |
| angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool) | |
| upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32) | |
| lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32) | |
| dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1) | |
| bond_length_violations = ( | |
| dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer) | |
| ) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer)) | |
| bond_angle_violations = ( | |
| dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer) | |
| ) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer)) | |
| internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[ | |
| ~bond_mask * ~angle_mask | |
| ] * (1.0 - clash_buffer) | |
| num_ligands = sum( | |
| [ | |
| int(const.chain_types[chain["mol_type"]] == "NONPOLYMER") | |
| for chain in structure.chains | |
| ] | |
| ) | |
| return { | |
| "num_ligands": num_ligands, | |
| "num_bond_length_violations": bond_length_violations.sum(), | |
| "num_bonds": bond_mask.sum(), | |
| "num_bond_angle_violations": bond_angle_violations.sum(), | |
| "num_angles": angle_mask.sum(), | |
| "num_internal_clash_violations": internal_clash_violations.sum(), | |
| "num_non_neighbors": (~bond_mask * ~angle_mask).sum(), | |
| } | |
| def check_ligand_stereochemistry(structure, constraints): | |
| coords = structure.coords["coords"] | |
| chiral_atom_constraints = constraints.chiral_atom_constraints | |
| stereo_bond_constraints = constraints.stereo_bond_constraints | |
| chiral_atom_index = chiral_atom_constraints["atom_idxs"].T | |
| true_chiral_atom_orientations = chiral_atom_constraints["is_r"] | |
| chiral_atom_ref_mask = chiral_atom_constraints["is_reference"] | |
| chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask] | |
| true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask] | |
| pred_chiral_atom_orientations = ( | |
| compute_torsion_angles(coords, chiral_atom_index) > 0 | |
| ) | |
| chiral_atom_violations = ( | |
| pred_chiral_atom_orientations != true_chiral_atom_orientations | |
| ) | |
| stereo_bond_index = stereo_bond_constraints["atom_idxs"].T | |
| true_stereo_bond_orientations = stereo_bond_constraints["is_e"] | |
| stereo_bond_ref_mask = stereo_bond_constraints["is_reference"] | |
| stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask] | |
| true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask] | |
| pred_stereo_bond_orientations = ( | |
| np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2 | |
| ) | |
| stereo_bond_violations = ( | |
| pred_stereo_bond_orientations != true_stereo_bond_orientations | |
| ) | |
| return { | |
| "num_chiral_atom_violations": chiral_atom_violations.sum(), | |
| "num_chiral_atoms": chiral_atom_index.shape[1], | |
| "num_stereo_bond_violations": stereo_bond_violations.sum(), | |
| "num_stereo_bonds": stereo_bond_index.shape[1], | |
| } | |
| def check_ligand_flatness(structure, constraints, buffer=0.25): | |
| coords = structure.coords["coords"] | |
| planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"] | |
| ring_5_coords = coords[planar_ring_5_index, :] | |
| centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True) | |
| ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None] | |
| ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1)) | |
| ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1) | |
| planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"] | |
| ring_6_coords = coords[planar_ring_6_index, :] | |
| centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True) | |
| ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None] | |
| ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1) | |
| ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1) | |
| planar_bond_index = constraints.planar_bond_constraints["atom_idxs"] | |
| bond_coords = coords[planar_bond_index, :] | |
| centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True) | |
| bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None] | |
| bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1) | |
| bond_violations = np.any(bond_dists >= buffer, axis=-1) | |
| return { | |
| "num_planar_5_ring_violations": ring_5_violations.sum(), | |
| "num_planar_5_rings": ring_5_violations.shape[0], | |
| "num_planar_6_ring_violations": ring_6_violations.sum(), | |
| "num_planar_6_rings": ring_6_violations.shape[0], | |
| "num_planar_double_bond_violations": bond_violations.sum(), | |
| "num_planar_double_bonds": bond_violations.shape[0], | |
| } | |
| def check_steric_clash(structure, molecules, buffer=0.25): | |
| result = {} | |
| for type_i in const.chain_types: | |
| out_type_i = type_i.lower() | |
| out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand" | |
| result[f"num_chain_pairs_sym_{out_type_i}"] = 0 | |
| result[f"num_chain_clashes_sym_{out_type_i}"] = 0 | |
| for type_j in const.chain_types: | |
| out_type_j = type_j.lower() | |
| out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand" | |
| result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0 | |
| result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0 | |
| connected_chains = set() | |
| for bond in structure.bonds: | |
| if bond["chain_1"] != bond["chain_2"]: | |
| connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"])))) | |
| vdw_radii = [] | |
| for res in structure.residues: | |
| mol = molecules[res["name"]] | |
| token_atoms = structure.atoms[ | |
| res["atom_idx"] : res["atom_idx"] + res["atom_num"] | |
| ] | |
| atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()} | |
| token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms] | |
| vdw_radii.extend( | |
| [const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref] | |
| ) | |
| vdw_radii = np.array(vdw_radii, dtype=np.float32) | |
| np.array([a.GetAtomicNum() for a in token_atoms_ref]) | |
| for i, chain_i in enumerate(structure.chains): | |
| for j, chain_j in enumerate(structure.chains): | |
| if ( | |
| chain_i["atom_num"] == 1 | |
| or chain_j["atom_num"] == 1 | |
| or j <= i | |
| or (i, j) in connected_chains | |
| ): | |
| continue | |
| coords_i = structure.coords["coords"][ | |
| chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] | |
| ] | |
| coords_j = structure.coords["coords"][ | |
| chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] | |
| ] | |
| dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1) | |
| radii_i = vdw_radii[ | |
| chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] | |
| ] | |
| radii_j = vdw_radii[ | |
| chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] | |
| ] | |
| radii_sum = radii_i[:, None] + radii_j[None, :] | |
| is_clashing = np.any(dists < radii_sum * (1.00 - buffer)) | |
| type_i = const.chain_types[chain_i["mol_type"]].lower() | |
| type_j = const.chain_types[chain_j["mol_type"]].lower() | |
| type_i = type_i if type_i != "nonpolymer" else "ligand" | |
| type_j = type_j if type_j != "nonpolymer" else "ligand" | |
| is_symmetric = ( | |
| chain_i["entity_id"] == chain_j["entity_id"] | |
| and chain_i["atom_num"] == chain_j["atom_num"] | |
| ) | |
| if is_symmetric: | |
| key = "sym_" + type_i | |
| else: | |
| key = "asym_" + type_i + "_" + type_j | |
| result["num_chain_pairs_" + key] += 1 | |
| result["num_chain_clashes_" + key] += int(is_clashing) | |
| return result | |
| cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache") | |
| ccd_path = cache_dir / "ccd.pkl" | |
| moldir = cache_dir / "mols" | |
| with ccd_path.open("rb") as file: | |
| ccd = pickle.load(file) | |
| boltz1_dir = Path( | |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions" | |
| ) | |
| boltz1x_dir = Path( | |
| "/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions" | |
| ) | |
| chai_dir = Path( | |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai" | |
| ) | |
| af3_dir = Path( | |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3" | |
| ) | |
| boltz1_pdb_ids = set(os.listdir(boltz1_dir)) | |
| boltz1x_pdb_ids = set(os.listdir(boltz1x_dir)) | |
| chai_pdb_ids = set(os.listdir(chai_dir)) | |
| af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)]) | |
| common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids | |
| tools = ["boltz1", "boltz1x", "chai", "af3"] | |
| num_samples = 5 | |
| def process_fn(key): | |
| tool, pdb_id, model_idx = key | |
| if tool == "boltz1": | |
| cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" | |
| elif tool == "boltz1x": | |
| cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" | |
| elif tool == "chai": | |
| cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif" | |
| elif tool == "af3": | |
| cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif" | |
| parsed_structure = parse_mmcif( | |
| cif_path, | |
| ccd, | |
| moldir, | |
| ) | |
| structure = parsed_structure.data | |
| constraints = parsed_structure.residue_constraints | |
| record = { | |
| "tool": tool, | |
| "pdb_id": pdb_id, | |
| "model_idx": model_idx, | |
| } | |
| record.update(check_ligand_distance_geometry(structure, constraints)) | |
| record.update(check_ligand_stereochemistry(structure, constraints)) | |
| record.update(check_ligand_flatness(structure, constraints)) | |
| record.update(check_steric_clash(structure, molecules=ccd)) | |
| return record | |
| keys = [] | |
| for tool in tools: | |
| for pdb_id in common_pdb_ids: | |
| for model_idx in range(num_samples): | |
| keys.append((tool, pdb_id, model_idx)) | |
| process_fn(keys[0]) | |
| records = [] | |
| with Pool(48) as p: | |
| with tqdm(total=len(keys)) as pbar: | |
| for record in p.imap_unordered(process_fn, keys): | |
| records.append(record) | |
| pbar.update(1) | |
| df = pd.DataFrame.from_records(records) | |
| df["num_chain_clashes_all"] = df[ | |
| [key for key in df.columns if "chain_clash" in key] | |
| ].sum(axis=1) | |
| df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1) | |
| df["clash_free"] = df["num_chain_clashes_all"] == 0 | |
| df["valid_ligand"] = ( | |
| df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0 | |
| ) | |
| df["valid"] = (df["clash_free"]) & (df["valid_ligand"]) | |
| df.to_csv("physical_checks_test.csv") | |