| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import onnxruntime as ort |
| import soundfile as sf |
| from scipy.signal import resample_poly |
|
|
|
|
| MODEL_PATH = Path(__file__).with_name("model.onnx") |
| SAMPLE_RATE = 16_000 |
| CHUNK_FRAMES = 160_000 |
| SUPPORTED_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} |
|
|
|
|
| def load_audio(path: Path) -> tuple[np.ndarray, int]: |
| audio, sample_rate = sf.read(path, dtype="float32", always_2d=True) |
| waveform = np.ascontiguousarray(audio.mean(axis=1), dtype=np.float32) |
| return waveform, int(sample_rate) |
|
|
|
|
| def resample_audio( |
| waveform: np.ndarray, source_rate: int, target_rate: int |
| ) -> np.ndarray: |
| gcd = np.gcd(source_rate, target_rate) |
| waveform = resample_poly( |
| waveform, target_rate // gcd, source_rate // gcd |
| ).astype(np.float32) |
| return np.ascontiguousarray(waveform) |
|
|
|
|
| def layer_norm(waveform: np.ndarray, eps: float = 1e-5) -> np.ndarray: |
| mean = waveform.mean(dtype=np.float64) |
| variance = waveform.var(dtype=np.float64) |
| return ((waveform - mean) / np.sqrt(variance + eps)).astype(np.float32) |
|
|
|
|
| def chunk_waveform(waveform: np.ndarray, chunk_frames: int) -> np.ndarray: |
| if chunk_frames <= 0 or waveform.size <= chunk_frames: |
| return waveform[None, :] |
|
|
| chunks = [ |
| waveform[start : start + chunk_frames] |
| for start in range(0, waveform.size, chunk_frames) |
| ] |
| max_length = max(chunk.size for chunk in chunks) |
| batch = np.zeros((len(chunks), max_length), dtype=np.float32) |
|
|
| for index, chunk in enumerate(chunks): |
| batch[index, : chunk.size] = chunk |
|
|
| return batch |
|
|
|
|
| def softmax(logits: np.ndarray) -> np.ndarray: |
| logits = logits.astype(np.float64) |
| probabilities = np.exp(logits - logits.max()) |
| return probabilities / probabilities.sum() |
|
|
|
|
| class TTSSuitabilityClassifier: |
| def __init__( |
| self, |
| model_path: str | Path = MODEL_PATH, |
| provider: str = "auto", |
| cuda_device_id: int = 0, |
| ) -> None: |
| available = set(ort.get_available_providers()) |
|
|
| if provider == "auto": |
| provider = "cuda" if "CUDAExecutionProvider" in available else "cpu" |
|
|
| if provider == "cuda": |
| if "CUDAExecutionProvider" not in available: |
| raise RuntimeError( |
| "CUDAExecutionProvider is unavailable. Install onnxruntime-gpu " |
| "or use provider='cpu'." |
| ) |
| providers = [ |
| ("CUDAExecutionProvider", {"device_id": cuda_device_id}), |
| "CPUExecutionProvider", |
| ] |
| elif provider == "cpu": |
| providers = ["CPUExecutionProvider"] |
| else: |
| raise ValueError("provider must be one of: auto, cpu, cuda") |
|
|
| self.session = ort.InferenceSession(str(model_path), providers=providers) |
| self.input_name = self.session.get_inputs()[0].name |
| self.output_names = [output.name for output in self.session.get_outputs()] |
|
|
| def predict(self, audio_path: str | Path) -> dict[str, object]: |
| path = Path(audio_path).expanduser().resolve() |
| waveform, sample_rate = load_audio(path) |
|
|
| if sample_rate != SAMPLE_RATE: |
| waveform = resample_audio(waveform, sample_rate, SAMPLE_RATE) |
|
|
| waveform = layer_norm(waveform) |
| batch = chunk_waveform(waveform, CHUNK_FRAMES) |
| logits = self.session.run( |
| self.output_names, {self.input_name: batch} |
| )[0].mean(axis=0) |
| probabilities = softmax(logits) |
| predicted_class = int(probabilities.argmax()) |
|
|
| return { |
| "path": str(path), |
| "label": "tts" if predicted_class == 1 else "not_tts", |
| "predicted_class": predicted_class, |
| "p_not_tts": float(probabilities[0]), |
| "p_tts": float(probabilities[1]), |
| "logits": [float(value) for value in logits], |
| } |
|
|
|
|
| def collect_audio_paths(path: Path) -> list[Path]: |
| path = path.expanduser().resolve() |
| if path.is_file(): |
| return [path] |
|
|
| return sorted( |
| child |
| for child in path.rglob("*") |
| if child.is_file() and child.suffix.lower() in SUPPORTED_EXTENSIONS |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="ONNX inference for the TTS suitability classifier." |
| ) |
| parser.add_argument("audio", type=Path, help="Audio file or directory.") |
| parser.add_argument( |
| "--model", type=Path, default=MODEL_PATH, help="Path to model.onnx." |
| ) |
| parser.add_argument( |
| "--provider", choices=("auto", "cpu", "cuda"), default="auto" |
| ) |
| parser.add_argument("--cuda-device-id", type=int, default=0) |
| args = parser.parse_args() |
|
|
| classifier = TTSSuitabilityClassifier( |
| args.model, args.provider, args.cuda_device_id |
| ) |
| paths = collect_audio_paths(args.audio) |
|
|
| if not paths: |
| raise RuntimeError(f"No supported audio files found at '{args.audio}'.") |
|
|
| for path in paths: |
| print(json.dumps(classifier.predict(path), ensure_ascii=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|