NikiPshg's picture
Upload folder using huggingface_hub
6b0fab3 verified
Raw
History Blame Contribute Delete
5.14 kB
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import onnxruntime as ort
import soundfile as sf
from scipy.signal import resample_poly
MODEL_PATH = Path(__file__).with_name("model.onnx")
SAMPLE_RATE = 16_000
CHUNK_FRAMES = 160_000
SUPPORTED_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".m4a"}
def load_audio(path: Path) -> tuple[np.ndarray, int]:
audio, sample_rate = sf.read(path, dtype="float32", always_2d=True)
waveform = np.ascontiguousarray(audio.mean(axis=1), dtype=np.float32)
return waveform, int(sample_rate)
def resample_audio(
waveform: np.ndarray, source_rate: int, target_rate: int
) -> np.ndarray:
gcd = np.gcd(source_rate, target_rate)
waveform = resample_poly(
waveform, target_rate // gcd, source_rate // gcd
).astype(np.float32)
return np.ascontiguousarray(waveform)
def layer_norm(waveform: np.ndarray, eps: float = 1e-5) -> np.ndarray:
mean = waveform.mean(dtype=np.float64)
variance = waveform.var(dtype=np.float64)
return ((waveform - mean) / np.sqrt(variance + eps)).astype(np.float32)
def chunk_waveform(waveform: np.ndarray, chunk_frames: int) -> np.ndarray:
if chunk_frames <= 0 or waveform.size <= chunk_frames:
return waveform[None, :]
chunks = [
waveform[start : start + chunk_frames]
for start in range(0, waveform.size, chunk_frames)
]
max_length = max(chunk.size for chunk in chunks)
batch = np.zeros((len(chunks), max_length), dtype=np.float32)
for index, chunk in enumerate(chunks):
batch[index, : chunk.size] = chunk
return batch
def softmax(logits: np.ndarray) -> np.ndarray:
logits = logits.astype(np.float64)
probabilities = np.exp(logits - logits.max())
return probabilities / probabilities.sum()
class TTSSuitabilityClassifier:
def __init__(
self,
model_path: str | Path = MODEL_PATH,
provider: str = "auto",
cuda_device_id: int = 0,
) -> None:
available = set(ort.get_available_providers())
if provider == "auto":
provider = "cuda" if "CUDAExecutionProvider" in available else "cpu"
if provider == "cuda":
if "CUDAExecutionProvider" not in available:
raise RuntimeError(
"CUDAExecutionProvider is unavailable. Install onnxruntime-gpu "
"or use provider='cpu'."
)
providers = [
("CUDAExecutionProvider", {"device_id": cuda_device_id}),
"CPUExecutionProvider",
]
elif provider == "cpu":
providers = ["CPUExecutionProvider"]
else:
raise ValueError("provider must be one of: auto, cpu, cuda")
self.session = ort.InferenceSession(str(model_path), providers=providers)
self.input_name = self.session.get_inputs()[0].name
self.output_names = [output.name for output in self.session.get_outputs()]
def predict(self, audio_path: str | Path) -> dict[str, object]:
path = Path(audio_path).expanduser().resolve()
waveform, sample_rate = load_audio(path)
if sample_rate != SAMPLE_RATE:
waveform = resample_audio(waveform, sample_rate, SAMPLE_RATE)
waveform = layer_norm(waveform)
batch = chunk_waveform(waveform, CHUNK_FRAMES)
logits = self.session.run(
self.output_names, {self.input_name: batch}
)[0].mean(axis=0)
probabilities = softmax(logits)
predicted_class = int(probabilities.argmax())
return {
"path": str(path),
"label": "tts" if predicted_class == 1 else "not_tts",
"predicted_class": predicted_class,
"p_not_tts": float(probabilities[0]),
"p_tts": float(probabilities[1]),
"logits": [float(value) for value in logits],
}
def collect_audio_paths(path: Path) -> list[Path]:
path = path.expanduser().resolve()
if path.is_file():
return [path]
return sorted(
child
for child in path.rglob("*")
if child.is_file() and child.suffix.lower() in SUPPORTED_EXTENSIONS
)
def main() -> None:
parser = argparse.ArgumentParser(
description="ONNX inference for the TTS suitability classifier."
)
parser.add_argument("audio", type=Path, help="Audio file or directory.")
parser.add_argument(
"--model", type=Path, default=MODEL_PATH, help="Path to model.onnx."
)
parser.add_argument(
"--provider", choices=("auto", "cpu", "cuda"), default="auto"
)
parser.add_argument("--cuda-device-id", type=int, default=0)
args = parser.parse_args()
classifier = TTSSuitabilityClassifier(
args.model, args.provider, args.cuda_device_id
)
paths = collect_audio_paths(args.audio)
if not paths:
raise RuntimeError(f"No supported audio files found at '{args.audio}'.")
for path in paths:
print(json.dumps(classifier.predict(path), ensure_ascii=False))
if __name__ == "__main__":
main()