File size: 4,427 Bytes
58dd072 4fa0566 58dd072 4fa0566 fd4194c 4fa0566 fd4194c 4fa0566 fd4194c 4fa0566 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | ---
license: mit
tags:
- language-model
- multi-token-prediction
- push-forward-language-model
- text-generation
- distillation
datasets:
- lm1b
- openwebtext
arxiv: "2606.10820"
---
# K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling
<p align="center">
<a href="https://arxiv.org/abs/2606.10820"><img src="https://img.shields.io/badge/arXiv-2606.10820-b31b1b.svg" alt="arXiv"></a>
<a href="https://github.com/alibaba-damo-academy/K-Forcing"><img src="https://img.shields.io/badge/GitHub-Code-blue?logo=github" alt="GitHub"></a>
</p>
## Overview
K-Forcing distills an autoregressive (AR) language model into a **push-forward language model (PFLM)** that generates **k tokens in one forward pass**. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving.
**Key results**: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers.
## Checkpoints
This repository contains four checkpoints:
| File | Model | Dataset | Parameters | Description |
|------|-------|---------|------------|-------------|
| `ar_openwebtxt.ckpt` | AR | OpenWebText | ~100M | Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024) |
| `ar_best_lm1b.ckpt` | AR | LM1B | ~100M | Autoregressive teacher model (custom tokenizer, seq_len=128) |
| `pflm_owt_k4.ckpt` | PFLM (k=4) | OpenWebText | ~100M | Push-forward LM, decodes 4 tokens per forward pass |
| `pflm_lm1b_k4.ckpt` | PFLM (k=4) | LM1B | ~100M | Push-forward LM, decodes 4 tokens per forward pass |
All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from [MDLM](https://arxiv.org/abs/2406.07524) (Sahoo et al., 2024).
## Download
```python
from huggingface_hub import hf_hub_download
# Download a specific checkpoint
ckpt_path = hf_hub_download(
repo_id="zwave/K-Forcing",
filename="pflm_owt_k4.ckpt", # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt
)
```
Or download all checkpoints at once:
```python
from huggingface_hub import snapshot_download
snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints")
```
Or via CLI:
```bash
huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints
```
## Usage
Clone the [K-Forcing repository](https://github.com/alibaba-damo-academy/K-Forcing) and follow setup instructions there:
```bash
git clone https://github.com/alibaba-damo-academy/K-Forcing.git
cd K-Forcing
# Setup environment
mkdir -p wheels
wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
uv sync
```
### AR Inference
```bash
python batch_inference_with_prefix.py \
--model ar --task owt \
--ckpt_path ./checkpoints/ar_openwebtxt.ckpt \
--prefix_file assets/prefix_owt_examples.jsonl \
--batch_size 4 --n_per_prefix 1
```
### PFLM Inference (K=2 tokens per forward pass)
```bash
python batch_inference_with_prefix.py \
--model pflm --task owt \
--ckpt_path ./checkpoints/pflm_owt_k4.ckpt \
--prefix_file assets/prefix_owt_examples.jsonl \
--batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3
```
The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4.
## Architecture
- **Backbone**: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads
- **Noise encoder**: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding
- **Fully causal design**: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ
- **Shared prediction head**: same linear head as AR, applied at each noise-token position
- **Training**: progressive self-forcing distillation (AR → k=1 → k=2 → k=4)
## Citation
```bibtex
@misc{tang2026kforcingjointnextktokendecoding,
title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling},
author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang},
year={2026},
eprint={2606.10820},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2606.10820},
}
```
## License
This project is licensed under the MIT License.
|