Instructions to use anerjy/Step-3.7-Flash-MLX-adapter with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use anerjy/Step-3.7-Flash-MLX-adapter with MLX:
# Make sure mlx-vlm is installed # pip install --upgrade mlx-vlm from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config # Load the model model, processor = load("anerjy/Step-3.7-Flash-MLX-adapter") config = load_config("anerjy/Step-3.7-Flash-MLX-adapter") # Prepare input image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] prompt = "Describe this image." # Apply chat template formatted_prompt = apply_chat_template( processor, config, prompt, num_images=1 ) # Generate output output = generate(model, processor, formatted_prompt, image) print(output) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
Step 3.7 Flash β MLX Adapter (text + vision)
First MLX-native runtime adapter for stepfun-ai/Step-3.7-Flash β Step Robotics' multimodal model based on the step3p5 MoE backbone (Qwen3.6-A3B family) with a 47-layer PerceptionEncoder vision tower.
Lets you run Step 3.7 Flash on Apple Silicon through mlx-vlm (multimodal, recommended) or mlx_lm (text-only, fallback).
This is the adapter code only β no model weights are shipped here. You download or quantize Step-3.7-Flash separately (instructions below).
What this includes
| Path | Purpose |
|---|---|
mlx_vlm/models/step3p7/ |
mlx-vlm adapter β 6 files, multimodal (text + vision) |
mlx_lm/models/step3p7.py |
mlx-lm adapter β text-only |
model_dir_patches/chat_template.jinja |
patched Jinja template (multi-crop image expansion + reasoning_effort=none) β drop into your model dir |
scripts/extract_vision_weights.py |
reproducer for the vision-shard pre-transpose step (needed for MLX NHWC convention) |
Architecture
Step 3.7 Flash = step3p5 backbone (45-layer Qwen3.6-A3B MoE, 12 full-attn + 33 sliding-window=512 layers) + 47-layer PerceptionEncoder ViT + 2 stride-2 Conv2d downsamplers + Linear(6144β4096) projector.
Vision pipeline:
- ImagePatcher (pure PIL/numpy port of upstream
ImagePatcher): square-pad extreme aspect ratios, resize somax(W, H) β€ 3024, then either single-view (long β€ 728) or sliding-window 504Β² sub-patches. - Vision encoder dual-pass: global 728Β² β 169 tokens (52Γ52 grid β 2Γ downsample β 13Β² = 169), each 504Β² sub-patch β 81 tokens (36Γ36 grid β 9Β² = 81). Same ViT weights, two grid sizes via interpolated positional embeddings.
- Placeholder layout per image (chat-template + processor):
[<patch_start>{81Γ<im_patch>}<patch_end> + opt <patch_newline>] Γ N + <im_start>{169Γ<im_patch>}<im_end>.encode_imagereturns features in matching order; merge viamx.cumsum+mx.where.
Bugs caught + fixed during the port
MLX
nn.Upsampleapplies 2-tuplescale_factorto dims 1+2 of the input β not the last 2 spatial dims like PyTorch'sF.interpolatedoes on NCHW. So 504Β²β36Γ36 positional embedding interpolation needs NHWC layout natively (no NCHW transpose). The naive port silently shrank the 1536 hidden dim instead of the spatial 52Β² dims, producing a tensor with 1,989,936 elements that couldn't reshape to(1296, 1536).mx.load + mx.save_safetensorsround-trip writes zeros for lazy values. When extracting vision weights from upstream bf16 shards, you mustmx.eval(v)each tensor before storing it in the output dict, or the safetensors file ends up shape-correct but full of zeros. The format=mlxmetadata is required formlx-vlmto load the file correctly.Vision feature cache must be ignored on multi-crop path. Engines that pre-cache vision features via
model.encode_image(pixel_values)only get the global 169 tokens; multi-crop prompts have169 + NΓ81placeholder positions, so the cached features no longer match. The naive zero-pad fallback then trashes every patch slot with zeros. Fix: detectpatch_pixel_valuesinkwargsand recompute everything when present.Vision weight provenance matters more than vision architecture. Step3-VL-10B and Step-3.7-Flash share the same
StepRoboticsVisionEncoderarchitecture, but their projector weights were trained against different text-side LMs (Qwen3-8B vs step3p5 MoE) β using Step3-VL-10B's vision weights with Step-3.7-Flash's LM gives mathematically correct vision features that the LM can't decode (model collapses to "white" / "a man in a suit" for every image). Always pullmodel-vit-00001/00002.safetensorsfrom the matching upstream model.
Installation
Drop the adapter into your mlx_vlm (and optionally mlx_lm) install:
# Discover the install paths from Python β works for pip / uv / conda /
# bundled frameworks alike.
MLX_VLM_DIR=$(python -c "import os, mlx_vlm; print(os.path.dirname(mlx_vlm.__file__))")
MLX_LM_DIR=$(python -c "import os, mlx_lm; print(os.path.dirname(mlx_lm.__file__))")
cp -r mlx_vlm/models/step3p7 "$MLX_VLM_DIR/models/"
cp mlx_lm/models/step3p7.py "$MLX_LM_DIR/models/"
You also need a 1-line patch to mlx_vlm/prompt_utils.py:
# In MESSAGE_FORMATS dict, add:
"step3p7": MessageFormat.LIST_WITH_IMAGE_FIRST,
Preparing the model weights
You will need quantized text shards + 1 vision shard with pre-transposed Conv2d kernels.
# 1. Get Step-3.7-Flash bf16 from upstream (very large)
hf download stepfun-ai/Step-3.7-Flash --local-dir Step-3.7-Flash
# 2. Quantize text shards to 4-bit (or use an existing community quant)
# e.g. https://huggingface.co/<some>/Step-3.7-Flash-4bit/
# 3. Extract vision weights with the NHWC Conv2d transpose applied
python scripts/extract_vision_weights.py \
--src Step-3.7-Flash \
--dst Step-3.7-Flash-4bit/model-00023-of-00023.safetensors
# 4. Patch chat_template.jinja into the model dir
cp model_dir_patches/chat_template.jinja Step-3.7-Flash-4bit/
# 5. Update model.safetensors.index.json to include the new vision shard keys.
# (Index update logic depends on your text quant β see `processing_step3.py` upstream for the canonical layout.)
Usage (via mlx-vlm)
from mlx_vlm import load, apply_chat_template, generate
from PIL import Image
model, processor = load("Step-3.7-Flash-4bit")
img = Image.open("photo.jpg")
prompt = apply_chat_template(processor, model.config, [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "What is in this image?"},
]}
], num_images=1)
output = generate(model, processor, prompt, image=img, max_tokens=200)
print(output)
Or call any OpenAI-compatible MLX engine with model="Step-3.7-Flash-4bit" and image content blocks.
Tested on
- Mac Studio M3 Ultra 512 GB
mlx==0.30.x,mlx-vlm==0.5.0- Text decode: ~44 tok/s warm
- Vision: 4-corner color/shape recognition on 2000Γ1500 images, sun detection on 1920Γ1080 landscape β all correct
- Long-context: 182K-token needle-in-haystack retrieval β correct answer, ~11 min wall (Step 3.7 4-bit is slow-but-correct vs the same generation's Huihui 8-bit Opus-distill at ~4.5 min)
Multi-Token Prediction (MTP) speculative decoding
Step 3.7 ships 3 MTP heads (vLLM Step3p5MTP) trained against the upstream
BF16 backbone. We port the MTPModule + MTPLayer classes to MLX
(mlx_vlm/models/step3p7/language.py) and wire mtp_forward /
make_mtp_cache / return_hidden into both LanguageModel and the outer
Model so the oMLX MTP draft/verify cycle (qwen35 PR 990-derived
batch_generator patch) engages automatically when mtp_enabled: True is
set in the model settings.
Status: works. 73-80% accept rate, lossless at temp=0, ~15% net speedup.
| Config (Step-3.7-Flash-4bit, M3 Ultra, 300 tok warm, 5 iters) | Overall TPS | Accept |
|---|---|---|
MTP off (turboquant_kv: True) |
45.6 | n/a |
| MTP on (per-MTP-layer lm_head + BF16 embed + 7 norm shifts) | 52.6 | 73-80% |
Per-prompt: short_code 60, physics 57, Chinese translation 49 tok/s. Greedy output is byte-identical to MTP-off (verified on 4 prompts).
What it took to get here
The hard part was identifying which weight-loading conventions Hikari07jp's BF16 extraction expected. With the wrong conventions, accept rate stuck at 0-4% (slowdown vs baseline); with the right conventions, 75%.
- Per-MTP-layer
shared_head.outputis NOT tied to backbonelm_headβ biggest finding. Each of the 3 MTP layers in the upstream BF16 ships its ownshared_head.outputweight (different per-layer means 0.000008/0.000029/0.000034). We were initially calling the backbone's 4-bit quantizedlm_headfrommtp_forward, which produced ~0% useful draft tokens. Per-MTP-layer BF16 head jumped accept from 2-4% to 70-80%. Detected by inspectingmodel.layers.{45,46,47}.transformer.shared_head.outputin the BF16 source and noticing the three rows aren't equal. - MTP-side
embed_tokensis also separate from backbone β vLLMStep3p5AMultiTokenPredictorkeeps its ownVocabParallelEmbedding. We add annn.Embeddingdirectly onMTPModuleand route the BF16model.embed_tokens.weightfrom the MTP shard there. (This alone didn't move the needle, but it removes the mixed-precision drift frommtp_forward's input lookup.) - +1 norm shift on extract β vLLM
GemmaRMSNorm.forward_nativedoesweight = stored + 1.0; rms_norm(x, weight, eps). MLXZeroCenteredRMSNormis plainmx.fast.rms_norm(x, weight, eps)(no +1 baked in despite the class name). We pre-add 1 to the 7 Gemma norms in the MTP shard (enorm,hnorm,mtp_block.input_layernorm,mtp_block.post_attention_layernorm,mtp_block.self_attn.q_norm,mtp_block.self_attn.k_norm,shared_head_norm) at offline extract time. - Pre-norm hidden state passed to MTP β
LanguageModel.__call__returns the residual stream beforebody.normwhen called withreturn_hidden=True. vLLMStep3p5Model.forward()does the same. Empirically pre-norm gives 73-80% accept; post-norm gave 1-2%.
The mismatched-precision hypothesis (4-bit backbone vs BF16 MTP) wasn't the bottleneck β BF16 embed_tokens didn't move accept on its own. The hidden state precision wasn't either. The bottleneck was using the wrong lm_head.
How to use
The MTP path activates automatically when these settings are set on
Step-3.7-Flash-4bit (or your equivalent model id) in
your engine's model settings file:
{
"mtp_enabled": true,
"turboquant_kv_enabled": false
}
mtp_enabled and turboquant_kv_enabled are mutually exclusive per oMLX
(TurboQuant patches the attention path that MTP relies on). Bench results
above are with TurboQuant OFF. Restart oMLX after changing settings:
restart your engine process. Verify activation in the log:
[engine] Native MTP patch applied for ... (model_type=step3p7, active)
[batch_generator] MTP path activated for uid=N (model has mtp_forward, batch=1)
[batch_generator] MTP[N] finish=length tokens=200 cycles=117 accept=82/117 (70.1%) ...
To reproduce the shard rewrite from scratch:
hf download Hikari07jp/Step-3.7-Flash-MTP-draft --local-dir /tmp/mtp-src
python scripts/rewrite_mtp_shard.py \
/tmp/mtp-src/model.safetensors \
Step-3.7-Flash-4bit/model-00024-of-00024.safetensors
# Then update model.safetensors.index.json so the 52 new keys point
# at the new shard filename (49 layer + 3 shared_head_output entries).
Acknowledgements
- Hikari07jp/Step-3.7-Flash-MTP-draft
β extracted BF16 MTP-draft layers from upstream
stepfun-ai/Step-3.7-Flash; the same weights run at ~80% accept in vLLM. - oMLX β
patches/mlx_lm_mtp/ports mlx-lm PR 990's MTP draft+verify batch_generator dispatch to oMLX's continuous-batching scheduler. We addstep3p5/step3p7to its compatible list. - mlx-optiq/Gemma 4 spec dec post β the structurally identical 0%β3%β33% accept-rate fix narrative for Gemma 4 spec dec on MLX was a useful template for diagnosing this one.
To experiment, set mtp_enabled: True and turboquant_kv_enabled: False
(they're mutually exclusive per oMLX) in your engine's model settings file
for Step-3.7-Flash-4bit, restart oMLX, and check
grep "MTP\[" <your-engine-log> for accept rate. To run the
shard rewrite yourself:
hf download Hikari07jp/Step-3.7-Flash-MTP-draft --local-dir /tmp/mtp-src
python scripts/rewrite_mtp_shard.py \
/tmp/mtp-src/model.safetensors \
Step-3.7-Flash-4bit/model-00024-of-00024.safetensors
# Then update model.safetensors.index.json so the 48 new keys point
# at the new shard filename.
The MLX_LM_MTP patch from oMLX 0.3.10+ exposes its own
_is_mtp_compatible allowlist; ours patches that list to recognise
step3p5 and step3p7 model_types (see
patches/mlx_lm_mtp_compat.diff for the one-line addition).
License
Adapter code: Apache-2.0 (see LICENSE). Upstream Step-3.7-Flash weights are licensed separately by Step Robotics β see their model card before redistributing.
Credits
Architecture port + bug catches were done end-to-end across 7 sessions of MLX debugging on Apple Silicon, with PyTorch reference parity tests used to localize the four bugs above. References that helped (none of them validate end-to-end vision on Step-3.7-Flash specifically):
monyschuk/Huihui-Step3-VL-10B-abliterated-mlxβ MLX port for Step3-VL-10B (omits LayerScale + positional embedding, incompatible with stock weights)AITRADER/Huihui-Step3-VL-10B-abliterated-mlx-fp16β explicitly marks vision inference "out of scope"- Upstream
stepfun-ai/Step3-VL-10Bβ PyTorch reference for the vision encoder
- Downloads last month
- 140
Model tree for anerjy/Step-3.7-Flash-MLX-adapter
Base model
stepfun-ai/Step-3.7-Flash