sage-t2i / model /dit.py
itriedcoding's picture
Upload folder using huggingface_hub
2d7087a verified
Raw
History Blame Contribute Delete
7.83 kB
"""Sage-T2I: Custom Diffusion Transformer for Text-to-Image generation."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, freq_embed_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(freq_embed_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.freq_embed_size = freq_embed_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.freq_embed_size)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size, act_layer=nn.SiLU):
super().__init__()
self.linear = nn.Linear(in_channels, hidden_size, bias=True)
self.act = act_layer()
def forward(self, x):
return self.act(self.linear(x))
class SelfAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
self.proj = nn.Linear(hidden_size, hidden_size, bias=True)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class CrossAttention(nn.Module):
def __init__(self, hidden_size, context_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q = nn.Linear(hidden_size, hidden_size, bias=True)
self.k = nn.Linear(context_dim, hidden_size, bias=True)
self.v = nn.Linear(context_dim, hidden_size, bias=True)
self.proj = nn.Linear(hidden_size, hidden_size, bias=True)
def forward(self, x, context):
B, N, C = x.shape
_, M, _ = context.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class FeedForward(nn.Module):
def __init__(self, hidden_size, mlp_ratio=4.0):
super().__init__()
hidden = int(hidden_size * mlp_ratio)
self.fc1 = nn.Linear(hidden_size, hidden, bias=True)
self.fc2 = nn.Linear(hidden, hidden_size, bias=True)
self.act = nn.GELU(approximate="tanh")
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, context_dim, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.self_attn = SelfAttention(hidden_size, num_heads)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.cross_attn = CrossAttention(hidden_size, context_dim, num_heads)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mlp_ratio)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, context):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + self.cross_attn(self.norm2(x), context)
x = x + gate_mlp.unsqueeze(1) * self.ff(modulate(self.norm3(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, in_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * in_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.in_channels = config.in_channels
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.num_heads = config.num_heads
self.x_embedder = nn.Linear(config.patch_size * config.patch_size * config.in_channels, config.hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.c_embedder = CaptionEmbedder(config.context_dim, config.hidden_size)
latent_size = config.image_size // 8
num_patches = (latent_size // config.patch_size) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size), requires_grad=True)
self.blocks = nn.ModuleList([
DiTBlock(config.hidden_size, config.num_heads, config.context_dim, config.mlp_ratio)
for _ in range(config.num_layers)
])
self.final_layer = FinalLayer(config.hidden_size, config.patch_size, config.in_channels)
self.initialize_weights()
def initialize_weights(self):
nn.init.normal_(self.x_embedder.weight, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.c_embedder.linear.weight, std=0.02)
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, context):
B, C, H, W = x.shape
x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
x = x.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size)
x = self.x_embedder(x) + self.pos_embed
t_emb = self.t_embedder(t)
c_emb = self.c_embedder(context).mean(dim=1)
c = t_emb + c_emb
for block in self.blocks:
x = block(x, c, context)
x = self.final_layer(x, c)
x = x.reshape(B, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size, self.in_channels)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, self.in_channels, H, W)
return x