Instructions to use ajwestfield/MeiGen-MultiTalk with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ajwestfield/MeiGen-MultiTalk with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image, export_to_video # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ajwestfield/MeiGen-MultiTalk", dtype=torch.bfloat16, device_map="cuda") pipe.to("cuda") prompt = "A man with short gray hair plays a red electric guitar." image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" ) output = pipe(image=image, prompt=prompt).frames[0] export_to_video(output, "output.mp4") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import sys | |
| import torch | |
| import json | |
| import base64 | |
| import io | |
| from typing import Dict, Any, List | |
| from PIL import Image | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| """ | |
| Initialize the MultiTalk model handler | |
| This will load the actual MeiGen-AI/MeiGen-MultiTalk model | |
| """ | |
| logger.info(f"Initializing handler with path: {path}") | |
| # Import required libraries | |
| try: | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| logger.info("Successfully imported required libraries") | |
| except ImportError as e: | |
| logger.error(f"Failed to import required libraries: {e}") | |
| raise | |
| # Set device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Load the actual MeiGen-MultiTalk model | |
| try: | |
| model_id = "MeiGen-AI/MeiGen-MultiTalk" | |
| logger.info(f"Loading model from: {model_id}") | |
| # Try to load as a diffusion pipeline | |
| self.pipeline = DiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| # Enable memory optimizations | |
| if hasattr(self.pipeline, "enable_attention_slicing"): | |
| self.pipeline.enable_attention_slicing() | |
| logger.info("Enabled attention slicing") | |
| if hasattr(self.pipeline, "enable_vae_slicing"): | |
| self.pipeline.enable_vae_slicing() | |
| logger.info("Enabled VAE slicing") | |
| if hasattr(self.pipeline, "enable_model_cpu_offload"): | |
| self.pipeline.enable_model_cpu_offload() | |
| logger.info("Enabled model CPU offload") | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| # Try alternative loading method | |
| try: | |
| logger.info("Attempting alternative loading method...") | |
| from transformers import AutoModel, AutoTokenizer | |
| self.model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| self.pipeline = None | |
| logger.info("Model loaded with alternative method") | |
| except Exception as e2: | |
| logger.error(f"Alternative loading also failed: {e2}") | |
| # Create a dummy model for testing | |
| self.pipeline = None | |
| self.model = None | |
| logger.warning("Running in test mode without actual model") | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Process the inference request | |
| Args: | |
| data: Input data containing: | |
| - inputs: The input prompt or configuration | |
| - parameters: Additional generation parameters | |
| Returns: | |
| Dict containing the generated output or error message | |
| """ | |
| logger.info(f"Received request with data keys: {data.keys()}") | |
| try: | |
| # Extract inputs | |
| inputs = data.get("inputs", "") | |
| parameters = data.get("parameters", {}) | |
| logger.info(f"Processing inputs: {type(inputs)}") | |
| logger.info(f"Parameters: {parameters}") | |
| # Handle different input types | |
| if isinstance(inputs, str): | |
| prompt = inputs | |
| image = None | |
| elif isinstance(inputs, dict): | |
| prompt = inputs.get("prompt", "A person speaking") | |
| # Handle base64 encoded image if provided | |
| if "image" in inputs: | |
| try: | |
| image_data = base64.b64decode(inputs["image"]) | |
| image = Image.open(io.BytesIO(image_data)) | |
| logger.info("Loaded input image") | |
| except Exception as e: | |
| logger.error(f"Failed to decode image: {e}") | |
| image = None | |
| else: | |
| image = None | |
| else: | |
| prompt = str(inputs) | |
| image = None | |
| # Extract parameters with defaults | |
| num_inference_steps = parameters.get("num_inference_steps", 25) | |
| guidance_scale = parameters.get("guidance_scale", 7.5) | |
| height = parameters.get("height", 480) | |
| width = parameters.get("width", 640) | |
| num_frames = parameters.get("num_frames", 16) | |
| logger.info(f"Generation params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}, frames={num_frames}") | |
| # Generate output | |
| if self.pipeline is not None: | |
| logger.info("Generating with diffusion pipeline...") | |
| # Prepare generation kwargs | |
| gen_kwargs = { | |
| "prompt": prompt, | |
| "height": height, | |
| "width": width, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| } | |
| # Add image if available | |
| if image is not None: | |
| gen_kwargs["image"] = image | |
| # Add num_frames if the pipeline supports it | |
| if "num_frames" in self.pipeline.__call__.__code__.co_varnames: | |
| gen_kwargs["num_frames"] = num_frames | |
| # Generate | |
| with torch.no_grad(): | |
| result = self.pipeline(**gen_kwargs) | |
| # Process result | |
| if hasattr(result, "frames"): | |
| frames = result.frames | |
| if isinstance(frames, list) and len(frames) > 0: | |
| # Convert frames to base64 | |
| encoded_frames = [] | |
| for frame in frames[0] if isinstance(frames[0], list) else frames: | |
| if isinstance(frame, Image.Image): | |
| buffered = io.BytesIO() | |
| frame.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| encoded_frames.append(img_str) | |
| return { | |
| "frames": encoded_frames, | |
| "num_frames": len(encoded_frames), | |
| "message": "Video generated successfully" | |
| } | |
| elif hasattr(result, "images"): | |
| # Handle image output | |
| images = result.images | |
| encoded_images = [] | |
| for img in images: | |
| if isinstance(img, Image.Image): | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| encoded_images.append(img_str) | |
| return { | |
| "images": encoded_images, | |
| "num_images": len(encoded_images), | |
| "message": "Images generated successfully" | |
| } | |
| else: | |
| return { | |
| "message": "Generation completed", | |
| "prompt": prompt, | |
| "result_type": str(type(result)) | |
| } | |
| elif self.model is not None: | |
| logger.info("Generating with transformer model...") | |
| # Use transformer model | |
| if self.tokenizer: | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate(**inputs, max_length=100) | |
| result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return { | |
| "generated_text": result, | |
| "message": "Text generated successfully" | |
| } | |
| else: | |
| return { | |
| "message": "Model loaded but tokenizer not available", | |
| "prompt": prompt | |
| } | |
| else: | |
| # Test mode response | |
| logger.warning("Running in test mode - no actual generation") | |
| return { | |
| "message": "Handler is running in test mode", | |
| "prompt": prompt, | |
| "parameters": parameters, | |
| "status": "test_mode" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during inference: {e}") | |
| import traceback | |
| return { | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "message": "Error during generation" | |
| } |