Spaces:
Running
Running
| import json | |
| import os | |
| import uuid | |
| import shutil | |
| import threading | |
| import numpy as np | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import librosa | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from huggingface_hub import HfApi, create_repo, upload_file | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| # ββ Env ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DATASET_REPO = os.getenv("HF_DATASET_REPO") | |
| # ββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).parent | |
| STATIC_DIR = BASE_DIR / "static" | |
| DATASET_DIR = BASE_DIR / "dataset" | |
| AUDIO_DIR = DATASET_DIR / "audio" | |
| MANIFEST = DATASET_DIR / "transcripts.jsonl" | |
| STATIC_DIR.mkdir(exist_ok=True) | |
| DATASET_DIR.mkdir(exist_ok=True) | |
| AUDIO_DIR.mkdir(exist_ok=True) | |
| # ββ HuggingFace setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _hf_api: HfApi | None = None | |
| def get_hf_api() -> HfApi | None: | |
| global _hf_api | |
| if _hf_api is None and HF_TOKEN: | |
| _hf_api = HfApi(token=HF_TOKEN) | |
| try: | |
| create_repo( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| token=HF_TOKEN, | |
| ) | |
| except Exception as e: | |
| print(f"[HF] Dataset repo check: {e}") | |
| return _hf_api | |
| # ββ Model (lazy-loaded on first request) ββββββββββββββββββββββββββββββββββ | |
| _MODEL: WhisperForConditionalGeneration | None = None | |
| _PROCESSOR: WhisperProcessor | None = None | |
| _DEVICE: torch.device | None = None | |
| def get_model(): | |
| global _MODEL, _PROCESSOR, _DEVICE | |
| if _MODEL is None: | |
| print("Loading Kennethdot/kasanoma_whisper β¦") | |
| _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _MODEL = WhisperForConditionalGeneration.from_pretrained( | |
| "Kennethdot/kasanoma_whisper", | |
| torch_dtype=torch.float16 if _DEVICE.type == "cuda" else torch.float32, | |
| ).to(_DEVICE) | |
| _MODEL.eval() | |
| _PROCESSOR = WhisperProcessor.from_pretrained( | |
| "Kennethdot/kasanoma_whisper" | |
| ) | |
| print(f"Model ready on {_DEVICE}.") | |
| return _MODEL, _PROCESSOR, _DEVICE | |
| # ββ App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Kasanoma ASR", version="2.1.0") | |
| _csv_lock = threading.Lock() | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def audio_duration(path: Path) -> float: | |
| try: | |
| y, sr = librosa.load(str(path), sr=None, mono=True) | |
| return round(len(y) / sr, 2) | |
| except Exception: | |
| return 0.0 | |
| def save_entry(entry: dict) -> None: | |
| with MANIFEST.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(entry, ensure_ascii=False) + "\n") | |
| def load_manifest() -> list[dict]: | |
| if not MANIFEST.exists(): | |
| return [] | |
| entries = [] | |
| with MANIFEST.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| entries.append(json.loads(line)) | |
| return entries | |
| def _upload_in_background(audio_path: Path, relative_audio_path: str) -> None: | |
| api = get_hf_api() | |
| if api is None: | |
| return | |
| try: | |
| upload_file( | |
| path_or_fileobj=str(audio_path), | |
| path_in_repo=relative_audio_path, | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| with _csv_lock: | |
| upload_file( | |
| path_or_fileobj=str(MANIFEST), | |
| path_in_repo="transcripts.jsonl", | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| except Exception as e: | |
| print(f"[Background upload error] {e}") | |
| def transcribe_path(audio_path: Path) -> tuple[str, str]: | |
| model, processor, device = get_model() | |
| # librosa decodes any format (webm, mp4, ogg, wav) via ffmpeg, | |
| # returns mono float32 already resampled to 16 kHz | |
| audio_data, _ = librosa.load(str(audio_path), sr=16000, mono=True) | |
| # Normalise | |
| peak = np.max(np.abs(audio_data)) | |
| if peak > 0: | |
| audio_data = audio_data / peak | |
| # Feature extraction β use the full processor so we get attention_mask | |
| inputs = processor( | |
| audio_data, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| input_features = inputs.input_features.to(device) | |
| attention_mask = inputs.attention_mask.to(device) | |
| # Cast to fp16 on GPU for speed | |
| if device.type == "cuda": | |
| input_features = input_features.half() | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_features, | |
| attention_mask=attention_mask, | |
| task="transcribe", | |
| language="yo", # Twi β not "yo" (Yoruba) | |
| temperature=0.0, | |
| forced_decoder_ids=None, # avoids duplicate logits processor warnings | |
| ) | |
| transcription = processor.batch_decode( | |
| generated_ids, skip_special_tokens=True | |
| )[0].strip() | |
| return transcription, "tw" | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| index = STATIC_DIR / "index.html" | |
| if not index.exists(): | |
| raise HTTPException(404, "Frontend not found. Place index.html in static/") | |
| return FileResponse(index) | |
| async def transcribe(audio: UploadFile = File(...)): | |
| suffix = Path(audio.filename or "audio.webm").suffix or ".webm" | |
| tmp_path = BASE_DIR / f"_tmp_{uuid.uuid4().hex}{suffix}" | |
| try: | |
| with tmp_path.open("wb") as f: | |
| shutil.copyfileobj(audio.file, f) | |
| transcription, detected_lang = transcribe_path(tmp_path) | |
| return JSONResponse({ | |
| "transcription": transcription, | |
| "detected_language": detected_lang, | |
| }) | |
| except Exception as e: | |
| print(f"[Transcribe error] {e}") | |
| raise HTTPException(500, detail=str(e)) | |
| finally: | |
| tmp_path.unlink(missing_ok=True) | |
| async def save( | |
| audio: UploadFile = File(...), | |
| transcription: str = Form(...), | |
| ): | |
| if not transcription.strip(): | |
| raise HTTPException(422, "Transcription must not be empty.") | |
| entry_id = uuid.uuid4().hex | |
| suffix = Path(audio.filename or "audio.webm").suffix or ".webm" | |
| audio_filename = f"{entry_id}{suffix}" | |
| audio_path = AUDIO_DIR / audio_filename | |
| with audio_path.open("wb") as f: | |
| shutil.copyfileobj(audio.file, f) | |
| duration = audio_duration(audio_path) | |
| entry = { | |
| "id": entry_id, | |
| "audio_file": f"dataset/audio/{audio_filename}", | |
| "transcription": transcription.strip(), | |
| "language": "twi_en", | |
| "duration_s": duration, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| with _csv_lock: | |
| save_entry(entry) | |
| relative_audio_path = f"audio/{audio_filename}" | |
| threading.Thread( | |
| target=_upload_in_background, | |
| args=(audio_path, relative_audio_path), | |
| daemon=True, | |
| ).start() | |
| total = len(load_manifest()) | |
| return JSONResponse({ | |
| "id": entry_id, | |
| "total_saved": total, | |
| "duration_s": duration, | |
| }) | |
| async def dataset_stats(): | |
| entries = load_manifest() | |
| if not entries: | |
| return JSONResponse({"total": 0, "total_duration_s": 0, "total_words": 0}) | |
| total_duration = sum(e.get("duration_s", 0) for e in entries) | |
| total_words = sum(len(e["transcription"].split()) for e in entries) | |
| return JSONResponse({ | |
| "total": len(entries), | |
| "total_duration_s": round(total_duration, 1), | |
| "total_words": total_words, | |
| }) | |
| async def dataset_entries(limit: int = 50, offset: int = 0): | |
| entries = load_manifest() | |
| entries.reverse() | |
| return JSONResponse({ | |
| "entries": entries[offset : offset + limit], | |
| "total": len(entries), | |
| }) | |
| # ββ Static files βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") |