TinyStories-10M-JAX

A ~14.5M-parameter (≈6.3M non-embedding) decoder-only transformer trained from scratch in JAX / Flax NNX on TinyStories, reproducing the setup of Eldan & Li (2023, arXiv:2305.07759).

Results — held-out TinyStories validation (4.64M tokens)

metric value
val loss (nats/token) 1.680
perplexity 5.364
bits/token 2.423

Architecture

Modern Llama/Mistral primitives, scaled down:

  • d_model 256 · 6 layers · 8 heads · d_ff 1024 · context 512
  • RoPE · RMSNorm · SwiGLU FFN · tied input/output embeddings
  • vocab 32,000 (byte-level BPE trained on TinyStories)

Training

  • AdamW (β 0.9/0.95, wd 0.1), grad-clip 1.0
  • 1k-step warmup → cosine decay, peak LR 6e-4
  • 20,000 steps · batch 32 · context 512 · single Colab T4

Usage

Weights are in model.safetensors. Reconstruct with the model code from the GitHub repo and load_safetensors() (see sample.py). Tokenizer: tokenizer.json.

Limitations

Trained only on synthetic children's stories — coherent short English narratives, weak long-range consistency, no factual/world knowledge. Not for general use.

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
14.5M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Zayed024/tinystories-10m-jax

Paper for Zayed024/tinystories-10m-jax