Keiro

Domain-Adaptive Sparse Mixture-of-Experts Injection for Qwen2.5-3B

A Top-2 dynamic router activates 2 of 8 LoRA experts per transformer block β€” expanding effective capacity while keeping active compute identical to the dense baseline

Open In Colab License Base Model Architecture Precision

This is an experimental proof-of-concept release aimed at validating the MoE injection and routing mechanism. A future release will train on a significantly larger and more diverse dataset using a larger base model.

Overview

Keiro retrofits a Sparse Mixture-of-Experts architecture into Qwen2.5-3B. Every transformer block's MLP is replaced by a SparseMoELayer: at each token, a LinearRouter scores all 8 experts and selects the top-2 by softmax probability. Expert outputs are LoRA residuals applied on top of the frozen base Qwen2MLP, preserving dense representations while allowing experts to learn specialised corrections.

Switch loss combined with additive noise encourages uniform expert utilisation and prevents router collapse.

Model Details

Property Value
Base model Qwen/Qwen2.5-3B
Architecture Sparse Mixture-of-Experts (LoRA residual)
Precision BFloat16
d_model 2048
Layers injected 36 transformer blocks (all)
Experts per layer 8 total β€” Top-2 active per forward pass
Expert type BatchedLoRAExperts
Expert capacity 1.25
LoRA rank 16
LoRA scaling 2.0
Expert dropout 0.05
Router LinearRouter with switch loss
Router noise 0.1
Router entropy regularisation 0.01
Trainable parameters ~19.46 M (0.63% of total)
Chat template ChatML

Evaluation

Benchmarked against the dense Qwen2.5-3B baseline using EleutherAI lm-evaluation-harness.

Benchmark Report

Benchmark Dense Baseline Keiro (MoE) Ξ”
HellaSwag (10-shot) 74.60% 74.47% βˆ’0.13%
ARC-Challenge (25-shot) 56.57% 56.40% βˆ’0.17%
GSM8K (5-shot, exact match) 69.83% 66.64% βˆ’3.19%

Key takeaways

  • The βˆ’0.13% and βˆ’0.17% deltas on HellaSwag and ARC-Challenge confirm the router did not trigger catastrophic forgetting on linguistic or factual reasoning pathways.
  • The 3.19% drop on GSM8K β€” retaining 95.4% of math reasoning capability β€” is the primary validation of the architecture: a dynamic routing mechanism can be retrofitted into a dense transformer without breaking multi-step autoregressive reasoning chains.
Summary metric Value
Trainable parameters 19.46 M (0.63% of total)
Avg. knowledge delta (HellaSwag + ARC) βˆ’0.15%
Math reasoning retained (GSM8K) 95.4%

Architecture Deep Dive

Dual-Dispatch System

The SparseMoELayer uses two distinct forward paths optimised for different regimes:

Path When Strategy
_forward_padded_loop Inside torch.compile Compile-friendly padded loop over experts with capacity buffers. Compatible with torch.compile static graph tracing.
_forward_dynamic Eager mode (default) Vectorised scatter/gather using torch.bmm on only the active expert subset. Bypasses capacity buffers entirely.

The runtime dispatch is automatic β€” _batched_sparse_forward checks torch.compiler.is_compiling() and routes accordingly. During normal model.generate() (eager mode), the fast dynamic path is always used.

Single-Token Fast Path

The dynamic path was engineered specifically to solve the kernel launch bottleneck observed during autoregressive generation:

  • Problem: The original vectorised path built a (num_experts, capacity, d_model) buffer and ran torch.bmm across all 8 experts. For a single token with top-2 routing, 6 of 8 expert slots were zero-filled β€” wasting 75% of compute on no-op matrix multiplications.
  • Solution: Use torch.unique to identify only the 2 active experts, gather their specific lora_A and lora_B weight slices, and run a targeted torch.bmm on a (2, 1, d_model) tensor.
  • Result: CUDA kernel launches per layer dropped from ~12 to ~3. GPU utilisation during generation improved from 34% to 72%+.

