AutoTune Models
Trained models from the AutoTune hyperparameter optimization study for astronomical transient classification.
Models
| Model | Architecture | Checkpoint | Val AUC |
|---|---|---|---|
| autotune_btsbot_optuna_asha | DeiT3 | DeiT3-epoch=14-val_auc=0.9999.ckpt |
0.9999 |
| autotune_btsbot_optuna_fifo | DeiT3 | DeiT3-epoch=15-val_auc=0.9995.ckpt |
0.9995 |
| autotune_btsbot_optuna_hyperband | DeiT3 | DeiT3-epoch=17-val_auc=0.9999.ckpt |
0.9999 |
| autotune_btsbot_optuna_median | DeiT3 | DeiT3-epoch=19-val_auc=0.9996.ckpt |
0.9996 |
| autotune_btsbot_random_asha | DeiT | DeiT-epoch=14-val_auc=0.9995.ckpt |
0.9995 |
| autotune_btsbot_random_hyperband | CaiT | CaiT-epoch=14-val_auc=0.9997.ckpt |
0.9997 |
Usage
Load from Hugging Face Hub
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
# Download model weights
model_path = hf_hub_download(
repo_id="parlange/autotune-models",
filename="autotune_btsbot_optuna_asha/model.safetensors"
)
# Load weights
state_dict = load_file(model_path)
# Create model architecture (DeiT3 example)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()
Google Cloud / Colab
!pip install huggingface_hub safetensors timm torch torchvision
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
import torch
from torchvision import transforms
from PIL import Image
# Download model
model_path = hf_hub_download(
repo_id="parlange/autotune-models",
filename="autotune_btsbot_optuna_asha/model.safetensors"
)
# Load model
state_dict = load_file(model_path)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()
# Inference
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load your triplet image (3-channel: science, reference, difference)
# image = Image.open("triplet.png").convert("RGB")
# input_tensor = transform(image).unsqueeze(0)
# with torch.no_grad():
# output = model(input_tensor)
# prediction = torch.softmax(output, dim=1)
# print(f"Real probability: {prediction[0, 1]:.4f}")
Load Lightning Checkpoint
import torch
checkpoint = torch.load("checkpoint.ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]
# Remove 'model.' prefix if present
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
HPO Study Details
These models were trained using Ray Tune with various search algorithms and schedulers:
- Optuna + ASHA: Bayesian optimization with aggressive early stopping
- Optuna + FIFO: Bayesian optimization, all trials run to completion
- Optuna + HyperBand: Bayesian optimization with HyperBand scheduling
- Optuna + Median: Bayesian optimization with median stopping rule
- Random + ASHA: Random search with ASHA early stopping
- Random + HyperBand: Random search with HyperBand scheduling
Dataset
Trained on the BTSBot dataset for real/bogus classification of astronomical transients from the Zwicky Transient Facility (ZTF).
Citation
If you use these models, please cite:
- AutoTune Repository: https://github.com/parlange/autotune
- BTSBot Dataset: MultimodalUniverse/btsbot