IBM Granite Speech 4.1 2b - ONNX export

ONNX export of ibm-granite/granite-speech-4.1-2b produced by smcleod. Three precision tiers (fp32/, int8/, fp16w/) ship in this repo - see Files below for sizes and trade-offs. The graphs target opset 20 / IR 10 / ai.onnx-only, so they load under the ort 2.0-rc.x Rust crate and onnxruntime 1.17 - 1.25.

Four graphs cooperate: encoder.onnx projects mel features to audio embeddings; embed_tokens.onnx looks up text-token embeddings from the LLM vocab; the host splices the audio embeddings in at <|audio|> placeholder positions (token id 100352); prompt_encode.onnx runs the LLM forward over the full inputs_embeds and returns the first-token logits plus a 40-layer KV cache; decode_step.onnx consumes one token at a time plus the past KV cache and emits the next logits. See How to use for the call sequence and KV-cache layout.

The audio placeholder token id is 100352. Replace those positions in the prompt with the projector outputs from encoder.onnx before running prompt_encode.onnx.

Files

Each precision tier ships in its own subdirectory (fp32/, int8/, fp16w/). Inside, files use the clean stem (no precision suffix) - the directory name carries the tier. Download a single subdirectory if you only need one precision; the tokeniser, processor, scripts, and metadata at the bundle root are shared across all tiers.

fp32/ - FP32 (reference, full precision) - 15.8 GB total

Use when you need byte-for-byte parity with the upstream PyTorch reference, or as a baseline for quantisation/conversion experiments.

  • fp32/encoder.onnx + fp32/encoder.onnx_data
  • fp32/prompt_encode.onnx + fp32/prompt_encode.onnx_data
  • fp32/decode_step.onnx + fp32/decode_step.onnx_data
  • fp32/embed_tokens.onnx + fp32/embed_tokens.onnx_data

int8/ - INT8 (smallest) - 4.0 GB total

Dynamic weights-only INT8 (MatMulInteger + ConvInteger, all ai.onnx). Mild quality drop on case/punctuation but transcripts remain semantically accurate. Choose when disk or memory is tight.

  • int8/encoder.onnx + int8/encoder.onnx_data
  • int8/prompt_encode.onnx + int8/prompt_encode.onnx_data
  • int8/decode_step.onnx + int8/decode_step.onnx_data
  • int8/embed_tokens.onnx + int8/embed_tokens.onnx_data

fp16w/ - FP16w (recommended for highest quality at smaller-than-FP32 size) - 7.9 GB total

Weights-FP16 with FP32 compute and IO. Each FP32 initializer is rewritten to FP16 storage with a Cast(FP16->FP32) inserted before each consumer; arithmetic and IO stay FP32. Quality is essentially identical to FP32 (mean norm WER 0.04% vs 0.72% for INT8) at 50% of FP32 storage. Choose when you have the disk and want FP32-grade transcripts.

  • fp16w/encoder.onnx + fp16w/encoder.onnx_data
  • fp16w/prompt_encode.onnx + fp16w/prompt_encode.onnx_data
  • fp16w/decode_step.onnx + fp16w/decode_step.onnx_data
  • fp16w/embed_tokens.onnx + fp16w/embed_tokens.onnx_data

Shared (used by every precision tier)

  • Tokeniser / processor: tokenizer.json, tokenizer_config.json, processor_config.json, chat_template.jinja, special_tokens_map.json, preprocessor_config.json
  • Export scripts: export_speech_2b_ar.py, export_embed_tokens.py, quantise.py, convert_fp.py
  • granite_export_metadata.json (graph IO, parity numbers, toolchain)
  • LICENSE (Apache 2.0)
  • test_fixtures/ - golden inputs/outputs for integration testing. See test_fixtures/README.md.

How to use end-to-end

Audio frontend (shared across variants)

16 kHz mono float32 PCM. Mel filter bank: n_fft=512, hop_length=160, win_length=400, n_mels=80, log10 with -8 dB clamp, no preemphasis. The frontend frame-stacks pairs of mel frames, doubling the feature dim to 160 (so the encoder input is [B, T, 160] not [B, T, 80]). The reference implementation lives in the transformers AutoProcessor for this repo; the parameters above are reproduced for runtimes that build their own STFT. The included test_fixtures/expected_input_features.npy is the post-mel output - if your audio frontend produces the same array on the included sample, your preprocessing is byte-correct.

Call sequence (autoregressive)

Four graphs and a host-side splice. Output of one is input to the next; only prompt_encode and decode_step consume the KV cache.

1. encoder.onnx            (16 kHz audio mel-bank features)  -> audio_embeds, audio_embed_sizes
2. embed_tokens.onnx       (chat-template token IDs)         -> text_embeds
3. splice (host-side)      (text_embeds with audio_embeds at <|audio|> positions, token id 100352)
                                                              -> inputs_embeds [B, N, 2048]
