#!/usr/bin/env python3 """Push a locally trained CodeBERT model to Hugging Face Hub.""" from __future__ import annotations import argparse import json import os import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from huggingface_hub import HfApi, create_repo from src.codebert_labels import load_codebert_labels DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder" MODEL_CARD = PROJECT_ROOT / "hub" / "CODEBERT_MODEL_CARD.md" def push( model_dir: Path = DEFAULT_MODEL_DIR, repo_id: str = "", private: bool = False, token: str | None = None, ) -> str: if not model_dir.exists(): raise FileNotFoundError(f"Model not found at {model_dir}. Train first.") if not repo_id: raise ValueError("--repo-id required, e.g. nishu08/sql-codebert-classifier") token = token or os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") if not token: raise ValueError("Set HF_TOKEN environment variable") # Ensure label config exists for inference label_config = model_dir / "label_config.json" if not label_config.exists(): with open(label_config, "w") as f: json.dump( { "labels": load_codebert_labels(), "model_name": "microsoft/codebert-base", "architecture": "codebert-cross-encoder", "threshold": 0.5, "max_length": 512, }, f, indent=2, ) api = HfApi(token=token) create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token) print(f"Uploading {model_dir} → {repo_id} ...") api.upload_folder( folder_path=str(model_dir), repo_id=repo_id, repo_type="model", token=token, commit_message="Upload CodeBERT SQL error classifier", ) if MODEL_CARD.exists(): api.upload_file( path_or_fileobj=str(MODEL_CARD), path_in_repo="README.md", repo_id=repo_id, repo_type="model", token=token, ) url = f"https://huggingface.co/{repo_id}" print(f"Done: {url}") return url def main() -> None: parser = argparse.ArgumentParser(description="Push CodeBERT model to HF Hub") parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) parser.add_argument("--repo-id", type=str, required=True) parser.add_argument("--private", action="store_true") parser.add_argument("--token", type=str, default=None) args = parser.parse_args() push( model_dir=args.model_dir, repo_id=args.repo_id, private=args.private, token=args.token, ) if __name__ == "__main__": main()