Instructions to use EnsueAI/metal-int4-sdpa with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use EnsueAI/metal-int4-sdpa with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("EnsueAI/metal-int4-sdpa") - Notebooks
- Google Colab
- Kaggle
Fused int4 SDPA for Apple Silicon
Core attention kernel from Open-TQ-Metal.
Open-TQ-Metal is a Metal-native implementation of fused compressed-domain attention built by Ensue. The full release enables Llama 3.1 70B at 128K context on a single 64GB Mac, 48x faster attention at 128K context, and includes a C++ inference engine, multiple attention kernels, a 330-experiment cross-architecture analysis, and a paper.
- Paper: https://arxiv.org/pdf/2604.16957
- Write-up: https://ensue.dev/blog/introducing-open-tq-metal/
- Llama 3.1 70B engine: https://github.com/mutable-state-inc/turboquant-llama3.170B
- Gemma 4 31B engine: https://github.com/mutable-state-inc/gemma4metal
This is the Metal implementation of TurboQuant's fused compressed-domain attention (Zandieh et al., ICLR 2026): A fused int4 Scaled Dot-Product Attention Metal kernel for Apple Silicon (M1/M2/M3/M4).
Computes softmax(Q @ dequant(K_int4)^T * scale) @ dequant(V_int4) in a single
kernel dispatch with online softmax β never materializing full K/V or score matrices.
Performance
Benchmarked on M1 Max (32 GPU cores, 64GB unified memory). 56 experiments conducted.
vs MLX int4 dequant+SDPA (apples-to-apples)
| Config | N | Fused int4 | MLX int4 | Speedup |
|---|---|---|---|---|
| Gemma 4 sliding (sw=1024) | 4096 | 0.34ms | 3.81ms | 11.1x |
| Gemma 4 sliding (sw=1024) | 8192 | 0.40ms | 5.79ms | 14.3x |
| Gemma 4 full (D=512) | 8192 | 0.98ms | 7.87ms | 8.1x |
| Llama 3.1 (D=128) | 8192 | 0.89ms | 2.53ms | 2.9x |
vs PyTorch native SDPA (FP32)
| Config | N | Fused int4 | Native SDPA | Speedup |
|---|---|---|---|---|
| Gemma 4 (D=256) | 2048 | 0.45ms | 1.35ms | 3.0x |
| Llama 3.1 (D=128) | 8192 | 0.88ms | 1.53ms | 1.7x |
Long-context scaling (Gemma 4 31B, 60 layers pipelined)
| Context | SDPA/token | SDPA ceiling | int4 KV |
|---|---|---|---|
| 4K | 5.1ms | 195 tok/s | 30 MB |
| 8K | 7.2ms | 138 tok/s | 60 MB |
| 32K | 19.6ms | 51 tok/s | 240 MB |
| 128K | 88ms | 11 tok/s | 960 MB |
6.4x KV cache compression β 128K context in 960 MB int4 vs 6.1 GB FP32.
Quality
| Metric | Value |
|---|---|
| Cosine similarity (int4 vs FP32 output) | 0.992 |
| Attention KL divergence | 0.004 |
| Top-10 token overlap | 90% |
| Kernel numerical precision | < 1e-6 relative error |
| Deterministic | Yes (bit-identical across runs) |
Features
- Online softmax β O(1) memory in sequence length
- qdot pattern β eliminates per-nibble shift operations in K dot product
- exp2 optimization β native Apple GPU instruction for 3-12% speedup
- Adaptive split-K β doubles GPU utilization for models with few attention heads
- Sliding window fast-skip β jumps directly to valid tokens
- GQA support β arbitrary grouped-query attention factors (1-32 tested)
- Batched decode β speculative decoding via flattened Q tensor
- bfloat16 scales/biases β native MLX KV cache format support
- D = 128, 256, 512 β all major model head dimensions
Reliability
- 40 correctness tests (reference + Metal + architecture-specific + batched)
- 43 stress tests (N=1-131072, GQA=1-32, sliding window edge cases)
- 12 Gemma 4 31B end-to-end eval tests (factual, instruction, coherence, perplexity)
- Zero memory leaks (10000 calls tested)
- No thermal throttling (60s sustained load)
- Deterministic (100 calls bit-identical)
int4 Format
- uint32 packs 8 x 4-bit nibbles
- Per-group (64 elements) asymmetric quantization:
value = scale * nibble + bias - 6.4x compression vs FP32
Usage
import kernels
sdpa_int4 = kernels.get_kernel("EnsueAI/metal-int4-sdpa", "sdpa_int4")
# Single-token decode
output = sdpa_int4.sdpa_int4(
queries, # (num_heads, D) float32, MPS
k_quant, # (num_kv_heads, N, D//8) uint32, MPS
k_scales, # (num_kv_heads, N, D//64) float32, MPS
k_biases, # (num_kv_heads, N, D//64) float32, MPS
v_quant, # (num_kv_heads, N, D//8) uint32, MPS
v_scales, # (num_kv_heads, N, D//64) float32, MPS
v_biases, # (num_kv_heads, N, D//64) float32, MPS
gqa_factor=4, # num_heads // num_kv_heads
N=2048, # sequence length
scale=0.0625, # 1/sqrt(D)
sliding_window=0, # 0 = full attention
)
Tested Models
- Llama 3.1 70B β 80 layers, 64 heads, 8 KV heads, D=128, GQA=8
- Gemma 4 31B β 12/12 end-to-end eval tests pass (9.2 tok/s on M1 Max)
- Llama 3.1 8B β 32 heads, 8 KV heads, D=128, GQA=4
- Qwen2 7B β 28 heads, 4 KV heads, D=128, GQA=7
- Mistral 7B β 32 heads, 8 KV heads, D=128, sliding window=4096
- Mixtral 8x7B β 32 heads, 8 KV heads, D=128, MoE with sliding window
- Arbitrary GQA factors (1-32), non-power-of-2 head counts
Origin
Ported from TurboQuant's sdpa_int4_vector kernel.
Paper: TurboQuant (arxiv.org/abs/2504.19874), ICLR 2026.
- Downloads last month
- 3