Fused int4 SDPA for Apple Silicon

Core attention kernel from Open-TQ-Metal.

Open-TQ-Metal is a Metal-native implementation of fused compressed-domain attention built by Ensue. The full release enables Llama 3.1 70B at 128K context on a single 64GB Mac, 48x faster attention at 128K context, and includes a C++ inference engine, multiple attention kernels, a 330-experiment cross-architecture analysis, and a paper.

This is the Metal implementation of TurboQuant's fused compressed-domain attention (Zandieh et al., ICLR 2026): A fused int4 Scaled Dot-Product Attention Metal kernel for Apple Silicon (M1/M2/M3/M4).

Computes softmax(Q @ dequant(K_int4)^T * scale) @ dequant(V_int4) in a single kernel dispatch with online softmax β€” never materializing full K/V or score matrices.

Performance

Benchmarked on M1 Max (32 GPU cores, 64GB unified memory). 56 experiments conducted.

vs MLX int4 dequant+SDPA (apples-to-apples)

Config N Fused int4 MLX int4 Speedup
Gemma 4 sliding (sw=1024) 4096 0.34ms 3.81ms 11.1x
Gemma 4 sliding (sw=1024) 8192 0.40ms 5.79ms 14.3x
Gemma 4 full (D=512) 8192 0.98ms 7.87ms 8.1x
Llama 3.1 (D=128) 8192 0.89ms 2.53ms 2.9x

vs PyTorch native SDPA (FP32)

Config N Fused int4 Native SDPA Speedup
Gemma 4 (D=256) 2048 0.45ms 1.35ms 3.0x
Llama 3.1 (D=128) 8192 0.88ms 1.53ms 1.7x

Long-context scaling (Gemma 4 31B, 60 layers pipelined)

Context SDPA/token SDPA ceiling int4 KV
4K 5.1ms 195 tok/s 30 MB
8K 7.2ms 138 tok/s 60 MB
32K 19.6ms 51 tok/s 240 MB
128K 88ms 11 tok/s 960 MB

6.4x KV cache compression β€” 128K context in 960 MB int4 vs 6.1 GB FP32.

Quality

Metric Value
Cosine similarity (int4 vs FP32 output) 0.992
Attention KL divergence 0.004
Top-10 token overlap 90%
Kernel numerical precision < 1e-6 relative error
Deterministic Yes (bit-identical across runs)

Features

  • Online softmax β€” O(1) memory in sequence length
  • qdot pattern β€” eliminates per-nibble shift operations in K dot product
  • exp2 optimization β€” native Apple GPU instruction for 3-12% speedup
  • Adaptive split-K β€” doubles GPU utilization for models with few attention heads
  • Sliding window fast-skip β€” jumps directly to valid tokens
  • GQA support β€” arbitrary grouped-query attention factors (1-32 tested)
  • Batched decode β€” speculative decoding via flattened Q tensor
  • bfloat16 scales/biases β€” native MLX KV cache format support
  • D = 128, 256, 512 β€” all major model head dimensions

Reliability

  • 40 correctness tests (reference + Metal + architecture-specific + batched)
  • 43 stress tests (N=1-131072, GQA=1-32, sliding window edge cases)
  • 12 Gemma 4 31B end-to-end eval tests (factual, instruction, coherence, perplexity)
  • Zero memory leaks (10000 calls tested)
  • No thermal throttling (60s sustained load)
  • Deterministic (100 calls bit-identical)

int4 Format

  • uint32 packs 8 x 4-bit nibbles
  • Per-group (64 elements) asymmetric quantization: value = scale * nibble + bias
  • 6.4x compression vs FP32

Usage

import kernels

sdpa_int4 = kernels.get_kernel("EnsueAI/metal-int4-sdpa", "sdpa_int4")

# Single-token decode
output = sdpa_int4.sdpa_int4(
    queries,       # (num_heads, D) float32, MPS
    k_quant,       # (num_kv_heads, N, D//8) uint32, MPS
    k_scales,      # (num_kv_heads, N, D//64) float32, MPS
    k_biases,      # (num_kv_heads, N, D//64) float32, MPS
    v_quant,       # (num_kv_heads, N, D//8) uint32, MPS
    v_scales,      # (num_kv_heads, N, D//64) float32, MPS
    v_biases,      # (num_kv_heads, N, D//64) float32, MPS
    gqa_factor=4,  # num_heads // num_kv_heads
    N=2048,        # sequence length
    scale=0.0625,  # 1/sqrt(D)
    sliding_window=0,  # 0 = full attention
)

Tested Models

  • Llama 3.1 70B β€” 80 layers, 64 heads, 8 KV heads, D=128, GQA=8
  • Gemma 4 31B β€” 12/12 end-to-end eval tests pass (9.2 tok/s on M1 Max)
  • Llama 3.1 8B β€” 32 heads, 8 KV heads, D=128, GQA=4
  • Qwen2 7B β€” 28 heads, 4 KV heads, D=128, GQA=7
  • Mistral 7B β€” 32 heads, 8 KV heads, D=128, sliding window=4096
  • Mixtral 8x7B β€” 32 heads, 8 KV heads, D=128, MoE with sliding window
  • Arbitrary GQA factors (1-32), non-power-of-2 head counts

Origin

Ported from TurboQuant's sdpa_int4_vector kernel.

Paper: TurboQuant (arxiv.org/abs/2504.19874), ICLR 2026.

Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for EnsueAI/metal-int4-sdpa