This README has been auto-generated by the HF Job run linked below and the whole repository is a reproducible artifact of this Job

Ahead-of-time repository

AoT repos contain pre-compiled binaries of PyTorch models, enabling:

  • fast startup times (no torch.compile needed)
  • significant speedup
  • ZeroGPU compatibility

How to use


import gc
from contextlib import contextmanager

import spaces
import torch
import torch.utils._pytree as pytree
from huggingface_hub import hf_hub_download, snapshot_download

from ltx_pipelines.distilled import DistilledPipeline
from ltx_pipelines.utils.denoisers import SimpleDenoiser
from ltx_pipelines.utils.types import ModalitySpec
from ltx_pipelines.utils.constants import DISTILLED_SIGMAS, STAGE_2_DISTILLED_SIGMAS
from ltx_pipelines.utils.helpers import combined_image_conditionings
from ltx_pipelines.utils.media_io import encode_video
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.model.transformer.transformer_args import TransformerArgs
from ltx_core.components.noisers import GaussianNoiser


import os as _os
# Reduce CUDA-allocator fragmentation; the LTX-2 distilled pipeline exhausts a
# 96GB GPU otherwise.
_os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

# cuDNN 9.20 + cuBLAS 13.x mismatch causes a `cublasLtGetVersion` segfault at
# teardown on cu130. The transformer block has no convs anyway.
torch.backends.cudnn.enabled = False


# Register TransformerArgs as a pytree node so torch.export can flatten the
# dataclass at the export boundary. We exclude `enabled` (bool) so AOTI's
# boxed_run gets only tensor leaves — otherwise the unflattening at runtime
# produces 2 extra non-tensor leaves that the compiled artifact doesn't expect.
_TA_TENSOR_FIELDS = [
    "x", "context", "context_mask",
    "timesteps", "embedded_timestep",
    "positional_embeddings", "cross_positional_embeddings",
    "cross_scale_shift_timestep", "cross_gate_timestep",
    "prompt_timestep", "self_attention_mask",
]


def _ta_flatten(obj):
    values, nones = [], []
    for name in _TA_TENSOR_FIELDS:
        v = getattr(obj, name)
        (nones if v is None else values).append((name, v))
    flat_names = [n for n, _ in values]
    none_names = [n for n, _ in nones]
    return [v for _, v in values], (flat_names, none_names)


def _ta_unflatten(values, context):
    flat_names, none_names = context
    kwargs = dict(zip(flat_names, values, strict=True))
    for n in none_names:
        kwargs[n] = None
    kwargs["enabled"] = True
    return TransformerArgs(**kwargs)


def _ta_flatten_with_keys(obj):
    flat, ctx = _ta_flatten(obj)
    flat_names, _ = ctx
    return [(pytree.GetAttrKey(n), v) for n, v in zip(flat_names, flat, strict=True)], ctx


pytree.register_pytree_node(
    TransformerArgs,
    _ta_flatten,
    _ta_unflatten,
    serialized_type_name="ltx_core.model.transformer.transformer_args.TransformerArgs",
    flatten_with_keys_fn=_ta_flatten_with_keys,
)


distilled_path = hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-distilled.safetensors")
upsampler_path = hf_hub_download("Lightricks/LTX-2", "ltx-2-spatial-upscaler-x2-1.0.safetensors")
gemma_root = snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized")


pipeline = DistilledPipeline(
    distilled_checkpoint_path=distilled_path,
    gemma_root=gemma_root,
    spatial_upsampler_path=upsampler_path,
    loras=[],
    device=torch.device("cuda"),
)
# DiffusionStage rebuilds the transformer on every __call__; we want a single
# persistent instance so the AOTI artifact survives between samples.
pipeline._transformer = pipeline.stage._build_transformer()


@contextmanager
def _persistent_transformer_ctx(**_kwargs):
    yield pipeline._transformer


pipeline.stage._transformer_ctx = _persistent_transformer_ctx

spaces.aoti_load(
    module=pipeline._transformer,
    repo_id='cbensimon/X0Model-sm120-cu130-rd1',
)

