from __future__ import annotations import json import os import time from pathlib import Path import torch from configs import cfg def checkpoint_rank(path: str) -> tuple[int, int]: name = os.path.basename(path) prefix, _, raw_value = name.partition("_") try: value = int(raw_value) except ValueError: value = -1 if prefix == "epoch": return (2, value) if prefix == "step": return (1, value) return (0, value) def find_latest_training_checkpoint(output_dir: str) -> str | None: candidates = [] for pattern in ("epoch_*", "step_*"): candidates.extend(str(path) for path in Path(output_dir).glob(pattern) if path.is_dir()) if not candidates: return None return max(candidates, key=checkpoint_rank) def load_trainer_state(checkpoint_dir: str, log) -> dict: state_path = os.path.join(checkpoint_dir, "trainer_state.json") if os.path.exists(state_path): try: with open(state_path, "r", encoding="utf-8") as f: state = json.load(f) if isinstance(state, dict): return state except (OSError, json.JSONDecodeError) as exc: log.warning(f"Could not read trainer_state.json from {checkpoint_dir}: {exc}") name = os.path.basename(checkpoint_dir) prefix, _, raw_value = name.partition("_") try: value = int(raw_value) except ValueError: value = 0 if prefix == "epoch": return { "checkpoint_type": "epoch", "start_epoch": value, "global_step": 0, "micro_step_global": 0, "next_batch_in_epoch": 0, } if prefix == "step": return { "checkpoint_type": "step", "start_epoch": 0, "global_step": value, "micro_step_global": 0, "next_batch_in_epoch": 0, } return {} def packing_checkpoint_metadata(enabled: bool, pack_length: int | None, max_seq_len: int) -> dict[str, int | bool | None]: return { "sequence_packing_enabled": bool(enabled), "sequence_packing_pack_length": int(pack_length) if enabled and pack_length is not None else None, "data_max_seq_len": int(max_seq_len), } def validate_resume_packing_state( trainer_state: dict, *, enabled: bool, pack_length: int, max_seq_len: int, log, ) -> None: checkpoint_enabled = bool(trainer_state.get("sequence_packing_enabled", False)) if checkpoint_enabled != bool(enabled): log.error( "Checkpoint sequence-packing state does not match the current run: " f"checkpoint={checkpoint_enabled}, current={bool(enabled)}." ) raise SystemExit(1) if checkpoint_enabled: checkpoint_pack_length = trainer_state.get("sequence_packing_pack_length") try: checkpoint_pack_length = int(checkpoint_pack_length) except (TypeError, ValueError): log.error("Checkpoint is missing a valid sequence_packing_pack_length value.") raise SystemExit(1) if checkpoint_pack_length != int(pack_length): log.error( "Checkpoint pack length does not match the current run: " f"checkpoint={checkpoint_pack_length}, current={int(pack_length)}." ) raise SystemExit(1) checkpoint_max_seq_len = trainer_state.get("data_max_seq_len") if checkpoint_max_seq_len is not None: try: checkpoint_max_seq_len = int(checkpoint_max_seq_len) except (TypeError, ValueError): log.error("Checkpoint is missing a valid data_max_seq_len value.") raise SystemExit(1) if checkpoint_max_seq_len != int(max_seq_len): log.error( "Checkpoint max sequence length does not match the current run: " f"checkpoint={checkpoint_max_seq_len}, current={int(max_seq_len)}." ) raise SystemExit(1) def save_checkpoint( model, tokenizer, output_dir: str, tag: str, logger, *, scheduler=None, trainer_state: dict | None = None, ) -> str: save_dir = os.path.join(output_dir, tag) os.makedirs(save_dir, exist_ok=True) save_start = time.time() logger.info(f"[CKPT] Saving {tag} -> {save_dir}/") model_to_save = model.module if hasattr(model, "module") else model if hasattr(model_to_save, "_orig_mod"): model_to_save = model_to_save._orig_mod model_to_save.config.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) try: from safetensors.torch import save_file state_dict = {k: v.contiguous().cpu() for k, v in model_to_save.state_dict().items()} save_file(state_dict, os.path.join(save_dir, "model.safetensors")) logger.info("[CKPT] Saved via safetensors") except ImportError: torch.save(model_to_save.state_dict(), os.path.join(save_dir, "pytorch_model.bin")) logger.info("[CKPT] Saved via torch.save") if scheduler is not None: torch.save(scheduler.state_dict(), os.path.join(save_dir, "scheduler.pt")) if trainer_state is not None: trainer_state = dict(trainer_state) trainer_state.setdefault("tag", tag) trainer_state.setdefault("saved_at", time.strftime("%Y-%m-%d %H:%M:%S %Z")) with open(os.path.join(save_dir, "trainer_state.json"), "w", encoding="utf-8") as f: json.dump(trainer_state, f, indent=2) size_mb = sum(f.stat().st_size for f in Path(save_dir).rglob("*") if f.is_file()) / 1e6 save_elapsed = time.time() - save_start logger.info(f"[CKPT] {tag} -> {save_dir}/ ({size_mb:.0f} MB, {save_elapsed:.1f}s)") return save_dir def read_env_flag(name: str, default: bool = False) -> bool: raw = os.environ.get(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "yes", "on"} def hub_upload_strict() -> bool: strict = getattr(getattr(cfg, "hub", None), "hub_upload_strict", None) if strict is None: return read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False) return bool(strict) def should_upload_checkpoint_tag(tag: str) -> bool: upload_regular = getattr(getattr(cfg, "hub", None), "upload_kd_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_KD_CHECKPOINTS", False) upload_steps = getattr(getattr(cfg, "hub", None), "upload_step_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_STEP_CHECKPOINTS", False) upload_last = getattr(getattr(cfg, "hub", None), "upload_last_checkpoint", False) or read_env_flag("QUINTUS_UPLOAD_LAST_CHECKPOINT", False) if tag.startswith("step_"): return upload_steps if tag.startswith("epoch_"): return upload_regular if tag == "best": return upload_regular if tag == "last": return upload_last or upload_regular return False def maybe_upload_checkpoint(checkpoint_dir: str, tag: str, logger) -> None: if not should_upload_checkpoint_tag(tag): return token = os.environ.get("HF_TOKEN") or getattr(cfg.hub, "token", None) if not token: msg = "HF checkpoint upload requested, but HF_TOKEN/cfg.hub.token is missing" strict = hub_upload_strict() if strict: raise RuntimeError(msg) logger.warning(f"[CKPT] {msg}; continuing without remote backup") return repo_id = getattr(getattr(cfg, "hub", None), "repo_id", None) or os.environ.get("QUINTUS_HUB_REPO_ID") or f"{cfg.hub.username}/{cfg.hub.repo_name}" base_path = getattr(getattr(cfg, "hub", None), "ckpt_path_in_repo", None) or os.environ.get("KD_CKPT_PATH_IN_REPO", "models/online_kd_3b_05b_ep3_B200_20260601") base_path = base_path.strip("/") path_in_repo = f"{base_path}/{tag}" commit_prefix = getattr(getattr(cfg, "hub", None), "commit_message_prefix", None) or os.environ.get( "KD_COMMIT_MESSAGE_PREFIX", "Online KD 8B->1.7B Run", ) commit_message = os.environ.get("KD_COMMIT_MESSAGE") or f"{commit_prefix}: upload {tag}" upload_start = time.time() size_mb = sum(f.stat().st_size for f in Path(checkpoint_dir).rglob("*") if f.is_file()) / 1e6 strict = hub_upload_strict() logger.info( f"[CKPT] Uploading {tag} -> {repo_id}/{path_in_repo} " f"({size_mb:.0f} MB, strict={strict})" ) logger.info(f"[CKPT] Commit: {commit_message}") try: from huggingface_hub import HfApi api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True) api.upload_folder( folder_path=checkpoint_dir, repo_id=repo_id, path_in_repo=path_in_repo, repo_type="dataset", commit_message=commit_message, ignore_patterns=["*.tmp", "*.log", "__pycache__/*"], ) upload_elapsed = time.time() - upload_start logger.info(f"[CKPT] Uploaded {tag} to HF Hub in {upload_elapsed / 60:.1f}m") except Exception as exc: msg = f"HF checkpoint upload failed for {tag}: {exc}" if hub_upload_strict(): raise RuntimeError(msg) from exc logger.warning(f"[CKPT] {msg}; continuing because hub upload strict mode is disabled")