Expert Architecture

BatchedLoRAExperts stores all expert weights in stacked tensors lora_A: (num_experts, rank, d_model) and lora_B: (num_experts, d_model, rank), enabling batched matrix multiplications instead of sequential expert loops. The LoRA residual is added on top of the frozen Qwen2MLP output:

output = frozen_FFN(x) + scaling * (x @ lora_A.T @ lora_B.T)

where scaling = lora_alpha / lora_rank (32 / 16 = 2.0).

Usage

Keiro uses a custom routing architecture. Architecture source files must be downloaded alongside the weights before the model can be loaded. Do not attempt to load this model using a standard AutoModelForCausalLM.from_pretrained() call directly against the repo.

Installation

pip install huggingface_hub transformers torch accelerate

Step 1 β€” Download weights and architecture

import os, sys, torch
from huggingface_hub import hf_hub_download

REPO_ID   = "iamrahulreddy/Keiro"
CACHE_DIR = "./keiro_cache"
device    = "cuda" if torch.cuda.is_available() else "cpu"

arch_files = [
    "architecture/sparse_moe/__init__.py",
    "architecture/sparse_moe/analysis.py",
    "architecture/sparse_moe/baselines.py",
    "architecture/sparse_moe/config.py",
    "architecture/sparse_moe/datasets.py",
    "architecture/sparse_moe/evaluation.py",
    "architecture/sparse_moe/experts.py",
    "architecture/sparse_moe/injection.py",
    "architecture/sparse_moe/layers.py",
    "architecture/sparse_moe/prompts.py",
    "architecture/sparse_moe/reporting.py",
    "architecture/sparse_moe/routing.py",
    "architecture/sparse_moe/stage_runner.py",
    "architecture/sparse_moe/sys_profiler.py",
    "architecture/sparse_moe/trainer.py",
    "architecture/sparse_moe/utils.py",
    "architecture/sparse_moe/visualization.py",
]

for filepath in arch_files:
    hf_hub_download(repo_id=REPO_ID, filename=filepath, local_dir=CACHE_DIR)

weights_path = hf_hub_download(repo_id=REPO_ID, filename="moe-best.pt", local_dir=CACHE_DIR)

Step 2 β€” Register architecture on sys.path

arch_root = os.path.abspath(os.path.join(CACHE_DIR, "architecture"))

sys.path = [p for p in sys.path if isinstance(p, str)]
if arch_root not in sys.path:
    sys.path.insert(0, arch_root)

from sparse_moe import ProjectConfig, RouterConfig, ExpertConfig, inject_moe_layers

Step 3 β€” Load base model and tokenizer

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer  = AutoTokenizer.from_pretrained(REPO_ID)
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-3B",
    dtype=torch.bfloat16,
    device_map=device,
)

Step 4 β€” Inject MoE layers and load Keiro weights

moe_cfg = ProjectConfig(
    routing=RouterConfig(num_experts=8, top_k=2),
    expert=ExpertConfig(lora_rank=16),
)

moe_model = inject_moe_layers(
    base_model,
    expert_config=moe_cfg.expert,
    router_config=moe_cfg.routing,
)

state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
moe_model.load_state_dict(state_dict, strict=False)
moe_model.to(device).eval()

Step 5 β€” Run inference

