YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Self-Healing Training System (SHTS)

Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with SelfHealingTrainer, and cut debugging costs to near zero.

License: MIT HF Hub


The Problem

ML training fails constantly:

  • CUDA OOM kills jobs at step 847/1000 β€” restart from scratch
  • NaN loss silently corrupts models β€” discovered hours later
  • Loss spikes cascade into divergence β€” manual intervention required
  • DPO plateau at 0.693 loss (= random chance) β€” wasted GPU hours
  • No postmortem β€” "what step did it die on?"

Each failure costs developer time + GPU credits + schedule delay. At scale, this is millions in wasted compute.

The Solution

SHTS wraps any Hugging Face TRL trainer with four autonomous layers:

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  LAYER 4: ORCHESTRATION                 β”‚
β”‚  SelfHealingTrainer retry loop          β”‚
β”‚  while not converged: try β†’ recover     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 3: RECOVERY                      β”‚
β”‚  HealingActions: rollback, halve LR,    β”‚
β”‚  halve batch, reclip, clear cache       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 2: DIAGNOSIS                     β”‚
β”‚  Root-cause classifier: NaN/divergence/ β”‚
β”‚  OOM/data/API β€” with literature refs    β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 1: DETECTION                     β”‚
β”‚  SelfHealingCallback: loss, gradients,  β”‚
β”‚  memory, ZClip adaptive clipping        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Quick Start

pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig

# Your normal training setup
trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        output_dir="./output",
        learning_rate=2e-5,
        per_device_train_batch_size=4,
    ),
    train_dataset=dataset,
    tokenizer=tokenizer,
)

# Wrap with self-healing β€” that's it!
sh = SelfHealingTrainer(
    trainer,
    HealingConfig(
        max_recovery_attempts=5,
        zclip_enabled=True,
    ),
)

# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)

# Train with full autonomy
result = sh.train()

What Handles What

Failure Detection Recovery Paper
NaN loss math.isnan(loss) after each step Rollback β†’ halve LR β†’ enable grad clip ZClip arxiv:2504.02507
CUDA OOM on_exception catches OutOfMemoryError Halve batch (preserve effective via GA) β†’ gradient checkpointing β†’ clear cache Unicron arxiv:2401.00134
Loss spike Loss > 5Γ— running mean over window ZClip adaptive gradient clipping β†’ emergency checkpoint ZClip arxiv:2504.02507
Divergence Loss increasing for N consecutive steps Rollback β†’ halve LR Pioneer Agent arxiv:2604.09791
Gradient explosion grad_norm > 100 ZClip β†’ enable max_grad_norm=1.0 AdaGC arxiv:2502.11034
DPO plateau loss β‰ˆ 0.693 (random chance) Increase LR 2-5Γ— β†’ check data quality Rafailov et al. (2023)
Overfitting eval_loss - train_loss > 2.0 Alert with actionable recommendation Standard practice
API errors Exception with "api/network/timeout" Exponential backoff (30s β†’ 60s β†’ 120s β†’ ...) Standard pattern
Data errors Exception with "shape/dimension/index" Skip batch β†’ log bad sample Deep Researcher arxiv:2604.05854
Crash postmortem Always postmortem.json with exit reason, last step, metrics, recovery history PTT pattern

Crash Postmortem

Every training interruption produces a postmortem.json:

{
  "exit_reason": "exception",
  "exception_type": "OutOfMemoryError",
  "last_step": 847,
  "timestamp": "2026-04-30T15:26:04Z",
  "final_metrics": {"loss": 2.15, "grad_norm": 42.3},
  "recovery_actions": [
    {
      "failure": "oom",
      "diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
      "actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
    }
  ],
  "running_time_seconds": 1847.3
}

Trackio Integration

Set report_to="trackio" in your training args. SHTS emits:

  • Alerts at every decision point (INFO/WARN/ERROR)
  • Metrics: healing/recovery_attempts, healing/nan_count, healing/loss_spike_ratio, healing/eval_gap
  • ZClip metrics: zclip/raw_grad_norm, zclip/clipped_grad_norm, zclip/z_score, zclip/total_clips

Dashboard URL: https://huggingface.co/spaces/<username>/<trackio-space>

HealingConfig Presets

# Aggressive β€” for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10

# Conservative β€” only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2

# Custom
config = HealingConfig(
    nan_patience=5,
    loss_spike_factor=8.0,
    divergence_patience=100,
    max_recovery_attempts=3,
    zclip_enabled=True,
    zclip_z_threshold=3.0,
)

Compatibility

Trainer Status Notes
SFTTrainer (TRL) βœ… Full All metrics captured
DPOTrainer (TRL) βœ… Full DPO plateau detection (lossβ‰ˆ0.693)
GRPOTrainer (TRL) βœ… Full Group reward monitoring
PPOTrainer (TRL) βœ… Full KL divergence tracking
ORPOTrainer (TRL) βœ… Full Odds ratio monitoring
KTOTrainer (TRL) βœ… Full Desirable/undesirable logps
CPOTrainer (TRL) βœ… Full Contrastive preference
Trainer (Transformers) βœ… Full Standard ML training

Architecture

SelfHealingTrainer.train()
  β”‚
  β”œβ”€β”€ dry_run()              ← Validate setup first
  β”‚
  └── while not converged:
      β”‚
      β”œβ”€β”€ trainer.train()    ← Run training
      β”‚     β”‚
      β”‚     β”œβ”€β”€ on_step_end  ← Detect NaN, spikes, divergence
      β”‚     β”œβ”€β”€ on_log       ← Monitor gradients (ZClip)
      β”‚     β”œβ”€β”€ on_evaluate  ← Check overfitting
      β”‚     └── on_exception ← Catch OOM, API, data errors
      β”‚
      β”œβ”€β”€ [recovery needed?]
      β”‚     β”œβ”€β”€ diagnose     ← Classify failure type
      β”‚     β”œβ”€β”€ heal         ← Apply recovery actions
      β”‚     └── retry        ← resume_from_checkpoint=True
      β”‚
      └── [converged]        ← Done!

References

Paper ID Contribution
Unicron arxiv:2401.00134 Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling
ZClip arxiv:2504.02507 Z-score adaptive gradient clipping, eliminates catastrophic loss spikes
AdaGC arxiv:2502.11034 Per-tensor adaptive gradient clipping, optimizer-agnostic
Pioneer Agent arxiv:2604.09791 Structured decision tree by score buckets for autonomous iteration
Deep Researcher arxiv:2604.05854 Dry-run validation, zero-cost monitoring, constant-size memory
CheckFree arxiv:2506.15461 Pipeline-parallel recovery via neighbor averaging
DPO Rafailov et al. (2023) DPO plateau at 0.693 = random chance (Section 4.2)
PTT post-training-toolkit DiagnosticsCallback + postmortem pattern

License

MIT β€” use freely, attribution appreciated.


Built autonomously by ML Intern. Questions? Open an issue on the Hub.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for ScottzillaSystems/self-healing-training