AutoTune Models

Trained models from the AutoTune hyperparameter optimization study for astronomical transient classification.

Models

Model Architecture Checkpoint Val AUC
autotune_btsbot_optuna_asha DeiT3 DeiT3-epoch=14-val_auc=0.9999.ckpt 0.9999
autotune_btsbot_optuna_fifo DeiT3 DeiT3-epoch=15-val_auc=0.9995.ckpt 0.9995
autotune_btsbot_optuna_hyperband DeiT3 DeiT3-epoch=17-val_auc=0.9999.ckpt 0.9999
autotune_btsbot_optuna_median DeiT3 DeiT3-epoch=19-val_auc=0.9996.ckpt 0.9996
autotune_btsbot_random_asha DeiT DeiT-epoch=14-val_auc=0.9995.ckpt 0.9995
autotune_btsbot_random_hyperband CaiT CaiT-epoch=14-val_auc=0.9997.ckpt 0.9997

Usage

Load from Hugging Face Hub

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm

# Download model weights
model_path = hf_hub_download(
    repo_id="parlange/autotune-models",
    filename="autotune_btsbot_optuna_asha/model.safetensors"
)

# Load weights
state_dict = load_file(model_path)

# Create model architecture (DeiT3 example)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()

Google Cloud / Colab

!pip install huggingface_hub safetensors timm torch torchvision

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import timm
import torch
from torchvision import transforms
from PIL import Image

# Download model
model_path = hf_hub_download(
    repo_id="parlange/autotune-models",
    filename="autotune_btsbot_optuna_asha/model.safetensors"
)

# Load model
state_dict = load_file(model_path)
model = timm.create_model("deit3_base_patch16_224", pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=False)
model.eval()

# Inference
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load your triplet image (3-channel: science, reference, difference)
# image = Image.open("triplet.png").convert("RGB")
# input_tensor = transform(image).unsqueeze(0)
# with torch.no_grad():
#     output = model(input_tensor)
#     prediction = torch.softmax(output, dim=1)
#     print(f"Real probability: {prediction[0, 1]:.4f}")

Load Lightning Checkpoint

import torch

checkpoint = torch.load("checkpoint.ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]

# Remove 'model.' prefix if present
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

HPO Study Details

These models were trained using Ray Tune with various search algorithms and schedulers:

  • Optuna + ASHA: Bayesian optimization with aggressive early stopping
  • Optuna + FIFO: Bayesian optimization, all trials run to completion
  • Optuna + HyperBand: Bayesian optimization with HyperBand scheduling
  • Optuna + Median: Bayesian optimization with median stopping rule
  • Random + ASHA: Random search with ASHA early stopping
  • Random + HyperBand: Random search with HyperBand scheduling

Dataset

Trained on the BTSBot dataset for real/bogus classification of astronomical transients from the Zwicky Transient Facility (ZTF).

Citation

If you use these models, please cite:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support