def generate_response(prompt: str, max_new_tokens: int = 256) -> str:
    messages   = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|im_end|>"),
    ]

    with torch.no_grad():
        output_ids = moe_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            eos_token_id=terminators,
        )

    new_tokens = output_ids[0][inputs.input_ids.shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


print(generate_response("Explain the core benefits of a Mixture-of-Experts architecture."))

Training Data

Keiro was trained on a small mixed dataset of approximately 500 samples spanning multiple domains, assembled from three sources.

Source Content
WikiText Sampled factual and encyclopedic prose
Alpaca-LoRA Instruction-following and reasoning pairs
Synthetic data Custom-generated domain-diverse examples

Hardware Requirements

Configuration Minimum VRAM
BF16 inference (recommended) 8 GB
FP32 inference 16 GB
Fine-tuning (batch=16, gradient checkpointing) 24 GB

Limitations & Known Issues

  • Repetition collapse under greedy decoding: Without repetition_penalty, the MoE model occasionally enters repetition loops on open-ended generation tasks. Use repetition_penalty=1.1 to mitigate this. (Expect a decreased capability in math, reasoning and logic.)
  • Expert load imbalance in early layers: Layers 0–5 exhibit mild expert collapse where 2–3 experts handle the majority of tokens. This is a known pathology β€” early transformer layers have less differentiated hidden states for the router to distinguish. Increasing aux_loss_weight from 0.05 β†’ 0.10 and training on larger datasets reduces this effect.
  • Inference overhead vs dense: Sparse routing adds per-token overhead from router scoring and expert dispatch. The dual-dispatch system minimises this, but MoE inference will always be moderately slower than a pure dense forward pass of equivalent active parameters.
  • Small training dataset: This release was trained on ~500 samples/domain as a proof-of-concept. I am working on a new release with 10k+ samples and multi-epoch training that would yield stronger domain specialisation. I will release it in future versions.

Repository Structure

iamrahulreddy/Keiro/
β”œβ”€β”€ moe-best.pt                    # Fine-tuned MoE weights
β”œβ”€β”€ tokenizer_config.json          # Qwen2.5-3B tokenizer with ChatML template
β”œβ”€β”€ benchmark_report.png
└── architecture/
    └── sparse_moe/
        β”œβ”€β”€ __init__.py            # Public API: ProjectConfig, inject_moe_layers
        β”œβ”€β”€ config.py              # ProjectConfig, RouterConfig, ExpertConfig
        β”œβ”€β”€ experts.py             # BatchedLoRAExperts (stacked weight tensors)
        β”œβ”€β”€ injection.py           # inject_moe_layers() entry point
        β”œβ”€β”€ layers.py              # SparseMoELayer with dual-dispatch forward
        β”œβ”€β”€ routing.py             # LinearRouter with switch loss + noise
        β”œβ”€β”€ trainer.py             # MoETrainer with early stopping
        β”œβ”€β”€ evaluation.py          # Per-domain perplexity & MRR metrics
        β”œβ”€β”€ visualization.py       # Dashboard plots & expert heatmaps
        β”œβ”€β”€ analysis.py            # Expert specialisation (JS divergence)
        β”œβ”€β”€ baselines.py           # Dense LoRA baseline for fair comparison
        └── ...

Technical References

Credits

  • PyTorch Team: For the world-class documentation and foundational deep learning primitives used to build the custom MoE layers.
  • Qwen Team (Alibaba Cloud): For open-sourcing Qwen2.5-3B, providing an exceptional dense foundation for sparse injection experiments.
  • Hugging Face: For the transformers, datasets, and hub ecosystems that streamline model orchestration.
  • EleutherAI: For the lm-evaluation-harness used for few-shot benchmarking.

Contributing

If you encounter any inconsistencies, technical errors, or issues, please feel free to open a Pull Request or an Issue. Feedback and improvements are welcome!

Citation

@misc{keiro2026,
  author       = {Muskula Rahul},
  title        = {Keiro: Custom Sparse Mixture-of-Experts Injection into Qwen2.5-3B},
  year         = {2026},
  publisher    = {Hugging Face},
  howpublished = {\url{https://huggingface.co/iamrahulreddy/Keiro}}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for iamrahulreddy/Keiro

Base model

Qwen/Qwen2.5-3B
Adapter
(426)
this model

Papers for iamrahulreddy/Keiro