kasanoma_ASR / api.py
Kennethdot's picture
Update api.py
0a660ef verified
Raw
History Blame Contribute Delete
9.54 kB
import json
import os
import uuid
import shutil
import threading
import numpy as np
from datetime import datetime, timezone
from pathlib import Path
import librosa
import torch
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import HfApi, create_repo, upload_file
from transformers import WhisperForConditionalGeneration, WhisperProcessor
# ── Env ────────────────────────────────────────────────────────────────────
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO = os.getenv("HF_DATASET_REPO")
# ── Paths ──────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent
STATIC_DIR = BASE_DIR / "static"
DATASET_DIR = BASE_DIR / "dataset"
AUDIO_DIR = DATASET_DIR / "audio"
MANIFEST = DATASET_DIR / "transcripts.jsonl"
STATIC_DIR.mkdir(exist_ok=True)
DATASET_DIR.mkdir(exist_ok=True)
AUDIO_DIR.mkdir(exist_ok=True)
# ── HuggingFace setup ──────────────────────────────────────────────────────
_hf_api: HfApi | None = None
def get_hf_api() -> HfApi | None:
global _hf_api
if _hf_api is None and HF_TOKEN:
_hf_api = HfApi(token=HF_TOKEN)
try:
create_repo(
repo_id=DATASET_REPO,
repo_type="dataset",
exist_ok=True,
token=HF_TOKEN,
)
except Exception as e:
print(f"[HF] Dataset repo check: {e}")
return _hf_api
# ── Model (lazy-loaded on first request) ──────────────────────────────────
_MODEL: WhisperForConditionalGeneration | None = None
_PROCESSOR: WhisperProcessor | None = None
_DEVICE: torch.device | None = None
def get_model():
global _MODEL, _PROCESSOR, _DEVICE
if _MODEL is None:
print("Loading Kennethdot/kasanoma_whisper …")
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_MODEL = WhisperForConditionalGeneration.from_pretrained(
"Kennethdot/kasanoma_whisper",
torch_dtype=torch.float16 if _DEVICE.type == "cuda" else torch.float32,
).to(_DEVICE)
_MODEL.eval()
_PROCESSOR = WhisperProcessor.from_pretrained(
"Kennethdot/kasanoma_whisper"
)
print(f"Model ready on {_DEVICE}.")
return _MODEL, _PROCESSOR, _DEVICE
# ── App ────────────────────────────────────────────────────────────────────
app = FastAPI(title="Kasanoma ASR", version="2.1.0")
_csv_lock = threading.Lock()
# ── Helpers ────────────────────────────────────────────────────────────────
def audio_duration(path: Path) -> float:
try:
y, sr = librosa.load(str(path), sr=None, mono=True)
return round(len(y) / sr, 2)
except Exception:
return 0.0
def save_entry(entry: dict) -> None:
with MANIFEST.open("a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
def load_manifest() -> list[dict]:
if not MANIFEST.exists():
return []
entries = []
with MANIFEST.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def _upload_in_background(audio_path: Path, relative_audio_path: str) -> None:
api = get_hf_api()
if api is None:
return
try:
upload_file(
path_or_fileobj=str(audio_path),
path_in_repo=relative_audio_path,
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN,
)
with _csv_lock:
upload_file(
path_or_fileobj=str(MANIFEST),
path_in_repo="transcripts.jsonl",
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN,
)
except Exception as e:
print(f"[Background upload error] {e}")
def transcribe_path(audio_path: Path) -> tuple[str, str]:
model, processor, device = get_model()
# librosa decodes any format (webm, mp4, ogg, wav) via ffmpeg,
# returns mono float32 already resampled to 16 kHz
audio_data, _ = librosa.load(str(audio_path), sr=16000, mono=True)
# Normalise
peak = np.max(np.abs(audio_data))
if peak > 0:
audio_data = audio_data / peak
# Feature extraction β€” use the full processor so we get attention_mask
inputs = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
return_attention_mask=True,
)
input_features = inputs.input_features.to(device)
attention_mask = inputs.attention_mask.to(device)
# Cast to fp16 on GPU for speed
if device.type == "cuda":
input_features = input_features.half()
with torch.no_grad():
generated_ids = model.generate(
input_features,
attention_mask=attention_mask,
task="transcribe",
language="yo", # Twi β€” not "yo" (Yoruba)
temperature=0.0,
forced_decoder_ids=None, # avoids duplicate logits processor warnings
)
transcription = processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0].strip()
return transcription, "tw"
# ── Routes ─────────────────────────────────────────────────────────────────
@app.get("/", include_in_schema=False)
async def root():
index = STATIC_DIR / "index.html"
if not index.exists():
raise HTTPException(404, "Frontend not found. Place index.html in static/")
return FileResponse(index)
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
suffix = Path(audio.filename or "audio.webm").suffix or ".webm"
tmp_path = BASE_DIR / f"_tmp_{uuid.uuid4().hex}{suffix}"
try:
with tmp_path.open("wb") as f:
shutil.copyfileobj(audio.file, f)
transcription, detected_lang = transcribe_path(tmp_path)
return JSONResponse({
"transcription": transcription,
"detected_language": detected_lang,
})
except Exception as e:
print(f"[Transcribe error] {e}")
raise HTTPException(500, detail=str(e))
finally:
tmp_path.unlink(missing_ok=True)
@app.post("/save")
async def save(
audio: UploadFile = File(...),
transcription: str = Form(...),
):
if not transcription.strip():
raise HTTPException(422, "Transcription must not be empty.")
entry_id = uuid.uuid4().hex
suffix = Path(audio.filename or "audio.webm").suffix or ".webm"
audio_filename = f"{entry_id}{suffix}"
audio_path = AUDIO_DIR / audio_filename
with audio_path.open("wb") as f:
shutil.copyfileobj(audio.file, f)
duration = audio_duration(audio_path)
entry = {
"id": entry_id,
"audio_file": f"dataset/audio/{audio_filename}",
"transcription": transcription.strip(),
"language": "twi_en",
"duration_s": duration,
"created_at": datetime.now(timezone.utc).isoformat(),
}
with _csv_lock:
save_entry(entry)
relative_audio_path = f"audio/{audio_filename}"
threading.Thread(
target=_upload_in_background,
args=(audio_path, relative_audio_path),
daemon=True,
).start()
total = len(load_manifest())
return JSONResponse({
"id": entry_id,
"total_saved": total,
"duration_s": duration,
})
@app.get("/dataset/stats")
async def dataset_stats():
entries = load_manifest()
if not entries:
return JSONResponse({"total": 0, "total_duration_s": 0, "total_words": 0})
total_duration = sum(e.get("duration_s", 0) for e in entries)
total_words = sum(len(e["transcription"].split()) for e in entries)
return JSONResponse({
"total": len(entries),
"total_duration_s": round(total_duration, 1),
"total_words": total_words,
})
@app.get("/dataset/entries")
async def dataset_entries(limit: int = 50, offset: int = 0):
entries = load_manifest()
entries.reverse()
return JSONResponse({
"entries": entries[offset : offset + limit],
"total": len(entries),
})
# ── Static files ───────────────────────────────────────────────────────────
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")