4. prompt_encode.onnx      (inputs_embeds, position_ids, 4-D attention_mask)
                                                              -> first-token logits, 40-layer KV cache
5. loop decode_step.onnx   (next inputs_embeds via embed_tokens(prev_token), past KV cache)
                                                              -> next logits + grown KV cache
                              ... until argmax(logits) == EOS

The chat template (rendered from chat_template.jinja + tokenizer.json) inserts the <|audio|> placeholder; the host code replaces those positions with rows from audio_embeds before calling prompt_encode. The placeholder token id is 100352.

Initial KV cache layout

decode_step.onnx consumes past_key_values.{i}.key and past_key_values.{i}.value for i in [0, 40). For the first decode step (after prompt_encode has already run), feed in the present.{i}.{key,value} tensors that prompt_encode returned. There is no separate "empty cache" ceremony - prompt_encode does the whole prefill.

The KV tensor shapes (GQA with 4 heads, head_dim 128):

past_key_values.{i}.key   : float32 [B, 4, T_past, 128]
past_key_values.{i}.value : float32 [B, 4, T_past, 128]
present.{i}.key           : float32 [B, 4, T_total, 128]   (T_total = T_past + 1)
present.{i}.value         : float32 [B, 4, T_total, 128]
attention_mask              : float32 [B, 1, 1, T_total]      (additive: 0 for valid, -inf for masked)
position_ids                : int64   [B, 1]                  (= T_past at step k)

At step k, set position_ids = T_past, slide the previous present.{i}.{key,value} into the next call's past_key_values.{i}.{key,value}, and grow attention_mask by one column of 0.0. The lm_head MatMul that produces logits is divided by config.text_config.logits_scaling = 8 inside the graph; argmax is unaffected, but if you compare logit values against another runtime, divide-by-8 is already baked in.

Tokeniser

