gemma4-jax — Flax NNX port of google/gemma-4-12B (full multimodal)

A from-scratch, faithful Flax NNX implementation of Google's Gemma 4 12B "Unified" model — the text decoder and the encoder-free vision + audio embedders — plus a HuggingFace safetensors → NNX weight converter.

This repository hosts the port (code). It does not redistribute the weights — at load time it reads the official google/gemma-4-12B safetensors directly (the tensor names match; nnx.Linear kernels are just transposed). You must accept the Gemma terms on the base-model page to download the weights.

Why

Gemma 4 12B is not (yet) in the official google-deepmind/gemma JAX library. This port makes the architecture an open, editable NNX module — so the global attention layers, the encoder-free projectors, and the RoPE/norm details are a clean swap point for research (e.g. linear-attention surgery on the 8 global layers).

What's implemented — verified exact to the parameter

Component State
Text decoder (48 layers, dual sliding/global attention)
Dual attention: sliding GQA (hd 256, 8 KV) / global MQA (hd 512, 1 KV)
attention_k_eq_v (V reuses pre-norm K on full layers, no RoPE on V)
Per-head QK-norm, sandwich norm, layer_scalar, embed scaling, logit softcap
Proportional / partial RoPE (zeroed-tail inv_freq) + default RoPE
Vision (encoder-free: raw 48×48×3 patches → LN→Dense→LN→+2D-posemb→norm→proj)
Audio (encoder-free: raw 640-sample frames → RMSNorm→proj)
Multimodal splice (soft-token scatter + bidirectional-vision mask)
safetensors → NNX converter (text + vision + audio)
KV cache for fast decode ❌ reference loop recomputes prefix

Param count matches the published checkpoint exactly: text 11,907,350,320 + multimodal 52,379,904 = 11,959,730,224 (0 diff). The multimodal converter is verified against the real safetensors; text and multimodal smoke tests pass (forward, causality drift 0, softcap bounds, splice).

Non-obvious details baked in (vs. Gemma 2/3)

  • RMSNorm is plain x·w (not (1+w)); eps inside the rsqrt; fp32 internals.
  • Attention scaling = 1.0 — magnitude set by per-head q_norm, not 1/√d.
  • k_eq_v: global layers have no v_proj; V reuses the K projection output (pre-norm, pre-RoPE) + a scale-free v_norm, and V is not rotated.
  • Proportional RoPE on global layers: full-length inv_freq with only the first 64 frequencies nonzero (NoPE tail), base 1e6.
  • Vision is encoder-free: no SigLIP; raw merged pixel patches project straight into the 3840-d decoder space. Audio likewise — no mel/conformer.

Usage

pip install jax flax safetensors huggingface_hub tokenizers
git clone https://github.com/mlnomadpy/gemma4-jax && cd gemma4-jax
import jax.numpy as jnp
from gemma4_jax.convert import unified_from_safetensors
from gemma4_jax.config import IMAGE_TOKEN_ID

# point at the official google/gemma-4-12B model.safetensors (accept terms first)
uni = unified_from_safetensors("path/to/model.safetensors")

# text
logits = uni.logits(input_ids)                       # [B, S, vocab]

# vision: pixel_values [B,P,6912], image_position_ids [B,P,2]
soft = uni.get_image_features(pixel_values, image_position_ids)   # [B,P,3840]
h = uni(input_ids, pixel_values=pixel_values, image_position_ids=image_position_ids)

# audio: input_features [B,T,640]
h = uni(input_ids, input_features=input_features)

The HF image processor (patchify + 3×3 pool + position ids) and audio feature extractor are not ported — feed pre-patchified pixel_values / pre-framed input_features exactly as the HF processors emit them.

License

This code is a clean-room reimplementation. The weights it loads are Google's Gemma 4, governed by the Gemma Terms of Use; your use of the weights is subject to those terms and the Gemma Prohibited Use Policy. See the base model.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for mlnomad/gemma4-jax

Finetuned
(15)
this model