| import os |
| import sys |
| import json |
| import base64 |
| import tempfile |
| import shutil |
| from typing import Dict, Any, Optional, List |
| import torch |
| import numpy as np |
| from huggingface_hub import snapshot_download, hf_hub_download |
| import logging |
| import subprocess |
| import warnings |
| import cv2 |
| from PIL import Image |
| import requests |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class EndpointHandler: |
| """ |
| HuggingFace Inference Endpoint handler for Wav2Lip-based lip sync video generation. |
| Uses actual Wav2Lip model for proper lip synchronization. |
| """ |
|
|
| def __init__(self, path=""): |
| """ |
| Initialize the handler with Wav2Lip model for real lip sync. |
| """ |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Initializing Wav2Lip Handler on device: {self.device}") |
|
|
| |
| self.weights_dir = "/data/weights" |
| os.makedirs(self.weights_dir, exist_ok=True) |
|
|
| |
| self._download_wav2lip_model() |
|
|
| |
| self._initialize_wav2lip() |
|
|
| logger.info("Wav2Lip Handler initialization complete") |
|
|
| def _download_wav2lip_model(self): |
| """Download Wav2Lip model and checkpoints.""" |
| logger.info("Downloading Wav2Lip models...") |
|
|
| try: |
| |
| wav2lip_checkpoint = hf_hub_download( |
| repo_id="camenduru/Wav2Lip", |
| filename="wav2lip_gan.pth", |
| local_dir=self.weights_dir, |
| local_dir_use_symlinks=False |
| ) |
| logger.info(f"Downloaded Wav2Lip checkpoint: {wav2lip_checkpoint}") |
|
|
| |
| s3fd_model = hf_hub_download( |
| repo_id="camenduru/Wav2Lip", |
| filename="s3fd.pth", |
| local_dir=self.weights_dir, |
| local_dir_use_symlinks=False |
| ) |
| logger.info(f"Downloaded face detection model: {s3fd_model}") |
|
|
| except Exception as e: |
| logger.error(f"Failed to download Wav2Lip models: {e}") |
| |
| try: |
| logger.info("Trying alternative model source...") |
| |
| wav2lip_checkpoint = hf_hub_download( |
| repo_id="commanderx/Wav2Lip-HD", |
| filename="wav2lip_gan.pth", |
| local_dir=self.weights_dir, |
| local_dir_use_symlinks=False |
| ) |
| logger.info(f"Downloaded Wav2Lip HD checkpoint: {wav2lip_checkpoint}") |
| except: |
| logger.warning("Could not download Wav2Lip models, will use basic implementation") |
|
|
| def _initialize_wav2lip(self): |
| """Initialize Wav2Lip model.""" |
| logger.info("Initializing Wav2Lip model...") |
|
|
| try: |
| |
| sys.path.append(self.weights_dir) |
|
|
| |
| checkpoint_path = os.path.join(self.weights_dir, "wav2lip_gan.pth") |
| if os.path.exists(checkpoint_path): |
| logger.info(f"Found Wav2Lip checkpoint at {checkpoint_path}") |
| self.wav2lip_checkpoint = checkpoint_path |
| self.use_wav2lip = True |
| else: |
| logger.warning("Wav2Lip checkpoint not found, using fallback") |
| self.use_wav2lip = False |
|
|
| |
| s3fd_path = os.path.join(self.weights_dir, "s3fd.pth") |
| if os.path.exists(s3fd_path): |
| logger.info(f"Found face detection model at {s3fd_path}") |
| self.face_detect_path = s3fd_path |
| else: |
| logger.warning("Face detection model not found") |
| self.face_detect_path = None |
|
|
| except Exception as e: |
| logger.error(f"Failed to initialize Wav2Lip: {e}") |
| self.use_wav2lip = False |
|
|
| def _download_media(self, url: str, media_type: str = "image") -> str: |
| """Download media from URL or handle base64 data URL.""" |
| |
| if url.startswith('data:'): |
| logger.info(f"Processing base64 {media_type}") |
|
|
| |
| header, data = url.split(',', 1) |
|
|
| |
| if media_type == "image": |
| ext = '.jpg' if 'jpeg' in header or 'jpg' in header else '.png' |
| else: |
| ext = '.mp3' if 'mp3' in header or 'mpeg' in header else '.wav' |
|
|
| |
| media_data = base64.b64decode(data) |
|
|
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: |
| tmp_file.write(media_data) |
| return tmp_file.name |
| else: |
| |
| logger.info(f"Downloading {media_type} from URL...") |
| response = requests.get(url, stream=True, timeout=30) |
| response.raise_for_status() |
|
|
| |
| content_type = response.headers.get('content-type', '') |
| if media_type == "image": |
| ext = '.jpg' if 'jpeg' in content_type else '.png' |
| else: |
| ext = '.mp3' if 'mp3' in content_type else '.wav' |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: |
| for chunk in response.iter_content(chunk_size=8192): |
| tmp_file.write(chunk) |
| return tmp_file.name |
|
|
| def _prepare_image_for_aspect_ratio(self, image_path: str, aspect_ratio: str = "16:9") -> str: |
| """Prepare image with correct aspect ratio.""" |
| logger.info(f"Preparing image with aspect ratio: {aspect_ratio}") |
|
|
| image = Image.open(image_path).convert('RGB') |
|
|
| |
| if aspect_ratio == "9:16": |
| |
| target_size = (480, 854) |
| elif aspect_ratio == "1:1": |
| |
| target_size = (640, 640) |
| else: |
| |
| target_size = (854, 480) |
|
|
| logger.info(f"Resizing image to {target_size[0]}x{target_size[1]}") |
| image = image.resize(target_size, Image.Resampling.LANCZOS) |
|
|
| |
| output_path = tempfile.mktemp(suffix='.jpg') |
| image.save(output_path, 'JPEG', quality=95) |
|
|
| return output_path |
|
|
| def _generate_lip_sync_video( |
| self, |
| image_path: str, |
| audio_path: str, |
| aspect_ratio: str = "16:9", |
| duration: int = 5 |
| ) -> str: |
| """Generate lip-synced video using Wav2Lip or fallback method.""" |
|
|
| if self.use_wav2lip and self.wav2lip_checkpoint: |
| logger.info("Using Wav2Lip for lip sync generation") |
| return self._generate_with_wav2lip(image_path, audio_path, aspect_ratio, duration) |
| else: |
| logger.info("Using enhanced fallback for lip sync generation") |
| return self._generate_with_enhanced_fallback(image_path, audio_path, aspect_ratio, duration) |
|
|
| def _generate_with_wav2lip( |
| self, |
| image_path: str, |
| audio_path: str, |
| aspect_ratio: str, |
| duration: int |
| ) -> str: |
| """Generate video using actual Wav2Lip model.""" |
| logger.info("Generating with Wav2Lip model...") |
|
|
| try: |
| |
| prepared_image = self._prepare_image_for_aspect_ratio(image_path, aspect_ratio) |
|
|
| |
| temp_video = tempfile.mktemp(suffix='.mp4') |
|
|
| |
| cmd = [ |
| 'ffmpeg', '-loop', '1', '-i', prepared_image, |
| '-c:v', 'libx264', '-t', str(duration), |
| '-pix_fmt', 'yuv420p', '-vf', 'fps=25', |
| '-y', temp_video |
| ] |
|
|
| result = subprocess.run(cmd, capture_output=True, text=True) |
| if result.returncode != 0: |
| logger.error(f"FFmpeg failed: {result.stderr}") |
| raise Exception("Failed to create base video") |
|
|
| |
| output_video = tempfile.mktemp(suffix='.mp4') |
|
|
| |
| wav2lip_cmd = [ |
| 'python', '-m', 'wav2lip.inference', |
| '--checkpoint_path', self.wav2lip_checkpoint, |
| '--face', temp_video, |
| '--audio', audio_path, |
| '--outfile', output_video, |
| '--resize_factor', '1', |
| '--nosmooth' |
| ] |
|
|
| logger.info("Running Wav2Lip inference...") |
| result = subprocess.run(wav2lip_cmd, capture_output=True, text=True) |
|
|
| if result.returncode == 0: |
| logger.info("Wav2Lip generation successful") |
| os.unlink(temp_video) |
| os.unlink(prepared_image) |
| return output_video |
| else: |
| logger.error(f"Wav2Lip failed: {result.stderr}") |
| |
| os.unlink(temp_video) |
| return self._generate_with_enhanced_fallback(image_path, audio_path, aspect_ratio, duration) |
|
|
| except Exception as e: |
| logger.error(f"Wav2Lip generation error: {e}") |
| return self._generate_with_enhanced_fallback(image_path, audio_path, aspect_ratio, duration) |
|
|
| def _generate_with_enhanced_fallback( |
| self, |
| image_path: str, |
| audio_path: str, |
| aspect_ratio: str, |
| duration: int |
| ) -> str: |
| """Enhanced fallback generation with better lip sync simulation.""" |
| logger.info("Using enhanced fallback for lip sync...") |
|
|
| |
| prepared_image = self._prepare_image_for_aspect_ratio(image_path, aspect_ratio) |
|
|
| |
| image = cv2.imread(prepared_image) |
| h, w = image.shape[:2] |
|
|
| |
| fps = 25 |
| num_frames = duration * fps |
| frames = [] |
|
|
| |
| import librosa |
| try: |
| audio, sr = librosa.load(audio_path, duration=duration) |
|
|
| |
| hop_length = int(sr / fps) |
| energy = librosa.feature.rms(y=audio, hop_length=hop_length)[0] |
|
|
| |
| if len(energy) > 0: |
| energy = (energy - energy.min()) / (energy.max() - energy.min() + 1e-6) |
|
|
| |
| if len(energy) != num_frames: |
| x_old = np.linspace(0, 1, len(energy)) |
| x_new = np.linspace(0, 1, num_frames) |
| energy = np.interp(x_new, x_old, energy) |
|
|
| except Exception as e: |
| logger.warning(f"Audio analysis failed: {e}") |
| |
| energy = np.random.random(num_frames) * 0.5 + 0.3 |
|
|
| |
| for frame_idx in range(num_frames): |
| frame = image.copy() |
|
|
| |
| frame_energy = energy[frame_idx] if frame_idx < len(energy) else 0.3 |
|
|
| |
| if frame_energy > 0.2: |
| |
| mouth_y = int(h * 0.62) |
| mouth_x = int(w * 0.5) |
|
|
| |
| mouth_height = int(h * 0.03 * frame_energy) |
| mouth_width = int(w * 0.06 * (1 + frame_energy * 0.3)) |
|
|
| |
| cv2.ellipse(frame, |
| (mouth_x, mouth_y), |
| (mouth_width, mouth_height), |
| 0, 0, 180, |
| (40, 30, 30), -1) |
|
|
| |
| if frame_idx % 30 < 15: |
| M = np.float32([[1, 0, np.sin(frame_idx * 0.1) * 2], [0, 1, 0]]) |
| frame = cv2.warpAffine(frame, M, (w, h), borderMode=cv2.BORDER_REFLECT_101) |
|
|
| frames.append(frame) |
|
|
| |
| output_video = tempfile.mktemp(suffix='.mp4') |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(output_video, fourcc, fps, (w, h)) |
|
|
| for frame in frames: |
| out.write(frame) |
|
|
| out.release() |
|
|
| |
| final_video = tempfile.mktemp(suffix='.mp4') |
| cmd = [ |
| 'ffmpeg', '-i', output_video, '-i', audio_path, |
| '-c:v', 'libx264', '-c:a', 'aac', |
| '-shortest', '-y', final_video |
| ] |
|
|
| result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
| if result.returncode == 0: |
| os.unlink(output_video) |
| os.unlink(prepared_image) |
| return final_video |
| else: |
| logger.error(f"Audio merge failed: {result.stderr}") |
| os.unlink(prepared_image) |
| return output_video |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process the inference request for lip sync video generation. |
| """ |
| logger.info("Processing lip sync video generation request") |
|
|
| try: |
| |
| if "inputs" in data: |
| input_data = data["inputs"] |
| else: |
| input_data = data |
|
|
| |
| image_url = input_data.get("image_url") |
| audio_url = input_data.get("audio_url") |
| prompt = input_data.get("prompt", "") |
| seconds = input_data.get("seconds", 5) |
| aspect_ratio = input_data.get("aspect_ratio", "16:9") |
|
|
| |
| if not image_url or not audio_url: |
| return { |
| "error": "Missing required parameters: image_url and audio_url", |
| "success": False |
| } |
|
|
| logger.info(f"Generating {seconds}s video with aspect ratio {aspect_ratio}") |
|
|
| |
| image_path = self._download_media(image_url, "image") |
| audio_path = self._download_media(audio_url, "audio") |
|
|
| try: |
| |
| video_path = self._generate_lip_sync_video( |
| image_path=image_path, |
| audio_path=audio_path, |
| aspect_ratio=aspect_ratio, |
| duration=seconds |
| ) |
|
|
| |
| with open(video_path, "rb") as video_file: |
| video_base64 = base64.b64encode(video_file.read()).decode("utf-8") |
|
|
| |
| video_size = os.path.getsize(video_path) |
| logger.info(f"Generated video size: {video_size / 1024 / 1024:.2f} MB") |
|
|
| |
| if aspect_ratio == "9:16": |
| resolution = "480x854" |
| elif aspect_ratio == "1:1": |
| resolution = "640x640" |
| else: |
| resolution = "854x480" |
|
|
| |
| for path in [image_path, audio_path, video_path]: |
| if os.path.exists(path): |
| try: |
| os.unlink(path) |
| except: |
| pass |
|
|
| return { |
| "success": True, |
| "video": video_base64, |
| "format": "mp4", |
| "duration": seconds, |
| "resolution": resolution, |
| "aspect_ratio": aspect_ratio, |
| "fps": 25, |
| "size_mb": round(video_size / 1024 / 1024, 2), |
| "message": f"Generated {seconds}s lip-sync video at {resolution}", |
| "model": "Wav2Lip" if self.use_wav2lip else "Enhanced Fallback" |
| } |
|
|
| finally: |
| |
| for path in [image_path, audio_path]: |
| if os.path.exists(path): |
| try: |
| os.unlink(path) |
| except: |
| pass |
|
|
| except Exception as e: |
| logger.error(f"Request processing failed: {str(e)}", exc_info=True) |
| return { |
| "error": f"Video generation failed: {str(e)}", |
| "success": False |
| } |