File size: 4,427 Bytes
58dd072
4fa0566
 
 
 
 
 
 
 
 
 
 
58dd072
4fa0566
 
 
 
 
fd4194c
4fa0566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4194c
4fa0566
 
fd4194c
4fa0566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
---
license: mit
tags:
  - language-model
  - multi-token-prediction
  - push-forward-language-model
  - text-generation
  - distillation
datasets:
  - lm1b
  - openwebtext
arxiv: "2606.10820"
---

# K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling

<p align="center">
  <a href="https://arxiv.org/abs/2606.10820"><img src="https://img.shields.io/badge/arXiv-2606.10820-b31b1b.svg" alt="arXiv"></a>
  <a href="https://github.com/alibaba-damo-academy/K-Forcing"><img src="https://img.shields.io/badge/GitHub-Code-blue?logo=github" alt="GitHub"></a>
</p>

## Overview

K-Forcing distills an autoregressive (AR) language model into a **push-forward language model (PFLM)** that generates **k tokens in one forward pass**. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving.

**Key results**: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers.

## Checkpoints

This repository contains four checkpoints:

| File | Model | Dataset | Parameters | Description |
|------|-------|---------|------------|-------------|
| `ar_openwebtxt.ckpt` | AR | OpenWebText | ~100M | Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024) |
| `ar_best_lm1b.ckpt` | AR | LM1B | ~100M | Autoregressive teacher model (custom tokenizer, seq_len=128) |
| `pflm_owt_k4.ckpt` | PFLM (k=4) | OpenWebText | ~100M | Push-forward LM, decodes 4 tokens per forward pass |
| `pflm_lm1b_k4.ckpt` | PFLM (k=4) | LM1B | ~100M | Push-forward LM, decodes 4 tokens per forward pass |

All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from [MDLM](https://arxiv.org/abs/2406.07524) (Sahoo et al., 2024).

## Download

```python
from huggingface_hub import hf_hub_download

# Download a specific checkpoint
ckpt_path = hf_hub_download(
    repo_id="zwave/K-Forcing",
    filename="pflm_owt_k4.ckpt",  # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt
)
```

Or download all checkpoints at once:
```python
from huggingface_hub import snapshot_download

snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints")
```

Or via CLI:
```bash
huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints
```

## Usage

Clone the [K-Forcing repository](https://github.com/alibaba-damo-academy/K-Forcing) and follow setup instructions there:

```bash
git clone https://github.com/alibaba-damo-academy/K-Forcing.git
cd K-Forcing

# Setup environment
mkdir -p wheels
wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
uv sync
```

### AR Inference

```bash
python batch_inference_with_prefix.py \
    --model ar --task owt \
    --ckpt_path ./checkpoints/ar_openwebtxt.ckpt \
    --prefix_file assets/prefix_owt_examples.jsonl \
    --batch_size 4 --n_per_prefix 1
```

### PFLM Inference (K=2 tokens per forward pass)

```bash
python batch_inference_with_prefix.py \
    --model pflm --task owt \
    --ckpt_path ./checkpoints/pflm_owt_k4.ckpt \
    --prefix_file assets/prefix_owt_examples.jsonl \
    --batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3
```

The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4.

## Architecture

- **Backbone**: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads
- **Noise encoder**: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding
- **Fully causal design**: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ
- **Shared prediction head**: same linear head as AR, applied at each noise-token position
- **Training**: progressive self-forcing distillation (AR → k=1 → k=2 → k=4)

## Citation

```bibtex
@misc{tang2026kforcingjointnextktokendecoding,
      title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling},
      author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang},
      year={2026},
      eprint={2606.10820},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2606.10820},
}
```

## License

This project is licensed under the MIT License.