IBM Granite Speech 4.1 2b - ONNX export
ONNX export of ibm-granite/granite-speech-4.1-2b produced by smcleod. Three precision tiers (fp32/, int8/, fp16w/) ship in this repo - see Files below for sizes and trade-offs. The graphs target opset 20 / IR 10 / ai.onnx-only, so they load under the ort 2.0-rc.x Rust crate and onnxruntime 1.17 - 1.25.
Four graphs cooperate: encoder.onnx projects mel features to audio embeddings; embed_tokens.onnx looks up text-token embeddings from the LLM vocab; the host splices the audio embeddings in at <|audio|> placeholder positions (token id 100352); prompt_encode.onnx runs the LLM forward over the full inputs_embeds and returns the first-token logits plus a 40-layer KV cache; decode_step.onnx consumes one token at a time plus the past KV cache and emits the next logits. See How to use for the call sequence and KV-cache layout.
The audio placeholder token id is 100352. Replace those positions in the prompt with the projector outputs from encoder.onnx before running prompt_encode.onnx.
Files
Each precision tier ships in its own subdirectory (fp32/, int8/, fp16w/). Inside, files use the clean stem (no precision suffix) - the directory name carries the tier. Download a single subdirectory if you only need one precision; the tokeniser, processor, scripts, and metadata at the bundle root are shared across all tiers.
fp32/ - FP32 (reference, full precision) - 15.8 GB total
Use when you need byte-for-byte parity with the upstream PyTorch reference, or as a baseline for quantisation/conversion experiments.
fp32/encoder.onnx+fp32/encoder.onnx_datafp32/prompt_encode.onnx+fp32/prompt_encode.onnx_datafp32/decode_step.onnx+fp32/decode_step.onnx_datafp32/embed_tokens.onnx+fp32/embed_tokens.onnx_data
int8/ - INT8 (smallest) - 4.0 GB total
Dynamic weights-only INT8 (MatMulInteger + ConvInteger, all ai.onnx). Mild quality drop on case/punctuation but transcripts remain semantically accurate. Choose when disk or memory is tight.
int8/encoder.onnx+int8/encoder.onnx_dataint8/prompt_encode.onnx+int8/prompt_encode.onnx_dataint8/decode_step.onnx+int8/decode_step.onnx_dataint8/embed_tokens.onnx+int8/embed_tokens.onnx_data
fp16w/ - FP16w (recommended for highest quality at smaller-than-FP32 size) - 7.9 GB total
Weights-FP16 with FP32 compute and IO. Each FP32 initializer is rewritten to FP16 storage with a Cast(FP16->FP32) inserted before each consumer; arithmetic and IO stay FP32. Quality is essentially identical to FP32 (mean norm WER 0.04% vs 0.72% for INT8) at 50% of FP32 storage. Choose when you have the disk and want FP32-grade transcripts.
fp16w/encoder.onnx+fp16w/encoder.onnx_datafp16w/prompt_encode.onnx+fp16w/prompt_encode.onnx_datafp16w/decode_step.onnx+fp16w/decode_step.onnx_datafp16w/embed_tokens.onnx+fp16w/embed_tokens.onnx_data
Shared (used by every precision tier)
- Tokeniser / processor:
tokenizer.json,tokenizer_config.json,processor_config.json,chat_template.jinja,special_tokens_map.json,preprocessor_config.json - Export scripts:
export_speech_2b_ar.py,export_embed_tokens.py,quantise.py,convert_fp.py granite_export_metadata.json(graph IO, parity numbers, toolchain)LICENSE(Apache 2.0)test_fixtures/- golden inputs/outputs for integration testing. Seetest_fixtures/README.md.
How to use end-to-end
Audio frontend (shared across variants)
16 kHz mono float32 PCM. Mel filter bank: n_fft=512, hop_length=160,
win_length=400, n_mels=80, log10 with -8 dB clamp, no preemphasis.
The frontend frame-stacks pairs of mel frames, doubling the feature dim to
160 (so the encoder input is [B, T, 160] not [B, T, 80]). The reference
implementation lives in the transformers AutoProcessor for this repo;
the parameters above are reproduced for runtimes that build their own STFT.
The included test_fixtures/expected_input_features.npy is the post-mel
output - if your audio frontend produces the same array on the included
sample, your preprocessing is byte-correct.
Call sequence (autoregressive)
Four graphs and a host-side splice. Output of one is input to the next; only
prompt_encode and decode_step consume the KV cache.
1. encoder.onnx (16 kHz audio mel-bank features) -> audio_embeds, audio_embed_sizes
2. embed_tokens.onnx (chat-template token IDs) -> text_embeds
3. splice (host-side) (text_embeds with audio_embeds at <|audio|> positions, token id 100352)
-> inputs_embeds [B, N, 2048]
4. prompt_encode.onnx (inputs_embeds, position_ids, 4-D attention_mask)
-> first-token logits, 40-layer KV cache
5. loop decode_step.onnx (next inputs_embeds via embed_tokens(prev_token), past KV cache)
-> next logits + grown KV cache
... until argmax(logits) == EOS
The chat template (rendered from chat_template.jinja + tokenizer.json)
inserts the <|audio|> placeholder; the host code replaces those positions
with rows from audio_embeds before calling prompt_encode. The placeholder
token id is 100352.
Initial KV cache layout
decode_step.onnx consumes past_key_values.{i}.key and
past_key_values.{i}.value for i in [0, 40). For the first decode step
(after prompt_encode has already run), feed in the present.{i}.{key,value}
tensors that prompt_encode returned. There is no separate "empty cache"
ceremony - prompt_encode does the whole prefill.
The KV tensor shapes (GQA with 4 heads, head_dim 128):
past_key_values.{i}.key : float32 [B, 4, T_past, 128]
past_key_values.{i}.value : float32 [B, 4, T_past, 128]
present.{i}.key : float32 [B, 4, T_total, 128] (T_total = T_past + 1)
present.{i}.value : float32 [B, 4, T_total, 128]
attention_mask : float32 [B, 1, 1, T_total] (additive: 0 for valid, -inf for masked)
position_ids : int64 [B, 1] (= T_past at step k)
At step k, set position_ids = T_past, slide the previous present.{i}.{key,value}
into the next call's past_key_values.{i}.{key,value}, and grow attention_mask
by one column of 0.0. The lm_head MatMul that produces logits is divided by
config.text_config.logits_scaling = 8 inside the graph; argmax is unaffected,
but if you compare logit values against another runtime, divide-by-8 is
already baked in.
Tokeniser
We ship tokenizer.json + tokenizer_config.json + chat_template.jinja from
ibm-granite/granite-speech-4.1-2b. For Rust, load via
tokenizers::Tokenizer::from_file("tokenizer.json") and apply the chat template
yourself (the tokeniser crate doesn't render Jinja). The chat template is short
- a system role, one user role with the
<|audio|>placeholder, and a final<|assistant|>opener.
Runtime / EP notes
- The graphs are
ai.onnx-only at opset 20; nocom.microsoft.*ops, noMatMulNBits. They load under theort2.0-rc.x Rust crate andonnxruntimePython 1.17 - 1.25. - CoreML EP: opset 20 contains a few ops that the CoreML EP doesn't kernel. ORT will silently fall back to CPU for those nodes at session-load time. If you want fully-accelerated MPS, FP16w is the closest available analogue (FP16 weights, FP32 compute) - the encoder is mostly conv + matmul which CoreML can take.
- CUDA EP: works out of the box at FP32 / INT8. FP16w is also fine - the Cast(FP16->FP32) is a cheap kernel.
- CPU: VNNI / AVX-512 systems get a meaningful win from INT8 (
MatMulIntegerConvInteger); systems without those land in roughly the same FP32 ballpark.
How the tiers are produced
- INT8 is dynamic, weights-only, per-channel
QInt8overMatMul+Convops. The quantiser emitsMatMulInteger+ConvIntegerand leaves activations in FP32. The unquantised ~22% of MatMul nodes in the LLM body graphs are activation x activation (attentionQK^Tandattention_weights x V); dynamic weight-only INT8 cannot quantise those, so this is the expected ceiling, not a coverage gap. - FP16w stores weights as FP16 initializers with a
Cast(FP16->FP32)inserted before each consumer, so arithmetic and IO stay FP32. Quality matches FP32 within numeric tolerance at ~50% of FP32 storage. embed_tokensis shipped as its own graph in all three tiers. INT8 uses per-row symmetric quantisation rather than the dynamic MatMul/Conv quantiser (Gather is not in that op set), giving the embedding table its own ~4x storage win at INT8.- No
com.microsoft.*ops are used. Re-validate the op-domain set withassert_pure_ai_onnxinquantise.py/convert_fp.pyafter any change.
Parity
Parity is taken against the upstream PyTorch reference on a single LibriSpeech
clip (10226_10111_000000.wav, 8.43 seconds, 844 mel frames). FP32 graphs
match the reference within numeric tolerance; INT8 graphs are validated in
argmax-only mode (logit values shift but token argmax is preserved, so the
decoded transcript is unchanged).
Encoder (numeric output, no argmax decoding):
| precision | max-abs-err | mean-abs-err | p99-abs-err |
|---|---|---|---|
| FP32 | 4.48e-06 | 1.24e-07 | 6.46e-07 |
| INT8 | 0.169 | 0.0109 | 0.0447 |
LLM stages (argmax decoding; INT8 logit max-abs delta is large but argmax is preserved):
| graph | precision | max-abs-err | argmax mismatches | transcript match |
|---|---|---|---|---|
| prompt_encode | FP32 | 0.000364 | 0/190 | Y |
| prompt_encode | INT8 | 10.1 | 58/190 | Y |
| decode_step | FP32 | n/a | 0/51 | Y |
| decode_step | INT8 | 5.76 | 0/51 | Y |
Multi-clip transcript parity
Three additional 16 kHz mono clips covering longer utterances (39 to 94 seconds), single and two-speaker conversational content. Word error rate (WER) and Levenshtein edit distance computed against the upstream PyTorch reference. Numbers measured end-to-end through the full ONNX pipeline (no PyTorch encoder fallback).
WER is the strict word-error rate against the PyTorch reference (case + punctuation sensitive). norm WER lower-cases both transcripts and strips punctuation before comparing - the dominant driver of strict WER on this model at INT8 is capitalisation and trailing punctuation drift, not actual word substitution. Pick whichever metric matches your downstream task. FP16w is essentially FP32 quality at 50% of FP32 storage; INT8 is the smallest tier with a mild quality drop.
| Clip | Duration | FP32 byte-exact | INT8 byte-exact | INT8 WER | INT8 norm WER | FP16w byte-exact | FP16w WER | FP16w norm WER |
|---|---|---|---|---|---|---|---|---|
| is-it-more-wood | 46.9 s | Y | N | 37.4% | 0.00% | Y | 0.0% | 0.00% |
| two-speakers-1 | 93.8 s | Y | N | 3.1% | 1.36% | Y | 0.0% | 0.00% |
| two-speakers-2 | 38.8 s | Y | N | 23.5% | 0.00% | Y | 0.0% | 0.00% |
Raw multi-clip data including full transcripts: see granite_export_metadata.json multi_clip_parity block.
Reference transcript:
After his nap, Timothy lazily stretched, first one gray velvet foot, then another, strolled indolently to his plate, turning over the food, carefully selecting choice bits, nosing out that which he scorned upon the clean hearth
The FP32 and FP16w paths reproduce this transcript exactly on the test clip, and INT8 reproduces it within argmax-only tolerance (token argmax preserved).
Toolchain
- transformers 5.8.0
- torch 2.11.0
- onnx 1.21.0
- onnxruntime 1.25.1
- exporter: torch.onnx.export TorchScript path (dynamo=False)
- opset: 20 (
ai.onnxonly) - IR version: 10
- external data layout: single
<stem>.onnx_datasidecar per graph
Compatibility
Targeted at the ort 2.0-rc.x Rust crate.
Compatible with onnxruntime Python 1.17 through 1.25. No com.microsoft
ops are used. Graphs were emitted via the TorchScript path
(torch.onnx.export(..., dynamo=False)); the dynamo exporter was deliberately
avoided because it injects aten::* ops ort does not understand. See the
Runtime / EP notes above for CoreML / CUDA / CPU
specifics including which precision tier to pick per backend.
Reproducing the export
The included scripts and quantise.py regenerate every artefact in this
bundle. The export pipeline writes flat-layout files into exports/<variant>/;
the per-tier subdirectory layout you see in this repo is produced by
scripts/stage_bundles.py (in the source tree at
https://github.com/sammcj/granite-speech-4.1-onnx). From a checkout:
python export_speech_2b_ar.py \
--model-dir <path-to-ibm-granite/granite-speech-4.1-2b> \
--out-dir exports/granite-speech-4.1-2b
python export_embed_tokens.py --variant base
# INT8 (per-variant default for 2b: exclude Conv ops in encoder layer 0).
# embed_tokens uses a hand-rolled per-row INT8 path baked into export_embed_tokens.py;
# the standard `quantise.py` (MatMul/Conv only) does not touch Gather ops.
python quantise.py --input exports/granite-speech-4.1-2b/encoder.onnx --output exports/granite-speech-4.1-2b/encoder_int8.onnx --exclude-pattern '/encoder/layers\.0/conv/.*'
python quantise.py --input exports/granite-speech-4.1-2b/prompt_encode.onnx --output exports/granite-speech-4.1-2b/prompt_encode_int8.onnx
python quantise.py --input exports/granite-speech-4.1-2b/decode_step.onnx --output exports/granite-speech-4.1-2b/decode_step_int8.onnx
# FP16w (weights-FP16, FP32 compute - no exclusions needed).
# embed_tokens FP16w is also produced by export_embed_tokens.py in the same run.
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/encoder.onnx --output exports/granite-speech-4.1-2b/encoder_fp16w.onnx
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/prompt_encode.onnx --output exports/granite-speech-4.1-2b/prompt_encode_fp16w.onnx
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/decode_step.onnx --output exports/granite-speech-4.1-2b/decode_step_fp16w.onnx
Licence
Apache 2.0 for both the upstream IBM model and this ONNX export. See
LICENSE for the full text.
Model tree for smcleod/ibm-granite-speech-4.1-2b-onnx
Base model
ibm-granite/granite-4.0-1b-base