Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import json | |
| import os | |
| import random | |
| import re | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| import urllib.request | |
| import biotite.structure as struc | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| from tqdm.auto import tqdm | |
| from pathlib import Path | |
| from typing import Any | |
| from typing import Dict | |
| from typing import List | |
| from typing import Optional | |
| from typing import Tuple | |
| from transformers import AutoModel | |
| from boltz_fastplms.cif_writer import write_cif | |
| from boltz_fastplms.get_boltz2_weights import BOLTZ2_CKPT_URL | |
| from boltz_fastplms.minimal_featurizer import build_boltz2_features | |
| from boltz_fastplms.minimal_structures import ProteinStructureTemplate | |
| from testing.common import autocast_context | |
| from testing.common import build_output_dir | |
| from testing.common import login_if_needed | |
| from testing.common import resolve_device | |
| from testing.common import resolve_dtype | |
| from testing.reporting import write_csv | |
| from testing.reporting import write_json | |
| from testing.reporting import write_summary | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| assert "tm_score" in dir(struc), ( | |
| "biotite.structure.tm_score is unavailable. Install biotite>=1.5.0 in the target environment." | |
| ) | |
| TM_SCORE_FN = struc.tm_score | |
| BOLTZ2_FIXED_RECYCLING_STEPS = 3 | |
| BOLTZ2_FIXED_SAMPLING_STEPS = 200 | |
| BOLTZ2_FIXED_DIFFUSION_SAMPLES = 20 | |
| MIN_SEED_VALUE = int(np.iinfo(np.uint32).min) | |
| MAX_SEED_VALUE = int(np.iinfo(np.uint32).max) | |
| SEQUENCE_OPTIONS = [ | |
| "MDDADPEERNYDNMLKMLSDLNKDLEKLLEEMEKISVQATWMAYDMVVMRTNPTLAESMRRLEDAFVNCKEEMEKNWQELLHETKQRL", | |
| "MASLGHILVFCVGLLTMAKAESPKEHDPFTYDYQSLQIGGLVIAGILFILGILIVLSRRCRCKFNQQQRTGEPDEEEGTFRSSIRRLSTRRR", | |
| "MAVESRVTQEEIKKEPEKPIDREKTCPLLLRVFTTNNGRHHRMDEFSRGNVPSSELQIYTWMDATLKELTSLVKEVYPEARKKGTHFNFAIVFTDVKRPGYRVKEIGSTMSGRKGTDDSMTLQSQKFQIGDYLDIAITPPNRAPPPSGRMRPY", | |
| ] | |
| def _enforce_determinism() -> None: | |
| if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.allow_tf32 = False | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.use_deterministic_algorithms(True) | |
| def _seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | |
| if seed is None: | |
| env_seed = os.environ.get("PL_GLOBAL_SEED") | |
| if env_seed is None: | |
| seed = 0 | |
| else: | |
| seed = int(env_seed) | |
| elif isinstance(seed, int) is False: | |
| seed = int(seed) | |
| if not (MIN_SEED_VALUE <= seed <= MAX_SEED_VALUE): | |
| raise ValueError(f"{seed} is not in bounds, numpy accepts from {MIN_SEED_VALUE} to {MAX_SEED_VALUE}") | |
| os.environ["PL_GLOBAL_SEED"] = str(seed) | |
| os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| return seed | |
| def _download_checkpoint_if_needed(checkpoint_path: Path) -> Path: | |
| checkpoint_path.parent.mkdir(parents=True, exist_ok=True) | |
| if not checkpoint_path.exists(): | |
| urllib.request.urlretrieve(BOLTZ2_CKPT_URL, str(checkpoint_path)) # noqa: S310 | |
| return checkpoint_path | |
| def _detect_no_kernels_support() -> bool: | |
| command = [sys.executable, "-m", "boltz.main", "predict", "--help"] | |
| completed = subprocess.run(command, capture_output=True, text=True, check=False) | |
| combined_output = f"{completed.stdout}\n{completed.stderr}" | |
| return "--no_kernels" in combined_output | |
| def _set_sequence_seed(seed: int, sequence_index: int) -> None: | |
| _seed_everything(seed=seed + sequence_index, workers=False) | |
| def _to_device(feats: Dict[str, torch.Tensor], device: torch.device, dtype: torch.dtype) -> Dict[str, torch.Tensor]: | |
| output: Dict[str, torch.Tensor] = {} | |
| for key in feats: | |
| value = feats[key] | |
| if value.is_floating_point(): | |
| output[key] = value.to(device=device, dtype=dtype) | |
| else: | |
| output[key] = value.to(device=device) | |
| return output | |
| def _clone_feats(feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| output: Dict[str, torch.Tensor] = {} | |
| for key in feats: | |
| output[key] = feats[key].clone() | |
| return output | |
| def _summary_metric(value: torch.Tensor) -> torch.Tensor: | |
| if value.ndim == 0: | |
| return value.reshape(1) | |
| if value.ndim == 1: | |
| return value | |
| return value.reshape(value.shape[0], -1)[:, 0] | |
| def _extract_primary_plddt_vector(output: Dict[str, torch.Tensor], feats: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| assert "plddt" in output, "Missing pLDDT in model output." | |
| plddt = output["plddt"].detach().cpu() | |
| if plddt.ndim == 0: | |
| return plddt.reshape(1).float() | |
| if plddt.ndim >= 2: | |
| plddt = plddt[0] | |
| plddt = plddt.reshape(-1).float() | |
| token_mask = feats["token_pad_mask"][0].detach().cpu().reshape(-1) > 0 | |
| atom_mask = feats["atom_pad_mask"][0].detach().cpu().reshape(-1) > 0 | |
| if plddt.numel() == token_mask.numel(): | |
| plddt = plddt[token_mask] | |
| elif plddt.numel() == atom_mask.numel(): | |
| plddt = plddt[atom_mask] | |
| return plddt | |
| def _compute_confidence_score(ptm: torch.Tensor, iptm: torch.Tensor, complex_plddt: torch.Tensor) -> torch.Tensor: | |
| if torch.allclose(iptm, torch.zeros_like(iptm)): | |
| return (4 * complex_plddt + ptm) / 5 | |
| return (4 * complex_plddt + iptm) / 5 | |
| def _run_ours_forward( | |
| model, | |
| feats_ours: Dict[str, torch.Tensor], | |
| args: argparse.Namespace, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| sequence_index: int, | |
| ) -> Dict[str, torch.Tensor]: | |
| with torch.no_grad(), autocast_context(device=device, dtype=dtype): | |
| _set_sequence_seed(args.seed, sequence_index) | |
| return model.forward( | |
| feats=feats_ours, | |
| recycling_steps=BOLTZ2_FIXED_RECYCLING_STEPS, | |
| num_sampling_steps=BOLTZ2_FIXED_SAMPLING_STEPS, | |
| diffusion_samples=BOLTZ2_FIXED_DIFFUSION_SAMPLES, | |
| run_confidence_sequentially=args.run_confidence_sequentially, | |
| ) | |
| def _vector_metrics(lhs: torch.Tensor, rhs: torch.Tensor) -> Tuple[float, float, float]: | |
| delta = lhs.float() - rhs.float() | |
| abs_delta = torch.abs(delta) | |
| mae = float(abs_delta.mean().item()) | |
| rmse = float(torch.sqrt(torch.mean(delta * delta)).item()) | |
| max_abs = float(abs_delta.max().item()) | |
| return mae, rmse, max_abs | |
| def _kabsch_align_mobile_to_target(mobile: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| assert mobile.ndim == 2 and target.ndim == 2, "Expected coordinate tensors with shape [N, 3]." | |
| assert mobile.shape == target.shape, "Coordinate tensors must have matching shapes." | |
| assert mobile.shape[1] == 3, "Coordinate tensors must have last dimension size 3." | |
| assert mobile.shape[0] > 0, "Expected at least one shared atom for alignment." | |
| mobile_32 = mobile.float() | |
| target_32 = target.float() | |
| if mobile_32.shape[0] < 3: | |
| mobile_centroid = mobile_32.mean(dim=0, keepdim=True) | |
| target_centroid = target_32.mean(dim=0, keepdim=True) | |
| return mobile_32 - mobile_centroid + target_centroid | |
| mobile_centroid = mobile_32.mean(dim=0, keepdim=True) | |
| target_centroid = target_32.mean(dim=0, keepdim=True) | |
| mobile_centered = mobile_32 - mobile_centroid | |
| target_centered = target_32 - target_centroid | |
| covariance = mobile_centered.transpose(0, 1).matmul(target_centered) | |
| u_mat, _, vh_mat = torch.linalg.svd(covariance, full_matrices=False) | |
| correction = torch.eye(3, dtype=mobile_32.dtype, device=mobile_32.device) | |
| det_sign = torch.det(vh_mat.transpose(0, 1).matmul(u_mat.transpose(0, 1))).item() | |
| if det_sign < 0: | |
| correction[2, 2] = -1.0 | |
| rotation = vh_mat.transpose(0, 1).matmul(correction).matmul(u_mat.transpose(0, 1)) | |
| return mobile_centered.matmul(rotation) + target_centroid | |
| def _pairwise_distance_mae(lhs: torch.Tensor, rhs: torch.Tensor) -> float: | |
| assert lhs.ndim == 2 and rhs.ndim == 2, "Expected coordinate tensors with shape [N, 3]." | |
| assert lhs.shape == rhs.shape, "Coordinate tensors must have matching shapes." | |
| assert lhs.shape[1] == 3, "Coordinate tensors must have last dimension size 3." | |
| lhs_dist = torch.cdist(lhs.float(), lhs.float()) | |
| rhs_dist = torch.cdist(rhs.float(), rhs.float()) | |
| return float(torch.mean(torch.abs(lhs_dist - rhs_dist)).item()) | |
| def _write_single_chain_fasta(sequence: str, path: Path) -> None: | |
| text = f">A|protein|empty\n{sequence}\n" | |
| path.write_text(text, encoding="utf-8") | |
| def _parse_pdb_atom_map(path: Path) -> Dict[Tuple[str, int, str], torch.Tensor]: | |
| atom_map: Dict[Tuple[str, int, str], torch.Tensor] = {} | |
| for line in path.read_text(encoding="utf-8").splitlines(): | |
| if not (line.startswith("ATOM") or line.startswith("HETATM")): | |
| continue | |
| atom_name = line[12:16].strip() | |
| chain_id = line[21:22].strip() | |
| residue_index = int(line[22:26]) | |
| x_val = float(line[30:38]) | |
| y_val = float(line[38:46]) | |
| z_val = float(line[46:54]) | |
| atom_map[(chain_id, residue_index, atom_name)] = torch.tensor([x_val, y_val, z_val], dtype=torch.float32) | |
| assert len(atom_map) > 0, f"No atoms parsed from PDB: {path}" | |
| return atom_map | |
| def _extract_model_id_from_name(filename: str) -> int: | |
| match = re.search(r"_model_(\d+)\.", filename) | |
| assert match is not None, f"Could not parse model id from filename: {filename}" | |
| return int(match.group(1)) | |
| def _map_paths_by_model(paths: List[Path]) -> Dict[int, Path]: | |
| path_map: Dict[int, Path] = {} | |
| for path in paths: | |
| model_id = _extract_model_id_from_name(path.name) | |
| assert model_id not in path_map, f"Found duplicate artifacts for model id {model_id}: {path}" | |
| path_map[model_id] = path | |
| return path_map | |
| def _build_ours_atom_maps( | |
| sample_coords: torch.Tensor, | |
| atom_mask: torch.Tensor, | |
| atom_names: List[str], | |
| atom_residue_index: List[int], | |
| atom_chain_id: List[str], | |
| ) -> List[Dict[Tuple[str, int, str], torch.Tensor]]: | |
| coords = sample_coords.detach().cpu() | |
| if coords.ndim == 4: | |
| assert coords.shape[0] == 1, "Expected singleton batch dimension for sample coordinates." | |
| coords = coords[0] | |
| if coords.ndim == 2: | |
| coords = coords.unsqueeze(0) | |
| assert coords.ndim == 3, f"Expected sample_atom_coords with 3 dimensions, got shape {coords.shape}." | |
| assert coords.shape[0] >= BOLTZ2_FIXED_DIFFUSION_SAMPLES, ( | |
| f"Expected at least {BOLTZ2_FIXED_DIFFUSION_SAMPLES} samples, got {coords.shape[0]}." | |
| ) | |
| atom_mask_bool = atom_mask.detach().cpu() > 0 | |
| output: List[Dict[Tuple[str, int, str], torch.Tensor]] = [] | |
| for sample_index in range(BOLTZ2_FIXED_DIFFUSION_SAMPLES): | |
| valid_coords = coords[sample_index][atom_mask_bool] | |
| assert valid_coords.shape[0] >= len(atom_names), ( | |
| "Our model returned fewer valid atom coordinates than template atoms." | |
| ) | |
| atom_map: Dict[Tuple[str, int, str], torch.Tensor] = {} | |
| for atom_idx in range(len(atom_names)): | |
| key = ( | |
| atom_chain_id[atom_idx], | |
| atom_residue_index[atom_idx] + 1, | |
| atom_names[atom_idx], | |
| ) | |
| atom_map[key] = valid_coords[atom_idx].float().cpu() | |
| output.append(atom_map) | |
| return output | |
| def _build_reference_cif_tensors( | |
| template: ProteinStructureTemplate, | |
| atom_pad_mask: torch.Tensor, | |
| ref_atom_map: Dict[Tuple[str, int, str], torch.Tensor], | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert atom_pad_mask.ndim == 1, "Expected atom pad mask with shape [atoms]." | |
| atom_slots = atom_pad_mask.shape[0] | |
| coords = torch.zeros((1, atom_slots, 3), dtype=torch.float32) | |
| ref_mask = atom_pad_mask.detach().cpu().float().clone() | |
| for atom_idx in range(template.num_atoms): | |
| key = ( | |
| template.atom_chain_id[atom_idx], | |
| template.atom_residue_index[atom_idx] + 1, | |
| template.atom_names[atom_idx], | |
| ) | |
| if key in ref_atom_map: | |
| coords[0, atom_idx] = ref_atom_map[key].float().cpu() | |
| else: | |
| ref_mask[atom_idx] = 0.0 | |
| return coords, ref_mask | |
| def _run_boltz_cli_reference( | |
| sequence: str, | |
| sequence_index: int, | |
| checkpoint_path: Path, | |
| args: argparse.Namespace, | |
| device: torch.device, | |
| supports_no_kernels: bool, | |
| ) -> Tuple[List[Dict[Tuple[str, int, str], torch.Tensor]], List[torch.Tensor], List[Dict[str, float]]]: | |
| sequence_seed = args.seed + sequence_index | |
| with tempfile.TemporaryDirectory(prefix=f"boltz2_ref_{sequence_index}_") as tmp_dir_str: | |
| tmp_dir = Path(tmp_dir_str) | |
| fasta_path = tmp_dir / f"seq_{sequence_index}.fasta" | |
| out_root = tmp_dir / "ref_out" | |
| _write_single_chain_fasta(sequence=sequence, path=fasta_path) | |
| command = [ | |
| sys.executable, | |
| "-m", | |
| "boltz.main", | |
| "predict", | |
| str(fasta_path), | |
| "--out_dir", | |
| str(out_root), | |
| "--model", | |
| "boltz2", | |
| "--checkpoint", | |
| str(checkpoint_path), | |
| "--recycling_steps", | |
| str(BOLTZ2_FIXED_RECYCLING_STEPS), | |
| "--sampling_steps", | |
| str(BOLTZ2_FIXED_SAMPLING_STEPS), | |
| "--diffusion_samples", | |
| str(BOLTZ2_FIXED_DIFFUSION_SAMPLES), | |
| "--seed", | |
| str(sequence_seed), | |
| "--output_format", | |
| "pdb", | |
| ] | |
| if supports_no_kernels: | |
| command.append("--no_kernels") | |
| env = os.environ.copy() | |
| env["PL_GLOBAL_SEED"] = str(sequence_seed) | |
| env["PL_SEED_WORKERS"] = "0" | |
| env["PYTHONHASHSEED"] = str(sequence_seed) | |
| completed = subprocess.run(command, capture_output=True, text=True, check=False, env=env) | |
| if completed.returncode != 0: | |
| stderr = completed.stderr[-4000:] | |
| stdout = completed.stdout[-4000:] | |
| raise RuntimeError( | |
| "pip boltz CLI prediction failed.\n" | |
| f"Command: {' '.join(command)}\n" | |
| f"STDOUT tail:\n{stdout}\n" | |
| f"STDERR tail:\n{stderr}" | |
| ) | |
| results_root = out_root / f"boltz_results_{fasta_path.stem}" / "predictions" | |
| assert results_root.exists(), f"Reference predictions directory not found: {results_root}" | |
| pdb_candidates = sorted(results_root.rglob("*_model_*.pdb")) | |
| plddt_candidates = sorted(results_root.rglob("plddt_*_model_*.npz")) | |
| confidence_candidates = sorted(results_root.rglob("confidence_*_model_*.json")) | |
| assert len(pdb_candidates) > 0, f"No reference PDB artifacts found under {results_root}" | |
| assert len(plddt_candidates) > 0, f"No reference pLDDT npz artifacts found under {results_root}" | |
| assert len(confidence_candidates) > 0, f"No reference confidence json artifacts found under {results_root}" | |
| pdb_by_model = _map_paths_by_model(pdb_candidates) | |
| plddt_by_model = _map_paths_by_model(plddt_candidates) | |
| confidence_by_model = _map_paths_by_model(confidence_candidates) | |
| expected_model_ids = list(range(BOLTZ2_FIXED_DIFFUSION_SAMPLES)) | |
| for model_id in expected_model_ids: | |
| assert model_id in pdb_by_model, f"Missing reference PDB for model {model_id}" | |
| assert model_id in plddt_by_model, f"Missing reference pLDDT for model {model_id}" | |
| assert model_id in confidence_by_model, f"Missing reference confidence JSON for model {model_id}" | |
| atom_maps: List[Dict[Tuple[str, int, str], torch.Tensor]] = [] | |
| plddt_samples: List[torch.Tensor] = [] | |
| confidence_summaries: List[Dict[str, float]] = [] | |
| for model_id in expected_model_ids: | |
| atom_maps.append(_parse_pdb_atom_map(pdb_by_model[model_id])) | |
| with np.load(plddt_by_model[model_id]) as handle: | |
| assert "plddt" in handle.files, f"Missing 'plddt' array in {plddt_by_model[model_id]}" | |
| plddt_samples.append(torch.tensor(handle["plddt"], dtype=torch.float32)) | |
| confidence_summary = json.loads(confidence_by_model[model_id].read_text(encoding="utf-8")) | |
| for key in ["ptm", "iptm", "complex_plddt", "confidence_score"]: | |
| assert key in confidence_summary, ( | |
| f"Reference confidence summary missing key '{key}' in {confidence_by_model[model_id]}" | |
| ) | |
| confidence_summaries.append(confidence_summary) | |
| return atom_maps, plddt_samples, confidence_summaries | |
| def _shared_ca_key_order( | |
| ours_atom_maps: List[Dict[Tuple[str, int, str], torch.Tensor]], | |
| ref_atom_maps: List[Dict[Tuple[str, int, str], torch.Tensor]], | |
| ) -> List[Tuple[str, int, str]]: | |
| shared_keys = {key for key in ours_atom_maps[0] if key[2] == "CA"} | |
| for atom_map in ours_atom_maps: | |
| ca_keys = {key for key in atom_map if key[2] == "CA"} | |
| shared_keys = shared_keys.intersection(ca_keys) | |
| for atom_map in ref_atom_maps: | |
| ca_keys = {key for key in atom_map if key[2] == "CA"} | |
| shared_keys = shared_keys.intersection(ca_keys) | |
| assert len(shared_keys) > 0, "No shared CA atoms found across all samples." | |
| ordered_keys = list(shared_keys) | |
| ordered_keys.sort() | |
| return ordered_keys | |
| def _stack_coords_for_keys( | |
| atom_maps: List[Dict[Tuple[str, int, str], torch.Tensor]], | |
| ordered_keys: List[Tuple[str, int, str]], | |
| ) -> torch.Tensor: | |
| stacked_samples: List[torch.Tensor] = [] | |
| for atom_map in atom_maps: | |
| coords: List[torch.Tensor] = [] | |
| for key in ordered_keys: | |
| assert key in atom_map, f"Missing key {key} in atom map." | |
| coords.append(atom_map[key].float()) | |
| stacked_samples.append(torch.stack(coords, dim=0)) | |
| return torch.stack(stacked_samples, dim=0) | |
| def _coords_to_ca_atom_array(coords: np.ndarray) -> struc.AtomArray: | |
| assert coords.ndim == 2, "Expected coordinate array with shape [N, 3]." | |
| assert coords.shape[1] == 3, "Expected coordinate array with shape [N, 3]." | |
| assert coords.shape[0] > 0, "Expected at least one CA atom for TM-score." | |
| assert np.all(np.isfinite(coords)), "Coordinate array for TM-score contains non-finite values." | |
| atom_count = coords.shape[0] | |
| array = struc.AtomArray(atom_count) | |
| array.coord = coords.astype(np.float32, copy=False) | |
| array.atom_name = np.full(atom_count, "CA") | |
| array.res_name = np.full(atom_count, "GLY") | |
| array.chain_id = np.full(atom_count, "A") | |
| array.res_id = np.arange(1, atom_count + 1, dtype=np.int32) | |
| array.element = np.full(atom_count, "C") | |
| return array | |
| def _tm_score_from_coords(reference_coords: torch.Tensor, subject_coords: torch.Tensor) -> float: | |
| aligned_subject = _kabsch_align_mobile_to_target(subject_coords, reference_coords) | |
| reference_np = reference_coords.detach().cpu().numpy().astype(np.float64) | |
| aligned_np = aligned_subject.detach().cpu().numpy().astype(np.float64) | |
| reference_atom_array = _coords_to_ca_atom_array(reference_np) | |
| subject_atom_array = _coords_to_ca_atom_array(aligned_np) | |
| index_array = np.arange(reference_np.shape[0], dtype=np.int32) | |
| tm_value = float( | |
| TM_SCORE_FN( | |
| reference=reference_atom_array, | |
| subject=subject_atom_array, | |
| reference_indices=index_array, | |
| subject_indices=index_array, | |
| reference_length="shorter", | |
| ) | |
| ) | |
| assert np.isfinite(tm_value), "TM-score computation produced non-finite value." | |
| return tm_value | |
| def _build_tm_matrix(reference_stack: torch.Tensor, subject_stack: torch.Tensor, symmetric: bool) -> np.ndarray: | |
| assert reference_stack.ndim == 3 and subject_stack.ndim == 3, "Expected stacks with shape [S, N, 3]." | |
| assert reference_stack.shape[1:] == subject_stack.shape[1:], "Reference and subject stacks must share atom layout." | |
| matrix = np.zeros((reference_stack.shape[0], subject_stack.shape[0]), dtype=np.float32) | |
| if symmetric: | |
| assert reference_stack.shape[0] == subject_stack.shape[0], "Symmetric matrix requires same sample count." | |
| for row_idx in range(reference_stack.shape[0]): | |
| for col_idx in range(row_idx, subject_stack.shape[0]): | |
| tm_value = _tm_score_from_coords(reference_stack[row_idx], subject_stack[col_idx]) | |
| matrix[row_idx, col_idx] = tm_value | |
| matrix[col_idx, row_idx] = tm_value | |
| return matrix | |
| for row_idx in range(reference_stack.shape[0]): | |
| for col_idx in range(subject_stack.shape[0]): | |
| matrix[row_idx, col_idx] = _tm_score_from_coords(reference_stack[row_idx], subject_stack[col_idx]) | |
| return matrix | |
| def _write_tm_matrix_heatmap(path: Path, matrix: np.ndarray, title: str) -> None: | |
| fig, axis = plt.subplots(figsize=(7, 6)) | |
| image = axis.imshow(matrix, cmap="viridis", vmin=0.0, vmax=1.0, aspect="auto") | |
| axis.set_title(title) | |
| axis.set_xlabel("Column sample index") | |
| axis.set_ylabel("Row sample index") | |
| fig.colorbar(image, ax=axis, fraction=0.046, pad=0.04) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=300) | |
| plt.close(fig) | |
| def _write_tm_matrix_artifacts( | |
| matrix_dir: Path, | |
| matrix_name: str, | |
| title: str, | |
| matrix: np.ndarray, | |
| ) -> Tuple[str, str, str]: | |
| csv_path = matrix_dir / f"{matrix_name}.csv" | |
| npy_path = matrix_dir / f"{matrix_name}.npy" | |
| png_path = matrix_dir / f"{matrix_name}.png" | |
| np.savetxt(csv_path, matrix, delimiter=",", fmt="%.6f") | |
| np.save(npy_path, matrix) | |
| _write_tm_matrix_heatmap(path=png_path, matrix=matrix, title=title) | |
| return str(csv_path), str(npy_path), str(png_path) | |
| def _matrix_stats(matrix: np.ndarray) -> Dict[str, float]: | |
| flattened = matrix.reshape(-1) | |
| return { | |
| "mean": float(np.mean(flattened)), | |
| "median": float(np.median(flattened)), | |
| "min": float(np.min(flattened)), | |
| "max": float(np.max(flattened)), | |
| } | |
| def run_boltz2_compliance_suite(args: argparse.Namespace) -> int: | |
| if args.enforce_determinism: | |
| _enforce_determinism() | |
| login_if_needed(args.token) | |
| device = resolve_device(args.device) | |
| dtype = resolve_dtype(args.dtype, device) | |
| _seed_everything(seed=args.seed, workers=False) | |
| output_dir = build_output_dir(args.output_dir, "boltz2_compliance") | |
| checkpoint_path = _download_checkpoint_if_needed(Path(args.checkpoint_path)) | |
| sequences = SEQUENCE_OPTIONS | |
| supports_no_kernels = _detect_no_kernels_support() | |
| model = AutoModel.from_pretrained(args.repo_id, trust_remote_code=True) | |
| model = model.to(device=device, dtype=torch.float32).eval() | |
| rows: List[Dict[str, object]] = [] | |
| overall_pass = True | |
| for sequence_index, sequence in tqdm(list(enumerate(sequences)), desc="Boltz2 sequences", unit="seq"): | |
| started = time.perf_counter() | |
| row: Dict[str, object] = { | |
| "sequence_index": sequence_index, | |
| "sequence": sequence, | |
| "sequence_seed": args.seed + sequence_index, | |
| "ours_dtype_effective": str(dtype), | |
| "num_ours_samples": 0, | |
| "num_ref_samples": 0, | |
| "shared_atoms": 0, | |
| "shared_ca_atoms": 0, | |
| "coord_mae": float("nan"), | |
| "coord_rmse": float("nan"), | |
| "coord_max_abs": float("nan"), | |
| "coord_mae_aligned": float("nan"), | |
| "coord_rmse_aligned": float("nan"), | |
| "coord_max_abs_aligned": float("nan"), | |
| "pairwise_dist_mae": float("nan"), | |
| "plddt_mae": float("nan"), | |
| "ptm_abs_diff": float("nan"), | |
| "iptm_abs_diff": float("nan"), | |
| "complex_plddt_abs_diff": float("nan"), | |
| "confidence_score_abs_diff": float("nan"), | |
| "tm_cross_median": float("nan"), | |
| "tm_cross_mean": float("nan"), | |
| "tm_cross_min": float("nan"), | |
| "tm_cross_max": float("nan"), | |
| "tm_ref_within_median": float("nan"), | |
| "tm_ref_within_mean": float("nan"), | |
| "tm_ref_within_min": float("nan"), | |
| "tm_ref_within_max": float("nan"), | |
| "tm_ours_within_median": float("nan"), | |
| "tm_ours_within_mean": float("nan"), | |
| "tm_ours_within_min": float("nan"), | |
| "tm_ours_within_max": float("nan"), | |
| "tm_official_vs_ours_csv": "", | |
| "tm_official_vs_ours_npy": "", | |
| "tm_official_vs_ours_png": "", | |
| "tm_official_vs_official_csv": "", | |
| "tm_official_vs_official_npy": "", | |
| "tm_official_vs_official_png": "", | |
| "tm_ours_vs_ours_csv": "", | |
| "tm_ours_vs_ours_npy": "", | |
| "tm_ours_vs_ours_png": "", | |
| "ours_cif_path": "", | |
| "ref_cif_path": "", | |
| "pass": False, | |
| "seconds": 0.0, | |
| "error": "", | |
| } | |
| try: | |
| feats, template = build_boltz2_features( | |
| amino_acid_sequence=sequence, | |
| num_bins=model.config.num_bins, | |
| atoms_per_window_queries=model.core.input_embedder.atom_encoder.atoms_per_window_queries, | |
| ) | |
| feats_ours = _to_device(_clone_feats(feats), device=device, dtype=torch.float32) | |
| try: | |
| out_ours = _run_ours_forward( | |
| model=model, | |
| feats_ours=feats_ours, | |
| args=args, | |
| device=device, | |
| dtype=dtype, | |
| sequence_index=sequence_index, | |
| ) | |
| except RuntimeError as exc: | |
| error_text = str(exc) | |
| bf16_mismatch = "expected scalar type Float but found BFloat16" in error_text | |
| fp16_mismatch = "expected scalar type Float but found Half" in error_text | |
| if bf16_mismatch or fp16_mismatch: | |
| out_ours = _run_ours_forward( | |
| model=model, | |
| feats_ours=feats_ours, | |
| args=args, | |
| device=device, | |
| dtype=torch.float32, | |
| sequence_index=sequence_index, | |
| ) | |
| row["ours_dtype_effective"] = str(torch.float32) | |
| else: | |
| raise | |
| ours_atom_maps = _build_ours_atom_maps( | |
| sample_coords=out_ours["sample_atom_coords"], | |
| atom_mask=feats_ours["atom_pad_mask"][0], | |
| atom_names=template.atom_names, | |
| atom_residue_index=template.atom_residue_index, | |
| atom_chain_id=template.atom_chain_id, | |
| ) | |
| ref_atom_maps, ref_plddt_samples, ref_confidence_samples = _run_boltz_cli_reference( | |
| sequence=sequence, | |
| sequence_index=sequence_index, | |
| checkpoint_path=checkpoint_path, | |
| args=args, | |
| device=device, | |
| supports_no_kernels=supports_no_kernels, | |
| ) | |
| row["num_ours_samples"] = len(ours_atom_maps) | |
| row["num_ref_samples"] = len(ref_atom_maps) | |
| ours_atom_map_primary = ours_atom_maps[0] | |
| ref_atom_map_primary = ref_atom_maps[0] | |
| ref_plddt_primary = ref_plddt_samples[0].float().cpu().reshape(-1) | |
| ref_confidence_primary = ref_confidence_samples[0] | |
| shared_keys = [] | |
| for key in ours_atom_map_primary: | |
| if key in ref_atom_map_primary: | |
| shared_keys.append(key) | |
| shared_keys.sort() | |
| assert len(shared_keys) > 0, "No overlapping atom keys between our output and pip boltz CLI output." | |
| row["shared_atoms"] = len(shared_keys) | |
| shared_ca_keys = _shared_ca_key_order(ours_atom_maps=ours_atom_maps, ref_atom_maps=ref_atom_maps) | |
| row["shared_ca_atoms"] = len(shared_ca_keys) | |
| ours_coords_stack = torch.stack([ours_atom_map_primary[key] for key in shared_keys], dim=0) | |
| ref_coords_stack = torch.stack([ref_atom_map_primary[key] for key in shared_keys], dim=0) | |
| coord_mae, coord_rmse, coord_max_abs = _vector_metrics(ours_coords_stack, ref_coords_stack) | |
| row["coord_mae"] = coord_mae | |
| row["coord_rmse"] = coord_rmse | |
| row["coord_max_abs"] = coord_max_abs | |
| ours_coords_aligned = _kabsch_align_mobile_to_target(ours_coords_stack, ref_coords_stack) | |
| coord_mae_aligned, coord_rmse_aligned, coord_max_abs_aligned = _vector_metrics( | |
| ours_coords_aligned, | |
| ref_coords_stack, | |
| ) | |
| row["coord_mae_aligned"] = coord_mae_aligned | |
| row["coord_rmse_aligned"] = coord_rmse_aligned | |
| row["coord_max_abs_aligned"] = coord_max_abs_aligned | |
| row["pairwise_dist_mae"] = _pairwise_distance_mae(ours_coords_stack, ref_coords_stack) | |
| ours_plddt = _extract_primary_plddt_vector(out_ours, feats_ours) | |
| assert ours_plddt.numel() == ref_plddt_primary.numel(), ( | |
| f"pLDDT size mismatch (ours={ours_plddt.numel()}, ref={ref_plddt_primary.numel()})." | |
| ) | |
| row["plddt_mae"] = float(torch.mean(torch.abs(ours_plddt - ref_plddt_primary)).item()) | |
| ours_ptm = _summary_metric(out_ours["ptm"]).float().cpu() | |
| ours_iptm = _summary_metric(out_ours["iptm"]).float().cpu() | |
| ours_complex_plddt = _summary_metric(out_ours["complex_plddt"]).float().cpu() | |
| ours_confidence_score = _compute_confidence_score( | |
| ptm=ours_ptm, | |
| iptm=ours_iptm, | |
| complex_plddt=ours_complex_plddt, | |
| ) | |
| ref_ptm = torch.tensor([float(ref_confidence_primary["ptm"])], dtype=torch.float32) | |
| ref_iptm = torch.tensor([float(ref_confidence_primary["iptm"])], dtype=torch.float32) | |
| ref_complex_plddt = torch.tensor([float(ref_confidence_primary["complex_plddt"])], dtype=torch.float32) | |
| ref_confidence_score = torch.tensor([float(ref_confidence_primary["confidence_score"])], dtype=torch.float32) | |
| row["ptm_abs_diff"] = float(torch.mean(torch.abs(ours_ptm - ref_ptm)).item()) | |
| row["iptm_abs_diff"] = float(torch.mean(torch.abs(ours_iptm - ref_iptm)).item()) | |
| row["complex_plddt_abs_diff"] = float( | |
| torch.mean(torch.abs(ours_complex_plddt - ref_complex_plddt)).item() | |
| ) | |
| row["confidence_score_abs_diff"] = float( | |
| torch.mean(torch.abs(ours_confidence_score - ref_confidence_score)).item() | |
| ) | |
| if args.write_cif_artifacts: | |
| structure_dir = output_dir / "structures" / f"seq_{sequence_index}" | |
| structure_dir.mkdir(parents=True, exist_ok=True) | |
| ours_cif_path = structure_dir / f"ours_seq{sequence_index}.cif" | |
| write_cif( | |
| structure_template=template, | |
| atom_coords=out_ours["sample_atom_coords"].detach().cpu(), | |
| atom_mask=feats_ours["atom_pad_mask"][0].detach().cpu(), | |
| output_path=str(ours_cif_path), | |
| plddt=out_ours["plddt"].detach().cpu() if "plddt" in out_ours else None, | |
| sample_index=0, | |
| ) | |
| row["ours_cif_path"] = str(ours_cif_path) | |
| ref_coords_cif, ref_atom_mask_cif = _build_reference_cif_tensors( | |
| template=template, | |
| atom_pad_mask=feats_ours["atom_pad_mask"][0].detach().cpu(), | |
| ref_atom_map=ref_atom_map_primary, | |
| ) | |
| ref_cif_path = structure_dir / f"ref_seq{sequence_index}.cif" | |
| write_cif( | |
| structure_template=template, | |
| atom_coords=ref_coords_cif, | |
| atom_mask=ref_atom_mask_cif, | |
| output_path=str(ref_cif_path), | |
| plddt=ref_plddt_primary, | |
| sample_index=0, | |
| ) | |
| row["ref_cif_path"] = str(ref_cif_path) | |
| ours_ca_stack = _stack_coords_for_keys(atom_maps=ours_atom_maps, ordered_keys=shared_ca_keys) | |
| ref_ca_stack = _stack_coords_for_keys(atom_maps=ref_atom_maps, ordered_keys=shared_ca_keys) | |
| tm_official_vs_ours = _build_tm_matrix( | |
| reference_stack=ref_ca_stack, | |
| subject_stack=ours_ca_stack, | |
| symmetric=False, | |
| ) | |
| tm_official_vs_official = _build_tm_matrix( | |
| reference_stack=ref_ca_stack, | |
| subject_stack=ref_ca_stack, | |
| symmetric=True, | |
| ) | |
| tm_ours_vs_ours = _build_tm_matrix( | |
| reference_stack=ours_ca_stack, | |
| subject_stack=ours_ca_stack, | |
| symmetric=True, | |
| ) | |
| matrix_dir = output_dir / "tm_matrices" / f"seq_{sequence_index}" | |
| matrix_dir.mkdir(parents=True, exist_ok=True) | |
| csv_path, npy_path, png_path = _write_tm_matrix_artifacts( | |
| matrix_dir=matrix_dir, | |
| matrix_name="official_vs_ours", | |
| title=f"Sequence {sequence_index}: official vs ours TM-score", | |
| matrix=tm_official_vs_ours, | |
| ) | |
| row["tm_official_vs_ours_csv"] = csv_path | |
| row["tm_official_vs_ours_npy"] = npy_path | |
| row["tm_official_vs_ours_png"] = png_path | |
| csv_path, npy_path, png_path = _write_tm_matrix_artifacts( | |
| matrix_dir=matrix_dir, | |
| matrix_name="official_vs_official", | |
| title=f"Sequence {sequence_index}: official vs official TM-score", | |
| matrix=tm_official_vs_official, | |
| ) | |
| row["tm_official_vs_official_csv"] = csv_path | |
| row["tm_official_vs_official_npy"] = npy_path | |
| row["tm_official_vs_official_png"] = png_path | |
| csv_path, npy_path, png_path = _write_tm_matrix_artifacts( | |
| matrix_dir=matrix_dir, | |
| matrix_name="ours_vs_ours", | |
| title=f"Sequence {sequence_index}: ours vs ours TM-score", | |
| matrix=tm_ours_vs_ours, | |
| ) | |
| row["tm_ours_vs_ours_csv"] = csv_path | |
| row["tm_ours_vs_ours_npy"] = npy_path | |
| row["tm_ours_vs_ours_png"] = png_path | |
| cross_stats = _matrix_stats(tm_official_vs_ours) | |
| row["tm_cross_mean"] = cross_stats["mean"] | |
| row["tm_cross_median"] = cross_stats["median"] | |
| row["tm_cross_min"] = cross_stats["min"] | |
| row["tm_cross_max"] = cross_stats["max"] | |
| ref_within_stats = _matrix_stats(tm_official_vs_official) | |
| row["tm_ref_within_mean"] = ref_within_stats["mean"] | |
| row["tm_ref_within_median"] = ref_within_stats["median"] | |
| row["tm_ref_within_min"] = ref_within_stats["min"] | |
| row["tm_ref_within_max"] = ref_within_stats["max"] | |
| ours_within_stats = _matrix_stats(tm_ours_vs_ours) | |
| row["tm_ours_within_mean"] = ours_within_stats["mean"] | |
| row["tm_ours_within_median"] = ours_within_stats["median"] | |
| row["tm_ours_within_min"] = ours_within_stats["min"] | |
| row["tm_ours_within_max"] = ours_within_stats["max"] | |
| row["pass"] = bool(row["tm_cross_median"] >= args.tm_pass_threshold) | |
| if row["pass"] is False: | |
| overall_pass = False | |
| except Exception as exc: | |
| row["error"] = str(exc) | |
| overall_pass = False | |
| finally: | |
| row["seconds"] = round(time.perf_counter() - started, 4) | |
| rows.append(row) | |
| payload: Dict[str, object] = { | |
| "suite": "boltz2_compliance", | |
| "all_passed": overall_pass, | |
| "repo_id": args.repo_id, | |
| "checkpoint_path": str(checkpoint_path), | |
| "device": str(device), | |
| "dtype": str(dtype), | |
| "seed": args.seed, | |
| "enforce_determinism": args.enforce_determinism, | |
| "write_cif_artifacts": args.write_cif_artifacts, | |
| "num_sequences": len(sequences), | |
| "recycling_steps": BOLTZ2_FIXED_RECYCLING_STEPS, | |
| "num_sampling_steps": BOLTZ2_FIXED_SAMPLING_STEPS, | |
| "diffusion_samples": BOLTZ2_FIXED_DIFFUSION_SAMPLES, | |
| "tm_pass_threshold": args.tm_pass_threshold, | |
| "supports_no_kernels": supports_no_kernels, | |
| "rows": rows, | |
| } | |
| write_json(output_dir / "metrics.json", payload) | |
| write_csv(output_dir / "metrics.csv", rows) | |
| passed_count = 0 | |
| for row in rows: | |
| if bool(row["pass"]): | |
| passed_count += 1 | |
| summary_lines = [ | |
| "Suite: boltz2_compliance", | |
| f"Sequences tested: {len(rows)}", | |
| f"Sequences passed: {passed_count}", | |
| f"Sequences failed: {len(rows) - passed_count}", | |
| f"Output directory: {output_dir}", | |
| f"Device: {device}", | |
| f"Dtype: {dtype}", | |
| f"Recycling steps (fixed): {BOLTZ2_FIXED_RECYCLING_STEPS}", | |
| f"Sampling steps (fixed): {BOLTZ2_FIXED_SAMPLING_STEPS}", | |
| f"Diffusion samples (fixed): {BOLTZ2_FIXED_DIFFUSION_SAMPLES}", | |
| f"TM pass threshold: {args.tm_pass_threshold}", | |
| f"Write CIF artifacts: {args.write_cif_artifacts}", | |
| f"Reference CLI supports --no_kernels: {supports_no_kernels}", | |
| ] | |
| for row in rows: | |
| status = "PASS" if bool(row["pass"]) else "FAIL" | |
| summary_lines.append( | |
| f"{status} | idx={row['sequence_index']} | seed={row['sequence_seed']} | " | |
| f"ours_dtype={row['ours_dtype_effective']} | shared_atoms={row['shared_atoms']} | " | |
| f"shared_ca={row['shared_ca_atoms']} | tm_cross_median={row['tm_cross_median']} | " | |
| f"tm_ref_within_median={row['tm_ref_within_median']} | " | |
| f"tm_ours_within_median={row['tm_ours_within_median']} | " | |
| f"coord_aln_mae={row['coord_mae_aligned']} | plddt_mae={row['plddt_mae']} | " | |
| f"official_vs_ours_csv={row['tm_official_vs_ours_csv']} | " | |
| f"official_vs_official_csv={row['tm_official_vs_official_csv']} | " | |
| f"ours_vs_ours_csv={row['tm_ours_vs_ours_csv']} | " | |
| f"ours_cif={row['ours_cif_path']} | ref_cif={row['ref_cif_path']} | error={row['error']}" | |
| ) | |
| write_summary(output_dir / "summary.txt", summary_lines) | |
| print("\n".join(summary_lines)) | |
| if overall_pass: | |
| return 0 | |
| return 1 | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run Boltz2 compliance test against pip boltz CLI outputs.") | |
| parser.add_argument("--token", type=str, default=None) | |
| parser.add_argument("--repo-id", type=str, default="Synthyra/Boltz2") | |
| parser.add_argument("--checkpoint-path", type=str, default="boltz_fastplms/weights/boltz2_conf.ckpt") | |
| parser.add_argument("--device", type=str, default="auto") | |
| parser.add_argument("--dtype", type=str, default="float32", choices=["auto", "float32", "float16", "bfloat16"]) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--enforce-determinism", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--write-cif-artifacts", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--pass-coord-metric", type=str, default="aligned", choices=["raw", "aligned"]) | |
| parser.add_argument("--run-confidence-sequentially", action="store_true") | |
| parser.add_argument("--coord-mae-threshold", type=float, default=5e-3) | |
| parser.add_argument("--coord-rmse-threshold", type=float, default=5e-3) | |
| parser.add_argument("--coord-max-abs-threshold", type=float, default=5e-2) | |
| parser.add_argument("--plddt-mae-threshold", type=float, default=5e-3) | |
| parser.add_argument("--summary-metric-abs-threshold", type=float, default=5e-3) | |
| parser.add_argument("--tm-pass-threshold", type=float, default=0.60) | |
| parser.add_argument("--output-dir", type=str, default=None) | |
| return parser | |
| def main(argv: List[str] | None = None) -> int: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| return run_boltz2_compliance_suite(args) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |