gemma4-jax — Flax NNX port of google/gemma-4-12B (full multimodal)
A from-scratch, faithful Flax NNX implementation of Google's Gemma 4 12B
"Unified" model — the text decoder and the encoder-free vision + audio
embedders — plus a HuggingFace safetensors → NNX weight converter.
- 💻 Code / GitHub: https://github.com/mlnomadpy/gemma4-jax
- 🧬 Base model:
google/gemma-4-12B
This repository hosts the port (code). It does not redistribute the
weights — at load time it reads the official google/gemma-4-12B safetensors
directly (the tensor names match; nnx.Linear kernels are just transposed). You
must accept the Gemma terms on the base-model page to download the weights.
Why
Gemma 4 12B is not (yet) in the official google-deepmind/gemma JAX library.
This port makes the architecture an open, editable NNX module — so the global
attention layers, the encoder-free projectors, and the RoPE/norm details are a
clean swap point for research (e.g. linear-attention surgery on the 8 global
layers).
What's implemented — verified exact to the parameter
| Component | State |
|---|---|
| Text decoder (48 layers, dual sliding/global attention) | ✅ |
| Dual attention: sliding GQA (hd 256, 8 KV) / global MQA (hd 512, 1 KV) | ✅ |
attention_k_eq_v (V reuses pre-norm K on full layers, no RoPE on V) |
✅ |
Per-head QK-norm, sandwich norm, layer_scalar, embed scaling, logit softcap |
✅ |
| Proportional / partial RoPE (zeroed-tail inv_freq) + default RoPE | ✅ |
| Vision (encoder-free: raw 48×48×3 patches → LN→Dense→LN→+2D-posemb→norm→proj) | ✅ |
| Audio (encoder-free: raw 640-sample frames → RMSNorm→proj) | ✅ |
| Multimodal splice (soft-token scatter + bidirectional-vision mask) | ✅ |
safetensors → NNX converter (text + vision + audio) |
✅ |
| KV cache for fast decode | ❌ reference loop recomputes prefix |
Param count matches the published checkpoint exactly: text
11,907,350,320 + multimodal 52,379,904 = 11,959,730,224 (0 diff).
The multimodal converter is verified against the real safetensors; text and
multimodal smoke tests pass (forward, causality drift 0, softcap bounds, splice).
Non-obvious details baked in (vs. Gemma 2/3)
- RMSNorm is plain
x·w(not(1+w)); eps inside the rsqrt; fp32 internals. - Attention
scaling = 1.0— magnitude set by per-headq_norm, not1/√d. k_eq_v: global layers have nov_proj; V reuses the K projection output (pre-norm, pre-RoPE) + a scale-freev_norm, and V is not rotated.- Proportional RoPE on global layers: full-length inv_freq with only the first 64 frequencies nonzero (NoPE tail), base 1e6.
- Vision is encoder-free: no SigLIP; raw merged pixel patches project straight into the 3840-d decoder space. Audio likewise — no mel/conformer.
Usage
pip install jax flax safetensors huggingface_hub tokenizers
git clone https://github.com/mlnomadpy/gemma4-jax && cd gemma4-jax
import jax.numpy as jnp
from gemma4_jax.convert import unified_from_safetensors
from gemma4_jax.config import IMAGE_TOKEN_ID
# point at the official google/gemma-4-12B model.safetensors (accept terms first)
uni = unified_from_safetensors("path/to/model.safetensors")
# text
logits = uni.logits(input_ids) # [B, S, vocab]
# vision: pixel_values [B,P,6912], image_position_ids [B,P,2]
soft = uni.get_image_features(pixel_values, image_position_ids) # [B,P,3840]
h = uni(input_ids, pixel_values=pixel_values, image_position_ids=image_position_ids)
# audio: input_features [B,T,640]
h = uni(input_ids, input_features=input_features)
The HF image processor (patchify + 3×3 pool + position ids) and audio feature
extractor are not ported — feed pre-patchified pixel_values / pre-framed
input_features exactly as the HF processors emit them.
License
This code is a clean-room reimplementation. The weights it loads are Google's Gemma 4, governed by the Gemma Terms of Use; your use of the weights is subject to those terms and the Gemma Prohibited Use Policy. See the base model.
Model tree for mlnomad/gemma4-jax
Base model
google/gemma-4-12B