Spaces:
Paused
Paused
| import argparse | |
| import time | |
| import os | |
| import json | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from soundfile import write | |
| DEFAULT_PROMPT = "Hugging Face Jobs make cloud compute incredibly straightforward." | |
| def run_batch_inference(prompts, run_id=None): | |
| # 1. Set up input/output paths (we map this directory via HF Jobs) | |
| base_output_dir = os.environ.get("OUTPUT_DIR", "/data/output") | |
| # Each run gets its own subdirectory so concurrent jobs don't collide in the bucket. | |
| run_id = run_id or os.environ.get("JOB_ID", "local") | |
| output_dir = os.path.join(base_output_dir, run_id) | |
| os.makedirs(output_dir, exist_ok=True) | |
| print("π Initializing TTS Model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
| # Load speaker embeddings for voice styling (via parquet to avoid deprecated dataset script) | |
| parquet_url = "https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet" | |
| embeddings_dataset = load_dataset("parquet", data_files=parquet_url, split="train") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device) | |
| metrics = [] | |
| print(f"ποΈ Starting batch inference on {device}...") | |
| for idx, text in enumerate(prompts): | |
| start_time = time.time() | |
| # Tokenize and generate audio | |
| inputs = processor(text=text, return_tensors="pt").to(device) | |
| speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) | |
| generation_time = time.time() - start_time | |
| audio_filename = f"audio_sample_{idx}.wav" | |
| audio_path = os.path.join(output_dir, audio_filename) | |
| # Save audio file | |
| write(audio_path, speech.cpu().numpy(), samplerate=16000) | |
| # Track metrics | |
| metrics.append({ | |
| "sample_id": idx, | |
| "text": text, | |
| "generation_time_seconds": round(generation_time, 3), | |
| "audio_file": audio_filename | |
| }) | |
| print(f"β Generated sample {idx} in {generation_time:.2f}s") | |
| # 2. Save your metrics JSON | |
| summary_metrics = { | |
| "total_samples": len(prompts), | |
| "hardware_used": device, | |
| "average_generation_time": round(sum(m["generation_time_seconds"] for m in metrics) / len(metrics), 3), | |
| "detailed_runs": metrics | |
| } | |
| metrics_path = os.path.join(output_dir, "inference_metrics.json") | |
| with open(metrics_path, "w") as f: | |
| json.dump(summary_metrics, f, indent=4) | |
| print(f"π Metrics and audio saved to {output_dir}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate speech from a text prompt using SpeechT5.", | |
| ) | |
| parser.add_argument( | |
| "--run-id", | |
| default=None, | |
| help="Subdirectory name under OUTPUT_DIR for this run's outputs. " | |
| "Defaults to $JOB_ID when set, else 'local'.", | |
| ) | |
| parser.add_argument( | |
| "--model-id", | |
| default=None, | |
| help="Hub model id requested by the caller. Currently ignored β the " | |
| "script always runs SpeechT5. Wire branching here when adding " | |
| "support for more TTS models.", | |
| ) | |
| parser.add_argument( | |
| "text", | |
| nargs="?", | |
| default=DEFAULT_PROMPT, | |
| help="Sentence to synthesize. Defaults to a built-in demo prompt.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.model_id: | |
| print(f"βΉοΈ Received --model-id={args.model_id} (ignored; running SpeechT5).") | |
| run_batch_inference([args.text], run_id=args.run_id) |