| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import importlib.util |
| import math |
| from collections.abc import Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
| from monai.networks.blocks import Convolution, MLPBlock |
| from monai.networks.layers.factories import Pool |
| from monai.utils import ensure_tuple_rep |
| from torch import nn |
|
|
| |
| if importlib.util.find_spec("xformers") is not None: |
| import xformers |
| import xformers.ops |
|
|
| has_xformers = True |
| else: |
| xformers = None |
| has_xformers = False |
|
|
|
|
| |
| |
| |
|
|
| __all__ = ["DiffusionModelUNet"] |
|
|
|
|
| def zero_module(module: nn.Module) -> nn.Module: |
| """ |
| Zero out the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| class CrossAttention(nn.Module): |
| """ |
| A cross attention layer. |
| |
| Args: |
| query_dim: number of channels in the query. |
| cross_attention_dim: number of channels in the context. |
| num_attention_heads: number of heads to use for multi-head attention. |
| num_head_channels: number of channels in each head. |
| dropout: dropout probability to use. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| cross_attention_dim: int | None = None, |
| num_attention_heads: int = 8, |
| num_head_channels: int = 64, |
| dropout: float = 0.0, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.use_flash_attention = use_flash_attention |
| inner_dim = num_head_channels * num_attention_heads |
| cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
|
|
| self.scale = 1 / math.sqrt(num_head_channels) |
| self.num_heads = num_attention_heads |
|
|
| self.upcast_attention = upcast_attention |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) |
|
|
| def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. |
| """ |
| batch_size, seq_len, dim = x.shape |
| x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) |
| x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) |
| return x |
|
|
| def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: |
| """Combine the output of the attention heads back into the hidden state dimension.""" |
| batch_size, seq_len, dim = x.shape |
| x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) |
| x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) |
| return x |
|
|
| def _memory_efficient_attention_xformers( |
| self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| query = query.contiguous() |
| key = key.contiguous() |
| value = value.contiguous() |
| x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) |
| return x |
|
|
| def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
| dtype = query.dtype |
| if self.upcast_attention: |
| query = query.float() |
| key = key.float() |
|
|
| attention_scores = torch.baddbmm( |
| torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), |
| query, |
| key.transpose(-1, -2), |
| beta=0, |
| alpha=self.scale, |
| ) |
| attention_probs = attention_scores.softmax(dim=-1) |
| attention_probs = attention_probs.to(dtype=dtype) |
|
|
| x = torch.bmm(attention_probs, value) |
| return x |
|
|
| def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
| query = self.to_q(x) |
| context = context if context is not None else x |
| key = self.to_k(context) |
| value = self.to_v(context) |
|
|
| |
| query = self.reshape_heads_to_batch_dim(query) |
| key = self.reshape_heads_to_batch_dim(key) |
| value = self.reshape_heads_to_batch_dim(value) |
|
|
| if self.use_flash_attention: |
| x = self._memory_efficient_attention_xformers(query, key, value) |
| else: |
| x = self._attention(query, key, value) |
|
|
| x = self.reshape_batch_dim_to_heads(x) |
| x = x.to(query.dtype) |
|
|
| return self.to_out(x) |
|
|
|
|
| class BasicTransformerBlock(nn.Module): |
| """ |
| A basic Transformer block. |
| |
| Args: |
| num_channels: number of channels in the input and output. |
| num_attention_heads: number of heads to use for multi-head attention. |
| num_head_channels: number of channels in each attention head. |
| dropout: dropout probability to use. |
| cross_attention_dim: size of the context vector for cross attention. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| num_channels: int, |
| num_attention_heads: int, |
| num_head_channels: int, |
| dropout: float = 0.0, |
| cross_attention_dim: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.attn1 = CrossAttention( |
| query_dim=num_channels, |
| num_attention_heads=num_attention_heads, |
| num_head_channels=num_head_channels, |
| dropout=dropout, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) |
| self.attn2 = CrossAttention( |
| query_dim=num_channels, |
| cross_attention_dim=cross_attention_dim, |
| num_attention_heads=num_attention_heads, |
| num_head_channels=num_head_channels, |
| dropout=dropout, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| self.norm1 = nn.LayerNorm(num_channels) |
| self.norm2 = nn.LayerNorm(num_channels) |
| self.norm3 = nn.LayerNorm(num_channels) |
|
|
| def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
| |
| x = self.attn1(self.norm1(x)) + x |
|
|
| |
| x = self.attn2(self.norm2(x), context=context) + x |
|
|
| |
| x = self.ff(self.norm3(x)) + x |
| return x |
|
|
|
|
| class SpatialTransformer(nn.Module): |
| """ |
| Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply |
| standard transformer action. Finally, reshape to image. |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of channels in the input and output. |
| num_attention_heads: number of heads to use for multi-head attention. |
| num_head_channels: number of channels in each attention head. |
| num_layers: number of layers of Transformer blocks to use. |
| dropout: dropout probability to use. |
| norm_num_groups: number of groups for the normalization. |
| norm_eps: epsilon for the normalization. |
| cross_attention_dim: number of context dimensions to use. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| num_attention_heads: int, |
| num_head_channels: int, |
| num_layers: int = 1, |
| dropout: float = 0.0, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| cross_attention_dim: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.spatial_dims = spatial_dims |
| self.in_channels = in_channels |
| inner_dim = num_attention_heads * num_head_channels |
|
|
| self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) |
|
|
| self.proj_in = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=inner_dim, |
| strides=1, |
| kernel_size=1, |
| padding=0, |
| conv_only=True, |
| ) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| num_channels=inner_dim, |
| num_attention_heads=num_attention_heads, |
| num_head_channels=num_head_channels, |
| dropout=dropout, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.proj_out = zero_module( |
| Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=inner_dim, |
| out_channels=in_channels, |
| strides=1, |
| kernel_size=1, |
| padding=0, |
| conv_only=True, |
| ) |
| ) |
|
|
| def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: |
| |
| batch = channel = height = width = depth = -1 |
| if self.spatial_dims == 2: |
| batch, channel, height, width = x.shape |
| if self.spatial_dims == 3: |
| batch, channel, height, width, depth = x.shape |
|
|
| residual = x |
| x = self.norm(x) |
| x = self.proj_in(x) |
|
|
| inner_dim = x.shape[1] |
|
|
| if self.spatial_dims == 2: |
| x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| if self.spatial_dims == 3: |
| x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) |
|
|
| for block in self.transformer_blocks: |
| x = block(x, context=context) |
|
|
| if self.spatial_dims == 2: |
| x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| if self.spatial_dims == 3: |
| x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() |
|
|
| x = self.proj_out(x) |
| return x + residual |
|
|
|
|
| class AttentionBlock(nn.Module): |
| """ |
| An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to |
| compute attention. |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| num_channels: number of input channels. |
| num_head_channels: number of channels in each attention head. |
| norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of |
| channels is divisible by this number. |
| norm_eps: epsilon value to use for the normalisation. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| num_channels: int, |
| num_head_channels: int | None = None, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.use_flash_attention = use_flash_attention |
| self.spatial_dims = spatial_dims |
| self.num_channels = num_channels |
|
|
| self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 |
| self.scale = 1 / math.sqrt(num_channels / self.num_heads) |
|
|
| self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) |
|
|
| self.to_q = nn.Linear(num_channels, num_channels) |
| self.to_k = nn.Linear(num_channels, num_channels) |
| self.to_v = nn.Linear(num_channels, num_channels) |
|
|
| self.proj_attn = nn.Linear(num_channels, num_channels) |
|
|
| def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: |
| batch_size, seq_len, dim = x.shape |
| x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) |
| x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) |
| return x |
|
|
| def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: |
| batch_size, seq_len, dim = x.shape |
| x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) |
| x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) |
| return x |
|
|
| def _memory_efficient_attention_xformers( |
| self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| ) -> torch.Tensor: |
| query = query.contiguous() |
| key = key.contiguous() |
| value = value.contiguous() |
| x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) |
| return x |
|
|
| def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
| attention_scores = torch.baddbmm( |
| torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), |
| query, |
| key.transpose(-1, -2), |
| beta=0, |
| alpha=self.scale, |
| ) |
| attention_probs = attention_scores.softmax(dim=-1) |
| x = torch.bmm(attention_probs, value) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
|
|
| batch = channel = height = width = depth = -1 |
| if self.spatial_dims == 2: |
| batch, channel, height, width = x.shape |
| if self.spatial_dims == 3: |
| batch, channel, height, width, depth = x.shape |
|
|
| |
| x = self.norm(x) |
|
|
| if self.spatial_dims == 2: |
| x = x.view(batch, channel, height * width).transpose(1, 2) |
| if self.spatial_dims == 3: |
| x = x.view(batch, channel, height * width * depth).transpose(1, 2) |
|
|
| |
| query = self.to_q(x) |
| key = self.to_k(x) |
| value = self.to_v(x) |
|
|
| |
| query = self.reshape_heads_to_batch_dim(query) |
| key = self.reshape_heads_to_batch_dim(key) |
| value = self.reshape_heads_to_batch_dim(value) |
|
|
| if self.use_flash_attention: |
| x = self._memory_efficient_attention_xformers(query, key, value) |
| else: |
| x = self._attention(query, key, value) |
|
|
| x = self.reshape_batch_dim_to_heads(x) |
| x = x.to(query.dtype) |
|
|
| if self.spatial_dims == 2: |
| x = x.transpose(-1, -2).reshape(batch, channel, height, width) |
| if self.spatial_dims == 3: |
| x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) |
|
|
| return x + residual |
|
|
|
|
| def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: |
| """ |
| Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic |
| Models" https://arxiv.org/abs/2006.11239. |
| |
| Args: |
| timesteps: a 1-D Tensor of N indices, one per batch element. |
| embedding_dim: the dimension of the output. |
| max_period: controls the minimum frequency of the embeddings. |
| """ |
| if timesteps.ndim != 1: |
| raise ValueError("Timesteps should be a 1d-array") |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) |
| freqs = torch.exp(exponent / half_dim) |
|
|
| args = timesteps[:, None].float() * freqs[None, :] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) |
|
|
| return embedding |
|
|
|
|
| class Downsample(nn.Module): |
| """ |
| Downsampling layer. |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| num_channels: number of input channels. |
| use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is |
| False, the number of output channels must be the same as the number of input channels. |
| out_channels: number of output channels. |
| padding: controls the amount of implicit zero-paddings on both sides for padding number of points |
| for each dimension. |
| """ |
|
|
| def __init__( |
| self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 |
| ) -> None: |
| super().__init__() |
| self.num_channels = num_channels |
| self.out_channels = out_channels or num_channels |
| self.use_conv = use_conv |
| if use_conv: |
| self.op = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=self.num_channels, |
| out_channels=self.out_channels, |
| strides=2, |
| kernel_size=3, |
| padding=padding, |
| conv_only=True, |
| ) |
| else: |
| if self.num_channels != self.out_channels: |
| raise ValueError("num_channels and out_channels must be equal when use_conv=False") |
| self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) |
|
|
| def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: |
| del emb |
| if x.shape[1] != self.num_channels: |
| raise ValueError( |
| f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " |
| f"({self.num_channels})" |
| ) |
| return self.op(x) |
|
|
|
|
| class Upsample(nn.Module): |
| """ |
| Upsampling layer with an optional convolution. |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| num_channels: number of input channels. |
| use_conv: if True uses Convolution instead of Pool average to perform downsampling. |
| out_channels: number of output channels. |
| padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each |
| dimension. |
| """ |
|
|
| def __init__( |
| self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 |
| ) -> None: |
| super().__init__() |
| self.num_channels = num_channels |
| self.out_channels = out_channels or num_channels |
| self.use_conv = use_conv |
| if use_conv: |
| self.conv = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=self.num_channels, |
| out_channels=self.out_channels, |
| strides=1, |
| kernel_size=3, |
| padding=padding, |
| conv_only=True, |
| ) |
| else: |
| self.conv = None |
|
|
| def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: |
| del emb |
| if x.shape[1] != self.num_channels: |
| raise ValueError("Input channels should be equal to num_channels") |
|
|
| |
| |
| dtype = x.dtype |
| if dtype == torch.bfloat16: |
| x = x.to(torch.float32) |
|
|
| x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
|
|
| |
| if dtype == torch.bfloat16: |
| x = x.to(dtype) |
|
|
| if self.use_conv: |
| x = self.conv(x) |
| return x |
|
|
|
|
| class ResnetBlock(nn.Module): |
| """ |
| Residual block with timestep conditioning. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| temb_channels: number of timestep embedding channels. |
| out_channels: number of output channels. |
| up: if True, performs upsampling. |
| down: if True, performs downsampling. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| temb_channels: int, |
| out_channels: int | None = None, |
| up: bool = False, |
| down: bool = False, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.spatial_dims = spatial_dims |
| self.channels = in_channels |
| self.emb_channels = temb_channels |
| self.out_channels = out_channels or in_channels |
| self.up = up |
| self.down = down |
|
|
| self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) |
| self.nonlinearity = nn.SiLU() |
| self.conv1 = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=self.out_channels, |
| strides=1, |
| kernel_size=3, |
| padding=1, |
| conv_only=True, |
| dilation=3 |
| ) |
|
|
| self.upsample = self.downsample = None |
| if self.up: |
| self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) |
| elif down: |
| self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) |
|
|
| self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) |
|
|
| self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) |
| self.conv2 = zero_module( |
| Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=self.out_channels, |
| out_channels=self.out_channels, |
| strides=1, |
| kernel_size=3, |
| padding=1, |
| conv_only=True, |
| dilation=2 |
| ) |
| ) |
|
|
| if self.out_channels == in_channels: |
| self.skip_connection = nn.Identity() |
| else: |
| self.skip_connection = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=self.out_channels, |
| strides=1, |
| kernel_size=1, |
| padding=0, |
| conv_only=True, |
| dilation=1 |
| ) |
|
|
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: |
| h = x |
| h = self.norm1(h) |
| h = self.nonlinearity(h) |
|
|
| if self.upsample is not None: |
| if h.shape[0] >= 64: |
| x = x.contiguous() |
| h = h.contiguous() |
| x = self.upsample(x) |
| h = self.upsample(h) |
| elif self.downsample is not None: |
| x = self.downsample(x) |
| h = self.downsample(h) |
|
|
| h = self.conv1(h) |
|
|
| if self.spatial_dims == 2: |
| temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] |
| else: |
| temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] |
| h = h + temb |
|
|
| h = self.norm2(h) |
| h = self.nonlinearity(h) |
| h = self.conv2(h) |
|
|
| return self.skip_connection(x) + h |
|
|
|
|
| class DownBlock(nn.Module): |
| """ |
| Unet's down block containing resnet and downsamplers blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_downsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for downsampling. |
| downsample_padding: padding used in the downsampling block. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_downsample: bool = True, |
| resblock_updown: bool = False, |
| downsample_padding: int = 1, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
|
|
| resnets = [] |
|
|
| for i in range(num_res_blocks): |
| in_channels = in_channels if i == 0 else out_channels |
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
|
|
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_downsample: |
| if resblock_updown: |
| self.downsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| down=True, |
| ) |
| else: |
| self.downsampler = Downsample( |
| spatial_dims=spatial_dims, |
| num_channels=out_channels, |
| use_conv=True, |
| out_channels=out_channels, |
| padding=downsample_padding, |
| ) |
| else: |
| self.downsampler = None |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
| ) -> tuple[torch.Tensor, list[torch.Tensor]]: |
| del context |
| output_states = [] |
|
|
| for resnet in self.resnets: |
| hidden_states = resnet(hidden_states, temb) |
| output_states.append(hidden_states) |
|
|
| if self.downsampler is not None: |
| hidden_states = self.downsampler(hidden_states, temb) |
| output_states.append(hidden_states) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class AttnDownBlock(nn.Module): |
| """ |
| Unet's down block containing resnet, downsamplers and self-attention blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_downsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for downsampling. |
| downsample_padding: padding used in the downsampling block. |
| num_head_channels: number of channels in each attention head. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_downsample: bool = True, |
| resblock_updown: bool = False, |
| downsample_padding: int = 1, |
| num_head_channels: int = 1, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
|
|
| resnets = [] |
| attentions = [] |
|
|
| for i in range(num_res_blocks): |
| in_channels = in_channels if i == 0 else out_channels |
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
| attentions.append( |
| AttentionBlock( |
| spatial_dims=spatial_dims, |
| num_channels=out_channels, |
| num_head_channels=num_head_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| use_flash_attention=use_flash_attention, |
| ) |
| ) |
|
|
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_downsample: |
| if resblock_updown: |
| self.downsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| down=True, |
| ) |
| else: |
| self.downsampler = Downsample( |
| spatial_dims=spatial_dims, |
| num_channels=out_channels, |
| use_conv=True, |
| out_channels=out_channels, |
| padding=downsample_padding, |
| ) |
| else: |
| self.downsampler = None |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
| ) -> tuple[torch.Tensor, list[torch.Tensor]]: |
| del context |
| output_states = [] |
|
|
| for resnet, attn in zip(self.resnets, self.attentions): |
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states) |
| output_states.append(hidden_states) |
|
|
| if self.downsampler is not None: |
| hidden_states = self.downsampler(hidden_states, temb) |
| output_states.append(hidden_states) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class CrossAttnDownBlock(nn.Module): |
| """ |
| Unet's down block containing resnet, downsamplers and cross-attention blocks. |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_downsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for downsampling. |
| downsample_padding: padding used in the downsampling block. |
| num_head_channels: number of channels in each attention head. |
| transformer_num_layers: number of layers of Transformer blocks to use. |
| cross_attention_dim: number of context dimensions to use. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_downsample: bool = True, |
| resblock_updown: bool = False, |
| downsample_padding: int = 1, |
| num_head_channels: int = 1, |
| transformer_num_layers: int = 1, |
| cross_attention_dim: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
|
|
| resnets = [] |
| attentions = [] |
|
|
| for i in range(num_res_blocks): |
| in_channels = in_channels if i == 0 else out_channels |
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
|
|
| attentions.append( |
| SpatialTransformer( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| num_attention_heads=out_channels // num_head_channels, |
| num_head_channels=num_head_channels, |
| num_layers=transformer_num_layers, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| ) |
|
|
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_downsample: |
| if resblock_updown: |
| self.downsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| down=True, |
| ) |
| else: |
| self.downsampler = Downsample( |
| spatial_dims=spatial_dims, |
| num_channels=out_channels, |
| use_conv=True, |
| out_channels=out_channels, |
| padding=downsample_padding, |
| ) |
| else: |
| self.downsampler = None |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
| ) -> tuple[torch.Tensor, list[torch.Tensor]]: |
| output_states = [] |
|
|
| for resnet, attn in zip(self.resnets, self.attentions): |
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states, context=context) |
| output_states.append(hidden_states) |
|
|
| if self.downsampler is not None: |
| hidden_states = self.downsampler(hidden_states, temb) |
| output_states.append(hidden_states) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class AttnMidBlock(nn.Module): |
| """ |
| Unet's mid block containing resnet and self-attention blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| temb_channels: number of timestep embedding channels. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| num_head_channels: number of channels in each attention head. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| temb_channels: int, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| num_head_channels: int = 1, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.attention = None |
|
|
| self.resnet_1 = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| self.attention = AttentionBlock( |
| spatial_dims=spatial_dims, |
| num_channels=in_channels, |
| num_head_channels=num_head_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| use_flash_attention=use_flash_attention, |
| ) |
|
|
| self.resnet_2 = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
| ) -> torch.Tensor: |
| del context |
| hidden_states = self.resnet_1(hidden_states, temb) |
| hidden_states = self.attention(hidden_states) |
| hidden_states = self.resnet_2(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| class CrossAttnMidBlock(nn.Module): |
| """ |
| Unet's mid block containing resnet and cross-attention blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| temb_channels: number of timestep embedding channels |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| num_head_channels: number of channels in each attention head. |
| transformer_num_layers: number of layers of Transformer blocks to use. |
| cross_attention_dim: number of context dimensions to use. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| temb_channels: int, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| num_head_channels: int = 1, |
| transformer_num_layers: int = 1, |
| cross_attention_dim: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.attention = None |
|
|
| self.resnet_1 = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| self.attention = SpatialTransformer( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| num_attention_heads=in_channels // num_head_channels, |
| num_head_channels=num_head_channels, |
| num_layers=transformer_num_layers, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| self.resnet_2 = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None |
| ) -> torch.Tensor: |
| hidden_states = self.resnet_1(hidden_states, temb) |
| hidden_states = self.attention(hidden_states, context=context) |
| hidden_states = self.resnet_2(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| class UpBlock(nn.Module): |
| """ |
| Unet's up block containing resnet and upsamplers blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| prev_output_channel: number of channels from residual connection. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_upsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for upsampling. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| prev_output_channel: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_upsample: bool = True, |
| resblock_updown: bool = False, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
| resnets = [] |
|
|
| for i in range(num_res_blocks): |
| res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
|
|
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_upsample: |
| if resblock_updown: |
| self.upsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| up=True, |
| ) |
| else: |
| self.upsampler = Upsample( |
| spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
| ) |
| else: |
| self.upsampler = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| res_hidden_states_list: list[torch.Tensor], |
| temb: torch.Tensor, |
| context: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| del context |
| for resnet in self.resnets: |
| |
| res_hidden_states = res_hidden_states_list[-1] |
| res_hidden_states_list = res_hidden_states_list[:-1] |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| hidden_states = resnet(hidden_states, temb) |
|
|
| if self.upsampler is not None: |
| hidden_states = self.upsampler(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| class AttnUpBlock(nn.Module): |
| """ |
| Unet's up block containing resnet, upsamplers, and self-attention blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| prev_output_channel: number of channels from residual connection. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_upsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for upsampling. |
| num_head_channels: number of channels in each attention head. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| prev_output_channel: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_upsample: bool = True, |
| resblock_updown: bool = False, |
| num_head_channels: int = 1, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
|
|
| resnets = [] |
| attentions = [] |
|
|
| for i in range(num_res_blocks): |
| res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
| attentions.append( |
| AttentionBlock( |
| spatial_dims=spatial_dims, |
| num_channels=out_channels, |
| num_head_channels=num_head_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| use_flash_attention=use_flash_attention, |
| ) |
| ) |
|
|
| self.resnets = nn.ModuleList(resnets) |
| self.attentions = nn.ModuleList(attentions) |
|
|
| if add_upsample: |
| if resblock_updown: |
| self.upsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| up=True, |
| ) |
| else: |
| self.upsampler = Upsample( |
| spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
| ) |
| else: |
| self.upsampler = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| res_hidden_states_list: list[torch.Tensor], |
| temb: torch.Tensor, |
| context: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| del context |
| for resnet, attn in zip(self.resnets, self.attentions): |
| |
| res_hidden_states = res_hidden_states_list[-1] |
| res_hidden_states_list = res_hidden_states_list[:-1] |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states) |
|
|
| if self.upsampler is not None: |
| hidden_states = self.upsampler(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| class CrossAttnUpBlock(nn.Module): |
| """ |
| Unet's up block containing resnet, upsamplers, and self-attention blocks. |
| |
| Args: |
| spatial_dims: The number of spatial dimensions. |
| in_channels: number of input channels. |
| prev_output_channel: number of channels from residual connection. |
| out_channels: number of output channels. |
| temb_channels: number of timestep embedding channels. |
| num_res_blocks: number of residual blocks. |
| norm_num_groups: number of groups for the group normalization. |
| norm_eps: epsilon for the group normalization. |
| add_upsample: if True add downsample block. |
| resblock_updown: if True use residual blocks for upsampling. |
| num_head_channels: number of channels in each attention head. |
| transformer_num_layers: number of layers of Transformer blocks to use. |
| cross_attention_dim: number of context dimensions to use. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| prev_output_channel: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int = 1, |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| add_upsample: bool = True, |
| resblock_updown: bool = False, |
| num_head_channels: int = 1, |
| transformer_num_layers: int = 1, |
| cross_attention_dim: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| self.resblock_updown = resblock_updown |
|
|
| resnets = [] |
| attentions = [] |
|
|
| for i in range(num_res_blocks): |
| res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| resnets.append( |
| ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| ) |
| ) |
| attentions.append( |
| SpatialTransformer( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| num_attention_heads=out_channels // num_head_channels, |
| num_head_channels=num_head_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| ) |
|
|
| self.attentions = nn.ModuleList(attentions) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| if add_upsample: |
| if resblock_updown: |
| self.upsampler = ResnetBlock( |
| spatial_dims=spatial_dims, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| up=True, |
| ) |
| else: |
| self.upsampler = Upsample( |
| spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels |
| ) |
| else: |
| self.upsampler = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| res_hidden_states_list: list[torch.Tensor], |
| temb: torch.Tensor, |
| context: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| for resnet, attn in zip(self.resnets, self.attentions): |
| |
| res_hidden_states = res_hidden_states_list[-1] |
| res_hidden_states_list = res_hidden_states_list[:-1] |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn(hidden_states, context=context) |
|
|
| if self.upsampler is not None: |
| hidden_states = self.upsampler(hidden_states, temb) |
|
|
| return hidden_states |
|
|
|
|
| def get_down_block( |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int, |
| norm_num_groups: int, |
| norm_eps: float, |
| add_downsample: bool, |
| resblock_updown: bool, |
| with_attn: bool, |
| with_cross_attn: bool, |
| num_head_channels: int, |
| transformer_num_layers: int, |
| cross_attention_dim: int | None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> nn.Module: |
| if with_attn: |
| return AttnDownBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_downsample=add_downsample, |
| resblock_updown=resblock_updown, |
| num_head_channels=num_head_channels, |
| use_flash_attention=use_flash_attention, |
| ) |
| elif with_cross_attn: |
| return CrossAttnDownBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_downsample=add_downsample, |
| resblock_updown=resblock_updown, |
| num_head_channels=num_head_channels, |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| else: |
| return DownBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_downsample=add_downsample, |
| resblock_updown=resblock_updown, |
| ) |
|
|
|
|
| def get_mid_block( |
| spatial_dims: int, |
| in_channels: int, |
| temb_channels: int, |
| norm_num_groups: int, |
| norm_eps: float, |
| with_conditioning: bool, |
| num_head_channels: int, |
| transformer_num_layers: int, |
| cross_attention_dim: int | None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> nn.Module: |
| if with_conditioning: |
| return CrossAttnMidBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| num_head_channels=num_head_channels, |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| else: |
| return AttnMidBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| temb_channels=temb_channels, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| num_head_channels=num_head_channels, |
| use_flash_attention=use_flash_attention, |
| ) |
|
|
|
|
| def get_up_block( |
| spatial_dims: int, |
| in_channels: int, |
| prev_output_channel: int, |
| out_channels: int, |
| temb_channels: int, |
| num_res_blocks: int, |
| norm_num_groups: int, |
| norm_eps: float, |
| add_upsample: bool, |
| resblock_updown: bool, |
| with_attn: bool, |
| with_cross_attn: bool, |
| num_head_channels: int, |
| transformer_num_layers: int, |
| cross_attention_dim: int | None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> nn.Module: |
| if with_attn: |
| return AttnUpBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| prev_output_channel=prev_output_channel, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_upsample=add_upsample, |
| resblock_updown=resblock_updown, |
| num_head_channels=num_head_channels, |
| use_flash_attention=use_flash_attention, |
| ) |
| elif with_cross_attn: |
| return CrossAttnUpBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| prev_output_channel=prev_output_channel, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_upsample=add_upsample, |
| resblock_updown=resblock_updown, |
| num_head_channels=num_head_channels, |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
| else: |
| return UpBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| prev_output_channel=prev_output_channel, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| num_res_blocks=num_res_blocks, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_upsample=add_upsample, |
| resblock_updown=resblock_updown, |
| ) |
|
|
|
|
| class DiffusionModelUNet(nn.Module): |
| """ |
| Unet network with timestep embedding and attention mechanisms for conditioning based on |
| Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 |
| and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| num_res_blocks: number of residual blocks (see ResnetBlock) per level. |
| num_channels: tuple of block output channels. |
| attention_levels: list of levels to add attention. |
| norm_num_groups: number of groups for the normalization. |
| norm_eps: epsilon for the normalization. |
| resblock_updown: if True use residual blocks for up/downsampling. |
| num_head_channels: number of channels in each attention head. |
| with_conditioning: if True add spatial transformers to perform conditioning. |
| transformer_num_layers: number of layers of Transformer blocks to use. |
| cross_attention_dim: number of context dimensions to use. |
| num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` |
| classes. |
| upcast_attention: if True, upcast attention operations to full precision. |
| use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), |
| num_channels: Sequence[int] = (32, 64, 64, 64), |
| attention_levels: Sequence[bool] = (False, False, True, True), |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| resblock_updown: bool = False, |
| num_head_channels: int | Sequence[int] = 8, |
| with_conditioning: bool = False, |
| transformer_num_layers: int = 1, |
| cross_attention_dim: int | None = None, |
| num_class_embeds: int | None = None, |
| upcast_attention: bool = False, |
| use_flash_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| if with_conditioning is True and cross_attention_dim is None: |
| raise ValueError( |
| "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " |
| "when using with_conditioning." |
| ) |
| if cross_attention_dim is not None and with_conditioning is False: |
| raise ValueError( |
| "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." |
| ) |
|
|
| |
| if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): |
| raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") |
|
|
| if len(num_channels) != len(attention_levels): |
| raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") |
|
|
| if isinstance(num_head_channels, int): |
| num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) |
|
|
| if len(num_head_channels) != len(attention_levels): |
| raise ValueError( |
| "num_head_channels should have the same length as attention_levels. For the i levels without attention," |
| " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." |
| ) |
|
|
| if isinstance(num_res_blocks, int): |
| num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) |
|
|
| if len(num_res_blocks) != len(num_channels): |
| raise ValueError( |
| "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " |
| "`num_channels`." |
| ) |
|
|
| if use_flash_attention and not has_xformers: |
| raise ValueError("use_flash_attention is True but xformers is not installed.") |
|
|
| if use_flash_attention is True and not torch.cuda.is_available(): |
| raise ValueError( |
| "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." |
| ) |
|
|
| self.in_channels = in_channels |
| self.block_out_channels = num_channels |
| self.out_channels = out_channels |
| self.num_res_blocks = num_res_blocks |
| self.attention_levels = attention_levels |
| self.num_head_channels = num_head_channels |
| self.with_conditioning = with_conditioning |
|
|
| |
| self.conv_in = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=num_channels[0], |
| strides=1, |
| kernel_size=3, |
| padding=1, |
| conv_only=True, |
| ) |
|
|
| |
| time_embed_dim = num_channels[0] * 4 |
| self.time_embed = nn.Sequential( |
| nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) |
| ) |
|
|
| |
| self.num_class_embeds = num_class_embeds |
| if num_class_embeds is not None: |
| self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
|
| |
| self.down_blocks = nn.ModuleList([]) |
| output_channel = num_channels[0] |
| for i in range(len(num_channels)): |
| input_channel = output_channel |
| output_channel = num_channels[i] |
| is_final_block = i == len(num_channels) - 1 |
|
|
| down_block = get_down_block( |
| spatial_dims=spatial_dims, |
| in_channels=input_channel, |
| out_channels=output_channel, |
| temb_channels=time_embed_dim, |
| num_res_blocks=num_res_blocks[i], |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_downsample=not is_final_block, |
| resblock_updown=resblock_updown, |
| with_attn=(attention_levels[i] and not with_conditioning), |
| with_cross_attn=(attention_levels[i] and with_conditioning), |
| num_head_channels=num_head_channels[i], |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
|
|
| self.down_blocks.append(down_block) |
|
|
| |
| self.middle_block = get_mid_block( |
| spatial_dims=spatial_dims, |
| in_channels=num_channels[-1], |
| temb_channels=time_embed_dim, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| with_conditioning=with_conditioning, |
| num_head_channels=num_head_channels[-1], |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
|
|
| |
| self.up_blocks = nn.ModuleList([]) |
| reversed_block_out_channels = list(reversed(num_channels)) |
| reversed_num_res_blocks = list(reversed(num_res_blocks)) |
| reversed_attention_levels = list(reversed(attention_levels)) |
| reversed_num_head_channels = list(reversed(num_head_channels)) |
| output_channel = reversed_block_out_channels[0] |
| for i in range(len(reversed_block_out_channels)): |
| prev_output_channel = output_channel |
| output_channel = reversed_block_out_channels[i] |
| input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] |
|
|
| is_final_block = i == len(num_channels) - 1 |
|
|
| up_block = get_up_block( |
| spatial_dims=spatial_dims, |
| in_channels=input_channel, |
| prev_output_channel=prev_output_channel, |
| out_channels=output_channel, |
| temb_channels=time_embed_dim, |
| num_res_blocks=reversed_num_res_blocks[i] + 1, |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_upsample=not is_final_block, |
| resblock_updown=resblock_updown, |
| with_attn=(reversed_attention_levels[i] and not with_conditioning), |
| with_cross_attn=(reversed_attention_levels[i] and with_conditioning), |
| num_head_channels=reversed_num_head_channels[i], |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| use_flash_attention=use_flash_attention, |
| ) |
|
|
| self.up_blocks.append(up_block) |
|
|
| |
| self.out = nn.Sequential( |
| nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), |
| nn.SiLU(), |
| zero_module( |
| Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=num_channels[0], |
| out_channels=out_channels, |
| strides=1, |
| kernel_size=3, |
| padding=1, |
| conv_only=True, |
| dilation=2 |
| ) |
| ), |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| timesteps: torch.Tensor, |
| context: torch.Tensor | None = None, |
| class_labels: torch.Tensor | None = None, |
| down_block_additional_residuals: tuple[torch.Tensor] | None = None, |
| mid_block_additional_residual: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: input tensor (N, C, SpatialDims). |
| timesteps: timestep tensor (N,). |
| context: context tensor (N, 1, ContextDim). |
| class_labels: context tensor (N, ). |
| down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). |
| mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). |
| """ |
| |
| t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) |
|
|
| |
| |
| |
| t_emb = t_emb.to(dtype=x.dtype) |
| emb = self.time_embed(t_emb) |
|
|
| |
| if self.num_class_embeds is not None: |
| if class_labels is None: |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") |
| class_emb = self.class_embedding(class_labels) |
| class_emb = class_emb.to(dtype=x.dtype) |
| emb = emb + class_emb |
|
|
| |
| h = self.conv_in(x) |
|
|
| |
| if context is not None and self.with_conditioning is False: |
| raise ValueError("model should have with_conditioning = True if context is provided") |
| down_block_res_samples: list[torch.Tensor] = [h] |
| for downsample_block in self.down_blocks: |
| h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) |
| for residual in res_samples: |
| down_block_res_samples.append(residual) |
|
|
| |
| if down_block_additional_residuals is not None: |
| new_down_block_res_samples = () |
| for down_block_res_sample, down_block_additional_residual in zip( |
| down_block_res_samples, down_block_additional_residuals |
| ): |
| down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| new_down_block_res_samples += (down_block_res_sample,) |
|
|
| down_block_res_samples = new_down_block_res_samples |
|
|
| |
| h = self.middle_block(hidden_states=h, temb=emb, context=context) |
|
|
| |
| if mid_block_additional_residual is not None: |
| h = h + mid_block_additional_residual |
|
|
| |
| for upsample_block in self.up_blocks: |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
| h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) |
|
|
| |
| h = self.out(h) |
|
|
| return h |
|
|
|
|
| class DiffusionModelEncoder(nn.Module): |
| """ |
| Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on |
| Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). |
| |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| num_res_blocks: number of residual blocks (see ResnetBlock) per level. |
| num_channels: tuple of block output channels. |
| attention_levels: list of levels to add attention. |
| norm_num_groups: number of groups for the normalization. |
| norm_eps: epsilon for the normalization. |
| resblock_updown: if True use residual blocks for downsampling. |
| num_head_channels: number of channels in each attention head. |
| with_conditioning: if True add spatial transformers to perform conditioning. |
| transformer_num_layers: number of layers of Transformer blocks to use. |
| cross_attention_dim: number of context dimensions to use. |
| num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. |
| upcast_attention: if True, upcast attention operations to full precision. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), |
| num_channels: Sequence[int] = (32, 64, 64, 64), |
| attention_levels: Sequence[bool] = (False, False, True, True), |
| norm_num_groups: int = 32, |
| norm_eps: float = 1e-6, |
| resblock_updown: bool = False, |
| num_head_channels: int | Sequence[int] = 8, |
| with_conditioning: bool = False, |
| transformer_num_layers: int = 1, |
| cross_attention_dim: int | None = None, |
| num_class_embeds: int | None = None, |
| upcast_attention: bool = False, |
| ) -> None: |
| super().__init__() |
| if with_conditioning is True and cross_attention_dim is None: |
| raise ValueError( |
| "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " |
| "when using with_conditioning." |
| ) |
| if cross_attention_dim is not None and with_conditioning is False: |
| raise ValueError( |
| "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." |
| ) |
|
|
| |
| if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): |
| raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") |
| if len(num_channels) != len(attention_levels): |
| raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") |
|
|
| if isinstance(num_head_channels, int): |
| num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) |
|
|
| if len(num_head_channels) != len(attention_levels): |
| raise ValueError( |
| "num_head_channels should have the same length as attention_levels. For the i levels without attention," |
| " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." |
| ) |
|
|
| self.in_channels = in_channels |
| self.block_out_channels = num_channels |
| self.out_channels = out_channels |
| self.num_res_blocks = num_res_blocks |
| self.attention_levels = attention_levels |
| self.num_head_channels = num_head_channels |
| self.with_conditioning = with_conditioning |
|
|
| |
| self.conv_in = Convolution( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=num_channels[0], |
| strides=1, |
| kernel_size=3, |
| padding=1, |
| conv_only=True, |
| ) |
|
|
| |
| time_embed_dim = num_channels[0] * 4 |
| self.time_embed = nn.Sequential( |
| nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) |
| ) |
|
|
| |
| self.num_class_embeds = num_class_embeds |
| if num_class_embeds is not None: |
| self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
|
| |
| self.down_blocks = nn.ModuleList([]) |
| output_channel = num_channels[0] |
| for i in range(len(num_channels)): |
| input_channel = output_channel |
| output_channel = num_channels[i] |
| is_final_block = i == len(num_channels) |
|
|
| down_block = get_down_block( |
| spatial_dims=spatial_dims, |
| in_channels=input_channel, |
| out_channels=output_channel, |
| temb_channels=time_embed_dim, |
| num_res_blocks=num_res_blocks[i], |
| norm_num_groups=norm_num_groups, |
| norm_eps=norm_eps, |
| add_downsample=not is_final_block, |
| resblock_updown=resblock_updown, |
| with_attn=(attention_levels[i] and not with_conditioning), |
| with_cross_attn=(attention_levels[i] and with_conditioning), |
| num_head_channels=num_head_channels[i], |
| transformer_num_layers=transformer_num_layers, |
| cross_attention_dim=cross_attention_dim, |
| upcast_attention=upcast_attention, |
| ) |
|
|
| self.down_blocks.append(down_block) |
|
|
| self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| timesteps: torch.Tensor, |
| context: torch.Tensor | None = None, |
| class_labels: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: input tensor (N, C, SpatialDims). |
| timesteps: timestep tensor (N,). |
| context: context tensor (N, 1, ContextDim). |
| class_labels: context tensor (N, ). |
| """ |
| |
| t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) |
|
|
| |
| |
| |
| t_emb = t_emb.to(dtype=x.dtype) |
| emb = self.time_embed(t_emb) |
|
|
| |
| if self.num_class_embeds is not None: |
| if class_labels is None: |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") |
| class_emb = self.class_embedding(class_labels) |
| class_emb = class_emb.to(dtype=x.dtype) |
| emb = emb + class_emb |
|
|
| |
| h = self.conv_in(x) |
|
|
| |
| if context is not None and self.with_conditioning is False: |
| raise ValueError("model should have with_conditioning = True if context is provided") |
| for downsample_block in self.down_blocks: |
| h, _ = downsample_block(hidden_states=h, temb=emb, context=context) |
|
|
| h = h.reshape(h.shape[0], -1) |
| output = self.out(h) |
|
|
| return output |
|
|