sql-error-classifier / scripts /push_codebert_to_hub.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
Raw
History Blame Contribute Delete
2.86 kB
#!/usr/bin/env python3
"""Push a locally trained CodeBERT model to Hugging Face Hub."""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from huggingface_hub import HfApi, create_repo
from src.codebert_labels import load_codebert_labels
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
MODEL_CARD = PROJECT_ROOT / "hub" / "CODEBERT_MODEL_CARD.md"
def push(
model_dir: Path = DEFAULT_MODEL_DIR,
repo_id: str = "",
private: bool = False,
token: str | None = None,
) -> str:
if not model_dir.exists():
raise FileNotFoundError(f"Model not found at {model_dir}. Train first.")
if not repo_id:
raise ValueError("--repo-id required, e.g. nishu08/sql-codebert-classifier")
token = token or os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if not token:
raise ValueError("Set HF_TOKEN environment variable")
# Ensure label config exists for inference
label_config = model_dir / "label_config.json"
if not label_config.exists():
with open(label_config, "w") as f:
json.dump(
{
"labels": load_codebert_labels(),
"model_name": "microsoft/codebert-base",
"architecture": "codebert-cross-encoder",
"threshold": 0.5,
"max_length": 512,
},
f,
indent=2,
)
api = HfApi(token=token)
create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token)
print(f"Uploading {model_dir}{repo_id} ...")
api.upload_folder(
folder_path=str(model_dir),
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="Upload CodeBERT SQL error classifier",
)
if MODEL_CARD.exists():
api.upload_file(
path_or_fileobj=str(MODEL_CARD),
path_in_repo="README.md",
repo_id=repo_id,
repo_type="model",
token=token,
)
url = f"https://huggingface.co/{repo_id}"
print(f"Done: {url}")
return url
def main() -> None:
parser = argparse.ArgumentParser(description="Push CodeBERT model to HF Hub")
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
parser.add_argument("--repo-id", type=str, required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument("--token", type=str, default=None)
args = parser.parse_args()
push(
model_dir=args.model_dir,
repo_id=args.repo_id,
private=args.private,
token=args.token,
)
if __name__ == "__main__":
main()