We ship tokenizer.json + tokenizer_config.json + chat_template.jinja from ibm-granite/granite-speech-4.1-2b. For Rust, load via tokenizers::Tokenizer::from_file("tokenizer.json") and apply the chat template yourself (the tokeniser crate doesn't render Jinja). The chat template is short

  • a system role, one user role with the <|audio|> placeholder, and a final <|assistant|> opener.

Runtime / EP notes

  • The graphs are ai.onnx-only at opset 20; no com.microsoft.* ops, no MatMulNBits. They load under the ort 2.0-rc.x Rust crate and onnxruntime Python 1.17 - 1.25.
  • CoreML EP: opset 20 contains a few ops that the CoreML EP doesn't kernel. ORT will silently fall back to CPU for those nodes at session-load time. If you want fully-accelerated MPS, FP16w is the closest available analogue (FP16 weights, FP32 compute) - the encoder is mostly conv + matmul which CoreML can take.
  • CUDA EP: works out of the box at FP32 / INT8. FP16w is also fine - the Cast(FP16->FP32) is a cheap kernel.
  • CPU: VNNI / AVX-512 systems get a meaningful win from INT8 (MatMulInteger
    • ConvInteger); systems without those land in roughly the same FP32 ballpark.

How the tiers are produced

  • INT8 is dynamic, weights-only, per-channel QInt8 over MatMul + Conv ops. The quantiser emits MatMulInteger + ConvInteger and leaves activations in FP32. The unquantised ~22% of MatMul nodes in the LLM body graphs are activation x activation (attention QK^T and attention_weights x V); dynamic weight-only INT8 cannot quantise those, so this is the expected ceiling, not a coverage gap.
  • FP16w stores weights as FP16 initializers with a Cast(FP16->FP32) inserted before each consumer, so arithmetic and IO stay FP32. Quality matches FP32 within numeric tolerance at ~50% of FP32 storage.
  • embed_tokens is shipped as its own graph in all three tiers. INT8 uses per-row symmetric quantisation rather than the dynamic MatMul/Conv quantiser (Gather is not in that op set), giving the embedding table its own ~4x storage win at INT8.
  • No com.microsoft.* ops are used. Re-validate the op-domain set with assert_pure_ai_onnx in quantise.py / convert_fp.py after any change.

Parity

Parity is taken against the upstream PyTorch reference on a single LibriSpeech clip (10226_10111_000000.wav, 8.43 seconds, 844 mel frames). FP32 graphs match the reference within numeric tolerance; INT8 graphs are validated in argmax-only mode (logit values shift but token argmax is preserved, so the decoded transcript is unchanged).

Encoder (numeric output, no argmax decoding):

precision max-abs-err mean-abs-err p99-abs-err
FP32 4.48e-06 1.24e-07 6.46e-07
INT8 0.169 0.0109 0.0447

LLM stages (argmax decoding; INT8 logit max-abs delta is large but argmax is preserved):

graph precision max-abs-err argmax mismatches transcript match
prompt_encode FP32 0.000364 0/190 Y
prompt_encode INT8 10.1 58/190 Y
decode_step FP32 n/a 0/51 Y
decode_step INT8 5.76 0/51 Y

Multi-clip transcript parity

Three additional 16 kHz mono clips covering longer utterances (39 to 94 seconds), single and two-speaker conversational content. Word error rate (WER) and Levenshtein edit distance computed against the upstream PyTorch reference. Numbers measured end-to-end through the full ONNX pipeline (no PyTorch encoder fallback).

WER is the strict word-error rate against the PyTorch reference (case + punctuation sensitive). norm WER lower-cases both transcripts and strips punctuation before comparing - the dominant driver of strict WER on this model at INT8 is capitalisation and trailing punctuation drift, not actual word substitution. Pick whichever metric matches your downstream task. FP16w is essentially FP32 quality at 50% of FP32 storage; INT8 is the smallest tier with a mild quality drop.

Clip Duration FP32 byte-exact INT8 byte-exact INT8 WER INT8 norm WER FP16w byte-exact FP16w WER FP16w norm WER
is-it-more-wood 46.9 s Y N 37.4% 0.00% Y 0.0% 0.00%
two-speakers-1 93.8 s Y N 3.1% 1.36% Y 0.0% 0.00%
two-speakers-2 38.8 s Y N 23.5% 0.00% Y 0.0% 0.00%

Raw multi-clip data including full transcripts: see granite_export_metadata.json multi_clip_parity block.

Reference transcript:

After his nap, Timothy lazily stretched, first one gray velvet foot, then another, strolled indolently to his plate, turning over the food, carefully selecting choice bits, nosing out that which he scorned upon the clean hearth

The FP32 and FP16w paths reproduce this transcript exactly on the test clip, and INT8 reproduces it within argmax-only tolerance (token argmax preserved).

Toolchain

  • transformers 5.8.0
  • torch 2.11.0
  • onnx 1.21.0
  • onnxruntime 1.25.1
  • exporter: torch.onnx.export TorchScript path (dynamo=False)
  • opset: 20 (ai.onnx only)
  • IR version: 10
  • external data layout: single <stem>.onnx_data sidecar per graph

Compatibility

Targeted at the ort 2.0-rc.x Rust crate. Compatible with onnxruntime Python 1.17 through 1.25. No com.microsoft ops are used. Graphs were emitted via the TorchScript path (torch.onnx.export(..., dynamo=False)); the dynamo exporter was deliberately avoided because it injects aten::* ops ort does not understand. See the Runtime / EP notes above for CoreML / CUDA / CPU specifics including which precision tier to pick per backend.

Reproducing the export

The included scripts and quantise.py regenerate every artefact in this bundle. The export pipeline writes flat-layout files into exports/<variant>/; the per-tier subdirectory layout you see in this repo is produced by scripts/stage_bundles.py (in the source tree at https://github.com/sammcj/granite-speech-4.1-onnx). From a checkout:

python export_speech_2b_ar.py \
    --model-dir <path-to-ibm-granite/granite-speech-4.1-2b> \
    --out-dir exports/granite-speech-4.1-2b
python export_embed_tokens.py --variant base

# INT8 (per-variant default for 2b: exclude Conv ops in encoder layer 0).
# embed_tokens uses a hand-rolled per-row INT8 path baked into export_embed_tokens.py;
# the standard `quantise.py` (MatMul/Conv only) does not touch Gather ops.
python quantise.py --input exports/granite-speech-4.1-2b/encoder.onnx       --output exports/granite-speech-4.1-2b/encoder_int8.onnx       --exclude-pattern '/encoder/layers\.0/conv/.*'
python quantise.py --input exports/granite-speech-4.1-2b/prompt_encode.onnx --output exports/granite-speech-4.1-2b/prompt_encode_int8.onnx
python quantise.py --input exports/granite-speech-4.1-2b/decode_step.onnx   --output exports/granite-speech-4.1-2b/decode_step_int8.onnx

# FP16w (weights-FP16, FP32 compute - no exclusions needed).
# embed_tokens FP16w is also produced by export_embed_tokens.py in the same run.
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/encoder.onnx       --output exports/granite-speech-4.1-2b/encoder_fp16w.onnx
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/prompt_encode.onnx --output exports/granite-speech-4.1-2b/prompt_encode_fp16w.onnx
python convert_fp.py --precision fp16w --input exports/granite-speech-4.1-2b/decode_step.onnx   --output exports/granite-speech-4.1-2b/decode_step_fp16w.onnx

Licence

Apache 2.0 for both the upstream IBM model and this ONNX export. See LICENSE for the full text.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for smcleod/ibm-granite-speech-4.1-2b-onnx

Quantized
(4)
this model