| --- |
| license: mit |
| tags: |
| - language-model |
| - multi-token-prediction |
| - push-forward-language-model |
| - text-generation |
| - distillation |
| datasets: |
| - lm1b |
| - openwebtext |
| arxiv: "2606.10820" |
| --- |
| |
| # K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling |
|
|
| <p align="center"> |
| <a href="https://arxiv.org/abs/2606.10820"><img src="https://img.shields.io/badge/arXiv-2606.10820-b31b1b.svg" alt="arXiv"></a> |
| <a href="https://github.com/alibaba-damo-academy/K-Forcing"><img src="https://img.shields.io/badge/GitHub-Code-blue?logo=github" alt="GitHub"></a> |
| </p> |
|
|
| ## Overview |
|
|
| K-Forcing distills an autoregressive (AR) language model into a **push-forward language model (PFLM)** that generates **k tokens in one forward pass**. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving. |
|
|
| **Key results**: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers. |
|
|
| ## Checkpoints |
|
|
| This repository contains four checkpoints: |
|
|
| | File | Model | Dataset | Parameters | Description | |
| |------|-------|---------|------------|-------------| |
| | `ar_openwebtxt.ckpt` | AR | OpenWebText | ~100M | Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024) | |
| | `ar_best_lm1b.ckpt` | AR | LM1B | ~100M | Autoregressive teacher model (custom tokenizer, seq_len=128) | |
| | `pflm_owt_k4.ckpt` | PFLM (k=4) | OpenWebText | ~100M | Push-forward LM, decodes 4 tokens per forward pass | |
| | `pflm_lm1b_k4.ckpt` | PFLM (k=4) | LM1B | ~100M | Push-forward LM, decodes 4 tokens per forward pass | |
|
|
| All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from [MDLM](https://arxiv.org/abs/2406.07524) (Sahoo et al., 2024). |
|
|
| ## Download |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| |
| # Download a specific checkpoint |
| ckpt_path = hf_hub_download( |
| repo_id="zwave/K-Forcing", |
| filename="pflm_owt_k4.ckpt", # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt |
| ) |
| ``` |
|
|
| Or download all checkpoints at once: |
| ```python |
| from huggingface_hub import snapshot_download |
| |
| snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints") |
| ``` |
|
|
| Or via CLI: |
| ```bash |
| huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints |
| ``` |
|
|
| ## Usage |
|
|
| Clone the [K-Forcing repository](https://github.com/alibaba-damo-academy/K-Forcing) and follow setup instructions there: |
|
|
| ```bash |
| git clone https://github.com/alibaba-damo-academy/K-Forcing.git |
| cd K-Forcing |
| |
| # Setup environment |
| mkdir -p wheels |
| wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl |
| uv sync |
| ``` |
|
|
| ### AR Inference |
|
|
| ```bash |
| python batch_inference_with_prefix.py \ |
| --model ar --task owt \ |
| --ckpt_path ./checkpoints/ar_openwebtxt.ckpt \ |
| --prefix_file assets/prefix_owt_examples.jsonl \ |
| --batch_size 4 --n_per_prefix 1 |
| ``` |
|
|
| ### PFLM Inference (K=2 tokens per forward pass) |
|
|
| ```bash |
| python batch_inference_with_prefix.py \ |
| --model pflm --task owt \ |
| --ckpt_path ./checkpoints/pflm_owt_k4.ckpt \ |
| --prefix_file assets/prefix_owt_examples.jsonl \ |
| --batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3 |
| ``` |
|
|
| The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4. |
|
|
| ## Architecture |
|
|
| - **Backbone**: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads |
| - **Noise encoder**: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding |
| - **Fully causal design**: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ |
| - **Shared prediction head**: same linear head as AR, applied at each noise-token position |
| - **Training**: progressive self-forcing distillation (AR → k=1 → k=2 → k=4) |
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{tang2026kforcingjointnextktokendecoding, |
| title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling}, |
| author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang}, |
| year={2026}, |
| eprint={2606.10820}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.LG}, |
| url={https://arxiv.org/abs/2606.10820}, |
| } |
| ``` |
|
|
| ## License |
|
|
| This project is licensed under the MIT License. |
|
|