YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Self-Healing Training System (SHTS)
Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with
SelfHealingTrainer, and cut debugging costs to near zero.
The Problem
ML training fails constantly:
- CUDA OOM kills jobs at step 847/1000 β restart from scratch
- NaN loss silently corrupts models β discovered hours later
- Loss spikes cascade into divergence β manual intervention required
- DPO plateau at 0.693 loss (= random chance) β wasted GPU hours
- No postmortem β "what step did it die on?"
Each failure costs developer time + GPU credits + schedule delay. At scale, this is millions in wasted compute.
The Solution
SHTS wraps any Hugging Face TRL trainer with four autonomous layers:
βββββββββββββββββββββββββββββββββββββββββββ
β LAYER 4: ORCHESTRATION β
β SelfHealingTrainer retry loop β
β while not converged: try β recover β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 3: RECOVERY β
β HealingActions: rollback, halve LR, β
β halve batch, reclip, clear cache β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 2: DIAGNOSIS β
β Root-cause classifier: NaN/divergence/ β
β OOM/data/API β with literature refs β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 1: DETECTION β
β SelfHealingCallback: loss, gradients, β
β memory, ZClip adaptive clipping β
βββββββββββββββββββββββββββββββββββββββββββ
Quick Start
pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig
# Your normal training setup
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
),
train_dataset=dataset,
tokenizer=tokenizer,
)
# Wrap with self-healing β that's it!
sh = SelfHealingTrainer(
trainer,
HealingConfig(
max_recovery_attempts=5,
zclip_enabled=True,
),
)
# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)
# Train with full autonomy
result = sh.train()
What Handles What
| Failure | Detection | Recovery | Paper |
|---|---|---|---|
| NaN loss | math.isnan(loss) after each step |
Rollback β halve LR β enable grad clip | ZClip arxiv:2504.02507 |
| CUDA OOM | on_exception catches OutOfMemoryError |
Halve batch (preserve effective via GA) β gradient checkpointing β clear cache | Unicron arxiv:2401.00134 |
| Loss spike | Loss > 5Γ running mean over window | ZClip adaptive gradient clipping β emergency checkpoint | ZClip arxiv:2504.02507 |
| Divergence | Loss increasing for N consecutive steps | Rollback β halve LR | Pioneer Agent arxiv:2604.09791 |
| Gradient explosion | grad_norm > 100 |
ZClip β enable max_grad_norm=1.0 | AdaGC arxiv:2502.11034 |
| DPO plateau | loss β 0.693 (random chance) |
Increase LR 2-5Γ β check data quality | Rafailov et al. (2023) |
| Overfitting | eval_loss - train_loss > 2.0 |
Alert with actionable recommendation | Standard practice |
| API errors | Exception with "api/network/timeout" | Exponential backoff (30s β 60s β 120s β ...) | Standard pattern |
| Data errors | Exception with "shape/dimension/index" | Skip batch β log bad sample | Deep Researcher arxiv:2604.05854 |
| Crash postmortem | Always | postmortem.json with exit reason, last step, metrics, recovery history |
PTT pattern |
Crash Postmortem
Every training interruption produces a postmortem.json:
{
"exit_reason": "exception",
"exception_type": "OutOfMemoryError",
"last_step": 847,
"timestamp": "2026-04-30T15:26:04Z",
"final_metrics": {"loss": 2.15, "grad_norm": 42.3},
"recovery_actions": [
{
"failure": "oom",
"diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
"actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
}
],
"running_time_seconds": 1847.3
}
Trackio Integration
Set report_to="trackio" in your training args. SHTS emits:
- Alerts at every decision point (INFO/WARN/ERROR)
- Metrics:
healing/recovery_attempts,healing/nan_count,healing/loss_spike_ratio,healing/eval_gap - ZClip metrics:
zclip/raw_grad_norm,zclip/clipped_grad_norm,zclip/z_score,zclip/total_clips
Dashboard URL: https://huggingface.co/spaces/<username>/<trackio-space>
HealingConfig Presets
# Aggressive β for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10
# Conservative β only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2
# Custom
config = HealingConfig(
nan_patience=5,
loss_spike_factor=8.0,
divergence_patience=100,
max_recovery_attempts=3,
zclip_enabled=True,
zclip_z_threshold=3.0,
)
Compatibility
| Trainer | Status | Notes |
|---|---|---|
SFTTrainer (TRL) |
β Full | All metrics captured |
DPOTrainer (TRL) |
β Full | DPO plateau detection (lossβ0.693) |
GRPOTrainer (TRL) |
β Full | Group reward monitoring |
PPOTrainer (TRL) |
β Full | KL divergence tracking |
ORPOTrainer (TRL) |
β Full | Odds ratio monitoring |
KTOTrainer (TRL) |
β Full | Desirable/undesirable logps |
CPOTrainer (TRL) |
β Full | Contrastive preference |
Trainer (Transformers) |
β Full | Standard ML training |
Architecture
SelfHealingTrainer.train()
β
βββ dry_run() β Validate setup first
β
βββ while not converged:
β
βββ trainer.train() β Run training
β β
β βββ on_step_end β Detect NaN, spikes, divergence
β βββ on_log β Monitor gradients (ZClip)
β βββ on_evaluate β Check overfitting
β βββ on_exception β Catch OOM, API, data errors
β
βββ [recovery needed?]
β βββ diagnose β Classify failure type
β βββ heal β Apply recovery actions
β βββ retry β resume_from_checkpoint=True
β
βββ [converged] β Done!
References
| Paper | ID | Contribution |
|---|---|---|
| Unicron | arxiv:2401.00134 | Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling |
| ZClip | arxiv:2504.02507 | Z-score adaptive gradient clipping, eliminates catastrophic loss spikes |
| AdaGC | arxiv:2502.11034 | Per-tensor adaptive gradient clipping, optimizer-agnostic |
| Pioneer Agent | arxiv:2604.09791 | Structured decision tree by score buckets for autonomous iteration |
| Deep Researcher | arxiv:2604.05854 | Dry-run validation, zero-cost monitoring, constant-size memory |
| CheckFree | arxiv:2506.15461 | Pipeline-parallel recovery via neighbor averaging |
| DPO | Rafailov et al. (2023) | DPO plateau at 0.693 = random chance (Section 4.2) |
| PTT | post-training-toolkit | DiagnosticsCallback + postmortem pattern |
License
MIT β use freely, attribution appreciated.
Built autonomously by ML Intern. Questions? Open an issue on the Hub.