How to reproduce or customize

# Install hf CLI
curl -LsSf https://hf.co/cli/install.sh | bash

# Login
hf auth login

# Get the job file and edit (user section) if needed
hf download cbensimon/X0Model-sm120-cu130-rd1 job.py --local-dir .

# Run the job and change flavor or image if needed
hf jobs uv run job.py \
    --flavor rtx-pro-6000 \
    --image pytorch/pytorch:2.9.1-cuda13.0-cudnn9-devel \
    --secrets HF_TOKEN

# Or run locally with Docker
docker run --rm --gpus all \
    -v $PWD/job.py:/workspace/job.py \
    -e HF_TOKEN=$(hf auth token) \
    -e JOB_IMAGE=pytorch/pytorch:2.9.1-cuda13.0-cudnn9-devel \
    -e JOB_FLAVOR=rtx-pro-6000 \
    pytorch/pytorch:2.9.1-cuda13.0-cudnn9-devel \
    uv run /workspace/job.py

The following job environment variables can be used to customize the repo name generation:

  • OUTPUT_REPO_NAMESPACE: taken from HF_TOKEN otherwise
  • OUTPUT_REPO_BASE_NAME: defaults to module class name
  • OUTPUT_REPO_ID: fully overtakes name generation

Samples

Generated as part of the compilation job: before and after compilation

Before compilation (3.27s) After compilation (3.20s)
Speedup: 1.02x
(note that this might not always reflect actual performance gain)

Environment

Click to expand
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.35

Python version: 3.10.19 (main, Oct 31 2025, 23:02:46) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-6.17.0-1013-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 13.0.48
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA RTX PRO 6000 Blackwell Server Edition
Nvidia driver version: 595.58.03
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: {'PYTORCH_CUDA_ALLOC_CONF': 'expandable_segments:True'}

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  16
On-line CPU(s) list:                     0-15
Vendor ID:                               GenuineIntel
Model name:                              Intel(R) Xeon(R) Platinum 8559C
CPU family:                              6
Model:                                   207
Thread(s) per core:                      2
Core(s) per socket:                      8
Socket(s):                               1
Stepping:                                2
BogoMIPS:                                4800.00
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd ida arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear serialize amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Hypervisor vendor:                       KVM
Virtualization type:                     full
L1d cache:                               384 KiB (8 instances)
L1i cache:                               256 KiB (8 instances)
L2 cache:                                16 MiB (8 instances)
L3 cache:                                320 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-15
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] Could not collect
[conda] numpy                        2.3.4            py311h2e04523_0       conda-forge
[conda] nvidia-cublas                13.0.0.19        pypi_0                pypi
[conda] nvidia-cuda-cupti            13.0.48          pypi_0                pypi
[conda] nvidia-cuda-nvrtc            13.0.48          pypi_0                pypi
[conda] nvidia-cuda-runtime          13.0.48          pypi_0                pypi
[conda] nvidia-cudnn-cu13            9.13.0.50        pypi_0                pypi
[conda] nvidia-cufft                 12.0.0.15        pypi_0                pypi
[conda] nvidia-curand                10.4.0.35        pypi_0                pypi
[conda] nvidia-cusolver              12.0.3.29        pypi_0                pypi
[conda] nvidia-cusparse              12.6.2.49        pypi_0                pypi
[conda] nvidia-cusparselt-cu13       0.8.0            pypi_0                pypi
[conda] nvidia-nccl-cu13             2.27.7           pypi_0                pypi
[conda] nvidia-nvjitlink             13.0.39          pypi_0                pypi
[conda] nvidia-nvtx                  13.0.39          pypi_0                pypi
[conda] optree                       0.17.0           pypi_0                pypi
[conda] torch                        2.9.1+cu130      pypi_0                pypi
[conda] torchaudio                   2.9.1+cu130      pypi_0                pypi
[conda] torchelastic                 0.2.2            pypi_0                pypi
[conda] torchvision                  0.24.1+cu130     pypi_0                pypi
[conda] triton                       3.5.1            pypi_0                pypi

Job run

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support