--- license: mit tags: - language-model - multi-token-prediction - push-forward-language-model - text-generation - distillation datasets: - lm1b - openwebtext arxiv: "2606.10820" --- # K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling

arXiv GitHub

## Overview K-Forcing distills an autoregressive (AR) language model into a **push-forward language model (PFLM)** that generates **k tokens in one forward pass**. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving. **Key results**: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers. ## Checkpoints This repository contains four checkpoints: | File | Model | Dataset | Parameters | Description | |------|-------|---------|------------|-------------| | `ar_openwebtxt.ckpt` | AR | OpenWebText | ~100M | Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024) | | `ar_best_lm1b.ckpt` | AR | LM1B | ~100M | Autoregressive teacher model (custom tokenizer, seq_len=128) | | `pflm_owt_k4.ckpt` | PFLM (k=4) | OpenWebText | ~100M | Push-forward LM, decodes 4 tokens per forward pass | | `pflm_lm1b_k4.ckpt` | PFLM (k=4) | LM1B | ~100M | Push-forward LM, decodes 4 tokens per forward pass | All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from [MDLM](https://arxiv.org/abs/2406.07524) (Sahoo et al., 2024). ## Download ```python from huggingface_hub import hf_hub_download # Download a specific checkpoint ckpt_path = hf_hub_download( repo_id="zwave/K-Forcing", filename="pflm_owt_k4.ckpt", # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt ) ``` Or download all checkpoints at once: ```python from huggingface_hub import snapshot_download snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints") ``` Or via CLI: ```bash huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints ``` ## Usage Clone the [K-Forcing repository](https://github.com/alibaba-damo-academy/K-Forcing) and follow setup instructions there: ```bash git clone https://github.com/alibaba-damo-academy/K-Forcing.git cd K-Forcing # Setup environment mkdir -p wheels wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl uv sync ``` ### AR Inference ```bash python batch_inference_with_prefix.py \ --model ar --task owt \ --ckpt_path ./checkpoints/ar_openwebtxt.ckpt \ --prefix_file assets/prefix_owt_examples.jsonl \ --batch_size 4 --n_per_prefix 1 ``` ### PFLM Inference (K=2 tokens per forward pass) ```bash python batch_inference_with_prefix.py \ --model pflm --task owt \ --ckpt_path ./checkpoints/pflm_owt_k4.ckpt \ --prefix_file assets/prefix_owt_examples.jsonl \ --batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3 ``` The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4. ## Architecture - **Backbone**: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads - **Noise encoder**: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding - **Fully causal design**: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ - **Shared prediction head**: same linear head as AR, applied at each noise-token position - **Training**: progressive self-forcing distillation (AR → k=1 → k=2 → k=4) ## Citation ```bibtex @misc{tang2026kforcingjointnextktokendecoding, title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling}, author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang}, year={2026}, eprint={2606.10820}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2606.10820}, } ``` ## License This project is licensed under the MIT License.