Quintus / src /checkpoints.py
iamrahulreddy's picture
release: publish Quintus project files
4fc1bb9 verified
Raw
History Blame Contribute Delete
9.33 kB
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import torch
from configs import cfg
def checkpoint_rank(path: str) -> tuple[int, int]:
name = os.path.basename(path)
prefix, _, raw_value = name.partition("_")
try:
value = int(raw_value)
except ValueError:
value = -1
if prefix == "epoch":
return (2, value)
if prefix == "step":
return (1, value)
return (0, value)
def find_latest_training_checkpoint(output_dir: str) -> str | None:
candidates = []
for pattern in ("epoch_*", "step_*"):
candidates.extend(str(path) for path in Path(output_dir).glob(pattern) if path.is_dir())
if not candidates:
return None
return max(candidates, key=checkpoint_rank)
def load_trainer_state(checkpoint_dir: str, log) -> dict:
state_path = os.path.join(checkpoint_dir, "trainer_state.json")
if os.path.exists(state_path):
try:
with open(state_path, "r", encoding="utf-8") as f:
state = json.load(f)
if isinstance(state, dict):
return state
except (OSError, json.JSONDecodeError) as exc:
log.warning(f"Could not read trainer_state.json from {checkpoint_dir}: {exc}")
name = os.path.basename(checkpoint_dir)
prefix, _, raw_value = name.partition("_")
try:
value = int(raw_value)
except ValueError:
value = 0
if prefix == "epoch":
return {
"checkpoint_type": "epoch",
"start_epoch": value,
"global_step": 0,
"micro_step_global": 0,
"next_batch_in_epoch": 0,
}
if prefix == "step":
return {
"checkpoint_type": "step",
"start_epoch": 0,
"global_step": value,
"micro_step_global": 0,
"next_batch_in_epoch": 0,
}
return {}
def packing_checkpoint_metadata(enabled: bool, pack_length: int | None, max_seq_len: int) -> dict[str, int | bool | None]:
return {
"sequence_packing_enabled": bool(enabled),
"sequence_packing_pack_length": int(pack_length) if enabled and pack_length is not None else None,
"data_max_seq_len": int(max_seq_len),
}
def validate_resume_packing_state(
trainer_state: dict,
*,
enabled: bool,
pack_length: int,
max_seq_len: int,
log,
) -> None:
checkpoint_enabled = bool(trainer_state.get("sequence_packing_enabled", False))
if checkpoint_enabled != bool(enabled):
log.error(
"Checkpoint sequence-packing state does not match the current run: "
f"checkpoint={checkpoint_enabled}, current={bool(enabled)}."
)
raise SystemExit(1)
if checkpoint_enabled:
checkpoint_pack_length = trainer_state.get("sequence_packing_pack_length")
try:
checkpoint_pack_length = int(checkpoint_pack_length)
except (TypeError, ValueError):
log.error("Checkpoint is missing a valid sequence_packing_pack_length value.")
raise SystemExit(1)
if checkpoint_pack_length != int(pack_length):
log.error(
"Checkpoint pack length does not match the current run: "
f"checkpoint={checkpoint_pack_length}, current={int(pack_length)}."
)
raise SystemExit(1)
checkpoint_max_seq_len = trainer_state.get("data_max_seq_len")
if checkpoint_max_seq_len is not None:
try:
checkpoint_max_seq_len = int(checkpoint_max_seq_len)
except (TypeError, ValueError):
log.error("Checkpoint is missing a valid data_max_seq_len value.")
raise SystemExit(1)
if checkpoint_max_seq_len != int(max_seq_len):
log.error(
"Checkpoint max sequence length does not match the current run: "
f"checkpoint={checkpoint_max_seq_len}, current={int(max_seq_len)}."
)
raise SystemExit(1)
def save_checkpoint(
model,
tokenizer,
output_dir: str,
tag: str,
logger,
*,
scheduler=None,
trainer_state: dict | None = None,
) -> str:
save_dir = os.path.join(output_dir, tag)
os.makedirs(save_dir, exist_ok=True)
save_start = time.time()
logger.info(f"[CKPT] Saving {tag} -> {save_dir}/")
model_to_save = model.module if hasattr(model, "module") else model
if hasattr(model_to_save, "_orig_mod"):
model_to_save = model_to_save._orig_mod
model_to_save.config.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
try:
from safetensors.torch import save_file
state_dict = {k: v.contiguous().cpu() for k, v in model_to_save.state_dict().items()}
save_file(state_dict, os.path.join(save_dir, "model.safetensors"))
logger.info("[CKPT] Saved via safetensors")
except ImportError:
torch.save(model_to_save.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
logger.info("[CKPT] Saved via torch.save")
if scheduler is not None:
torch.save(scheduler.state_dict(), os.path.join(save_dir, "scheduler.pt"))
if trainer_state is not None:
trainer_state = dict(trainer_state)
trainer_state.setdefault("tag", tag)
trainer_state.setdefault("saved_at", time.strftime("%Y-%m-%d %H:%M:%S %Z"))
with open(os.path.join(save_dir, "trainer_state.json"), "w", encoding="utf-8") as f:
json.dump(trainer_state, f, indent=2)
size_mb = sum(f.stat().st_size for f in Path(save_dir).rglob("*") if f.is_file()) / 1e6
save_elapsed = time.time() - save_start
logger.info(f"[CKPT] {tag} -> {save_dir}/ ({size_mb:.0f} MB, {save_elapsed:.1f}s)")
return save_dir
def read_env_flag(name: str, default: bool = False) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def hub_upload_strict() -> bool:
strict = getattr(getattr(cfg, "hub", None), "hub_upload_strict", None)
if strict is None:
return read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False)
return bool(strict)
def should_upload_checkpoint_tag(tag: str) -> bool:
upload_regular = getattr(getattr(cfg, "hub", None), "upload_kd_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_KD_CHECKPOINTS", False)
upload_steps = getattr(getattr(cfg, "hub", None), "upload_step_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_STEP_CHECKPOINTS", False)
upload_last = getattr(getattr(cfg, "hub", None), "upload_last_checkpoint", False) or read_env_flag("QUINTUS_UPLOAD_LAST_CHECKPOINT", False)
if tag.startswith("step_"):
return upload_steps
if tag.startswith("epoch_"):
return upload_regular
if tag == "best":
return upload_regular
if tag == "last":
return upload_last or upload_regular
return False
def maybe_upload_checkpoint(checkpoint_dir: str, tag: str, logger) -> None:
if not should_upload_checkpoint_tag(tag):
return
token = os.environ.get("HF_TOKEN") or getattr(cfg.hub, "token", None)
if not token:
msg = "HF checkpoint upload requested, but HF_TOKEN/cfg.hub.token is missing"
strict = hub_upload_strict()
if strict:
raise RuntimeError(msg)
logger.warning(f"[CKPT] {msg}; continuing without remote backup")
return
repo_id = getattr(getattr(cfg, "hub", None), "repo_id", None) or os.environ.get("QUINTUS_HUB_REPO_ID") or f"{cfg.hub.username}/{cfg.hub.repo_name}"
base_path = getattr(getattr(cfg, "hub", None), "ckpt_path_in_repo", None) or os.environ.get("KD_CKPT_PATH_IN_REPO", "models/online_kd_3b_05b_ep3_B200_20260601")
base_path = base_path.strip("/")
path_in_repo = f"{base_path}/{tag}"
commit_prefix = getattr(getattr(cfg, "hub", None), "commit_message_prefix", None) or os.environ.get(
"KD_COMMIT_MESSAGE_PREFIX",
"Online KD 8B->1.7B Run",
)
commit_message = os.environ.get("KD_COMMIT_MESSAGE") or f"{commit_prefix}: upload {tag}"
upload_start = time.time()
size_mb = sum(f.stat().st_size for f in Path(checkpoint_dir).rglob("*") if f.is_file()) / 1e6
strict = hub_upload_strict()
logger.info(
f"[CKPT] Uploading {tag} -> {repo_id}/{path_in_repo} "
f"({size_mb:.0f} MB, strict={strict})"
)
logger.info(f"[CKPT] Commit: {commit_message}")
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
api.upload_folder(
folder_path=checkpoint_dir,
repo_id=repo_id,
path_in_repo=path_in_repo,
repo_type="dataset",
commit_message=commit_message,
ignore_patterns=["*.tmp", "*.log", "__pycache__/*"],
)
upload_elapsed = time.time() - upload_start
logger.info(f"[CKPT] Uploaded {tag} to HF Hub in {upload_elapsed / 60:.1f}m")
except Exception as exc:
msg = f"HF checkpoint upload failed for {tag}: {exc}"
if hub_upload_strict():
raise RuntimeError(msg) from exc
logger.warning(f"[CKPT] {msg}; continuing because hub upload strict mode is disabled")