| """Sage-T2I: Custom Diffusion Transformer for Text-to-Image generation.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| def modulate(x, shift, scale): |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| class TimestepEmbedder(nn.Module): |
| def __init__(self, hidden_size, freq_embed_size=256): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(freq_embed_size, hidden_size, bias=True), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
| self.freq_embed_size = freq_embed_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32, device=t.device) / half) |
| args = t[:, None].float() * freqs[None] |
| return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.freq_embed_size) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
| class CaptionEmbedder(nn.Module): |
| def __init__(self, in_channels, hidden_size, act_layer=nn.SiLU): |
| super().__init__() |
| self.linear = nn.Linear(in_channels, hidden_size, bias=True) |
| self.act = act_layer() |
|
|
| def forward(self, x): |
| return self.act(self.linear(x)) |
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, hidden_size, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) |
| self.proj = nn.Linear(hidden_size, hidden_size, bias=True) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| attn = attn.softmax(dim=-1) |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| return x |
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, hidden_size, context_dim, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.q = nn.Linear(hidden_size, hidden_size, bias=True) |
| self.k = nn.Linear(context_dim, hidden_size, bias=True) |
| self.v = nn.Linear(context_dim, hidden_size, bias=True) |
| self.proj = nn.Linear(hidden_size, hidden_size, bias=True) |
|
|
| def forward(self, x, context): |
| B, N, C = x.shape |
| _, M, _ = context.shape |
| q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2) |
| attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| attn = attn.softmax(dim=-1) |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| return x |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, hidden_size, mlp_ratio=4.0): |
| super().__init__() |
| hidden = int(hidden_size * mlp_ratio) |
| self.fc1 = nn.Linear(hidden_size, hidden, bias=True) |
| self.fc2 = nn.Linear(hidden, hidden_size, bias=True) |
| self.act = nn.GELU(approximate="tanh") |
|
|
| def forward(self, x): |
| return self.fc2(self.act(self.fc1(x))) |
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, hidden_size, num_heads, context_dim, mlp_ratio=4.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.self_attn = SelfAttention(hidden_size, num_heads) |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.cross_attn = CrossAttention(hidden_size, context_dim, num_heads) |
| self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.ff = FeedForward(hidden_size, mlp_ratio) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| ) |
|
|
| def forward(self, x, c, context): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ |
| self.adaLN_modulation(c).chunk(6, dim=1) |
| x = x + gate_msa.unsqueeze(1) * self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa)) |
| x = x + self.cross_attn(self.norm2(x), context) |
| x = x + gate_mlp.unsqueeze(1) * self.ff(modulate(self.norm3(x), shift_mlp, scale_mlp)) |
| return x |
|
|
| class FinalLayer(nn.Module): |
| def __init__(self, hidden_size, patch_size, in_channels): |
| super().__init__() |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * in_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) |
| ) |
|
|
| def forward(self, x, c): |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) |
| x = modulate(self.norm_final(x), shift, scale) |
| x = self.linear(x) |
| return x |
|
|
| class DiT(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.in_channels = config.in_channels |
| self.hidden_size = config.hidden_size |
| self.patch_size = config.patch_size |
| self.num_heads = config.num_heads |
|
|
| self.x_embedder = nn.Linear(config.patch_size * config.patch_size * config.in_channels, config.hidden_size, bias=True) |
| self.t_embedder = TimestepEmbedder(config.hidden_size) |
| self.c_embedder = CaptionEmbedder(config.context_dim, config.hidden_size) |
| latent_size = config.image_size // 8 |
| num_patches = (latent_size // config.patch_size) ** 2 |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size), requires_grad=True) |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock(config.hidden_size, config.num_heads, config.context_dim, config.mlp_ratio) |
| for _ in range(config.num_layers) |
| ]) |
| self.final_layer = FinalLayer(config.hidden_size, config.patch_size, config.in_channels) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| nn.init.normal_(self.x_embedder.weight, std=0.02) |
| nn.init.normal_(self.pos_embed, std=0.02) |
| nn.init.normal_(self.c_embedder.linear.weight, std=0.02) |
| for block in self.blocks: |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
| nn.init.constant_(self.final_layer.linear.weight, 0) |
| nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
| def forward(self, x, t, context): |
| B, C, H, W = x.shape |
| x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size) |
| x = x.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size) |
| x = self.x_embedder(x) + self.pos_embed |
| t_emb = self.t_embedder(t) |
| c_emb = self.c_embedder(context).mean(dim=1) |
| c = t_emb + c_emb |
| for block in self.blocks: |
| x = block(x, c, context) |
| x = self.final_layer(x, c) |
| x = x.reshape(B, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size, self.in_channels) |
| x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, self.in_channels, H, W) |
| return x |
|
|