Instructions to use dn6/RFDiffusion-3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use dn6/RFDiffusion-3 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("dn6/RFDiffusion-3", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Dhruv Nair. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Tuple, Union | |
| import torch | |
| from diffusers.utils import logging | |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState | |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam | |
| logger = logging.get_logger(__name__) | |
| def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]: | |
| """ | |
| Parse contig specification string. | |
| Supports formats like: | |
| - "100" -> 100 residues to design | |
| - "50-100" -> random length between 50-100 | |
| - "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed | |
| Returns: | |
| total_length: Total protein length | |
| motif_ranges: List of (start, end) for motif residues (0-indexed) | |
| """ | |
| parts = contig_str.split("/") | |
| total_length = 0 | |
| motif_ranges = [] | |
| for part in parts: | |
| part = part.strip() | |
| if not part: | |
| continue | |
| if part[0].isalpha(): | |
| chain = part[0] | |
| residue_spec = part[1:] | |
| if "-" in residue_spec: | |
| start, end = map(int, residue_spec.split("-")) | |
| else: | |
| start = end = int(residue_spec) | |
| motif_len = end - start + 1 | |
| motif_ranges.append((total_length, total_length + motif_len)) | |
| total_length += motif_len | |
| else: | |
| if "-" in part: | |
| min_len, max_len = map(int, part.split("-")) | |
| add_len = (min_len + max_len) // 2 | |
| else: | |
| add_len = int(part) | |
| total_length += add_len | |
| return total_length, motif_ranges | |
| class RFDiffusionInputStep(ModularPipelineBlocks): | |
| """ | |
| Input processing step for RFDiffusion. | |
| Parses contigs to prepare features for structure generation. | |
| """ | |
| model_name = "rfdiffusion" | |
| def description(self) -> str: | |
| return ( | |
| "Input processing step that:\n" | |
| " 1. Parses contig specification to determine protein length and design regions\n" | |
| " 2. Generates masks for motif positions\n" | |
| ) | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "contigs", | |
| required=True, | |
| type_hint=Union[str, List[str]], | |
| description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')", | |
| ), | |
| InputParam( | |
| "input_xyz", | |
| type_hint=torch.Tensor, | |
| description="Input coordinates for motif residues [N_motif, 3]", | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "motif_mask", | |
| type_hint=torch.Tensor, | |
| description="Boolean mask for motif (fixed) positions", | |
| ), | |
| OutputParam( | |
| "motif_xyz", | |
| type_hint=torch.Tensor, | |
| description="Coordinates for motif residues", | |
| ), | |
| OutputParam( | |
| "L", | |
| type_hint=int, | |
| description="Total length of the protein being designed", | |
| ), | |
| OutputParam( | |
| "batch_size", | |
| type_hint=int, | |
| description="Batch size (typically 1 for RFDiffusion)", | |
| ), | |
| OutputParam( | |
| "dtype", | |
| type_hint=torch.dtype, | |
| description="Data type for tensors", | |
| ), | |
| ] | |
| def check_inputs(self, components, block_state): | |
| if block_state.contigs is None: | |
| raise ValueError("`contigs` must be provided to specify protein design regions") | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| self.check_inputs(components, block_state) | |
| contigs = block_state.contigs | |
| input_xyz = block_state.input_xyz | |
| if isinstance(contigs, list): | |
| contig_str = "/".join(contigs) | |
| else: | |
| contig_str = contigs | |
| L, motif_ranges = parse_contig_string(contig_str) | |
| motif_mask = torch.zeros(L, dtype=torch.bool) | |
| for start, end in motif_ranges: | |
| motif_mask[start:end] = True | |
| if input_xyz is not None: | |
| motif_xyz = input_xyz | |
| else: | |
| motif_xyz = None | |
| block_state.motif_mask = motif_mask | |
| block_state.motif_xyz = motif_xyz | |
| block_state.L = L | |
| block_state.batch_size = 1 | |
| block_state.dtype = torch.float32 | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class RFDiffusionSetTimestepsStep(ModularPipelineBlocks): | |
| """ | |
| Set up the EDM noise schedule for RFDiffusion3. | |
| """ | |
| model_name = "rfdiffusion" | |
| def description(self) -> str: | |
| return "Sets up the EDM noise schedule matching the original inference sampler." | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), | |
| ] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "num_inference_steps", | |
| default=None, | |
| type_hint=int, | |
| description="Number of denoising steps (default: use scheduler config)", | |
| ), | |
| InputParam("L", required=True, type_hint=int, description="Protein length"), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "noise_schedule", | |
| type_hint=torch.Tensor, | |
| description="EDM noise schedule [num_timesteps] from high to low noise", | |
| ), | |
| OutputParam( | |
| "num_inference_steps", | |
| type_hint=int, | |
| description="Number of inference steps", | |
| ), | |
| ] | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| if hasattr(components, "scheduler") and components.scheduler is not None: | |
| noise_schedule = components.scheduler.get_noise_schedule() | |
| else: | |
| # Fallback: simple linear schedule | |
| noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200) | |
| block_state.noise_schedule = noise_schedule | |
| block_state.num_inference_steps = len(noise_schedule) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks): | |
| """ | |
| Prepare initial noised coordinates for RFDiffusion3. | |
| Matches the original _get_initial_structure: | |
| noise = c0 * randn(D, L, 3) | |
| noise[..., is_motif, :] = 0 | |
| X_L = noise + coord_motif | |
| """ | |
| model_name = "rfdiffusion" | |
| def description(self) -> str: | |
| return ( | |
| "Prepares initial coordinates by sampling Gaussian noise scaled by " | |
| "the first noise schedule value, matching the original sampler." | |
| ) | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), | |
| ComponentSpec("transformer", description="RFDiffusion transformer model"), | |
| ] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"), | |
| InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"), | |
| InputParam("L", required=True, type_hint=int, description="Protein length"), | |
| InputParam("motif_mask", required=True, type_hint=torch.Tensor), | |
| InputParam("motif_xyz", type_hint=torch.Tensor), | |
| InputParam("noise_schedule", required=True, type_hint=torch.Tensor), | |
| InputParam("dtype", type_hint=torch.dtype), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"), | |
| ] | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| L = block_state.L | |
| motif_mask = block_state.motif_mask | |
| motif_xyz = block_state.motif_xyz | |
| noise_schedule = block_state.noise_schedule | |
| dtype = block_state.dtype or torch.float32 | |
| generator = block_state.generator | |
| D = block_state.diffusion_batch_size or 1 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Initial noise scaled by first noise level (c0), matching original: | |
| # noise = c0 * randn(D, L, 3) | |
| c0 = noise_schedule[0] | |
| noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator) | |
| # Zero out noise for motif atoms | |
| if motif_mask is not None: | |
| noise[:, motif_mask] = 0.0 | |
| # Build initial coordinates: motif coords + noise | |
| coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device) | |
| if motif_xyz is not None and motif_mask is not None: | |
| motif_indices = motif_mask.nonzero(as_tuple=True)[0] | |
| for i, idx in enumerate(motif_indices): | |
| if i < motif_xyz.shape[0]: | |
| coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device) | |
| xyz = noise + coord_motif | |
| block_state.xyz = xyz | |
| self.set_block_state(state, block_state) | |
| return components, state | |