tts-batch-builder / tts_script.py
3outeille's picture
3outeille HF Staff
Sync tts_script.py
e965984 verified
Raw
History Blame Contribute Delete
4.03 kB
import argparse
import time
import os
import json
import torch
from datasets import load_dataset
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from soundfile import write
DEFAULT_PROMPT = "Hugging Face Jobs make cloud compute incredibly straightforward."
def run_batch_inference(prompts, run_id=None):
# 1. Set up input/output paths (we map this directory via HF Jobs)
base_output_dir = os.environ.get("OUTPUT_DIR", "/data/output")
# Each run gets its own subdirectory so concurrent jobs don't collide in the bucket.
run_id = run_id or os.environ.get("JOB_ID", "local")
output_dir = os.path.join(base_output_dir, run_id)
os.makedirs(output_dir, exist_ok=True)
print("πŸš€ Initializing TTS Model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
# Load speaker embeddings for voice styling (via parquet to avoid deprecated dataset script)
parquet_url = "https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet"
embeddings_dataset = load_dataset("parquet", data_files=parquet_url, split="train")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
metrics = []
print(f"πŸŽ™οΈ Starting batch inference on {device}...")
for idx, text in enumerate(prompts):
start_time = time.time()
# Tokenize and generate audio
inputs = processor(text=text, return_tensors="pt").to(device)
speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
generation_time = time.time() - start_time
audio_filename = f"audio_sample_{idx}.wav"
audio_path = os.path.join(output_dir, audio_filename)
# Save audio file
write(audio_path, speech.cpu().numpy(), samplerate=16000)
# Track metrics
metrics.append({
"sample_id": idx,
"text": text,
"generation_time_seconds": round(generation_time, 3),
"audio_file": audio_filename
})
print(f"βœ… Generated sample {idx} in {generation_time:.2f}s")
# 2. Save your metrics JSON
summary_metrics = {
"total_samples": len(prompts),
"hardware_used": device,
"average_generation_time": round(sum(m["generation_time_seconds"] for m in metrics) / len(metrics), 3),
"detailed_runs": metrics
}
metrics_path = os.path.join(output_dir, "inference_metrics.json")
with open(metrics_path, "w") as f:
json.dump(summary_metrics, f, indent=4)
print(f"πŸ“Š Metrics and audio saved to {output_dir}")
def parse_args():
parser = argparse.ArgumentParser(
description="Generate speech from a text prompt using SpeechT5.",
)
parser.add_argument(
"--run-id",
default=None,
help="Subdirectory name under OUTPUT_DIR for this run's outputs. "
"Defaults to $JOB_ID when set, else 'local'.",
)
parser.add_argument(
"--model-id",
default=None,
help="Hub model id requested by the caller. Currently ignored β€” the "
"script always runs SpeechT5. Wire branching here when adding "
"support for more TTS models.",
)
parser.add_argument(
"text",
nargs="?",
default=DEFAULT_PROMPT,
help="Sentence to synthesize. Defaults to a built-in demo prompt.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.model_id:
print(f"ℹ️ Received --model-id={args.model_id} (ignored; running SpeechT5).")
run_batch_inference([args.text], run_id=args.run_id)