mie237's picture
Upload folder using huggingface_hub
924c3c9 verified
import math
import torch
import torch.nn as nn
from .modules import (
film_modulate,
unpatchify,
PatchEmbed,
PE_wrapper,
TimestepEmbedder,
FeedForward,
RMSNorm,
)
from .attention import Attention
class AdaLN(nn.Module):
def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
super().__init__()
self.ada_mode = ada_mode
self.scale_shift_table = None
if ada_mode == 'ada':
self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
elif ada_mode == 'ada_single':
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
elif ada_mode in ['ada_sola', 'ada_sola_bias']:
self.lora_a = nn.Linear(dim, r * 6, bias=False)
self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
self.scaling = alpha / r
if ada_mode == 'ada_sola_bias':
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
else:
raise NotImplementedError
def forward(self, time_token=None, time_ada=None):
if self.ada_mode == 'ada':
assert time_ada is None
B = time_token.shape[0]
time_ada = self.time_ada(time_token).reshape(B, 6, -1)
elif self.ada_mode == 'ada_single':
B = time_ada.shape[0]
time_ada = time_ada.reshape(B, 6, -1)
time_ada = self.scale_shift_table[None] + time_ada
elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
B = time_ada.shape[0]
time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
time_ada = time_ada + time_ada_lora
time_ada = time_ada.reshape(B, 6, -1)
if self.scale_shift_table is not None:
time_ada = self.scale_shift_table[None] + time_ada
else:
raise NotImplementedError
return time_ada
class DiTBlock(nn.Module):
def __init__(
self,
dim,
context_dim=None,
num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
qk_norm=None,
act_layer='gelu',
norm_layer=nn.LayerNorm,
time_fusion='none',
ada_sola_rank=None,
ada_sola_alpha=None,
skip=False,
skip_norm=False,
rope_mode='none',
context_norm=False,
use_checkpoint=False
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
rope_mode=rope_mode
)
if context_dim is not None:
self.use_context = True
self.cross_attn = Attention(
dim=dim,
num_heads=num_heads,
context_dim=context_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
rope_mode='none'
)
self.norm2 = norm_layer(dim)
if context_norm:
self.norm_context = norm_layer(context_dim)
else:
self.norm_context = nn.Identity()
else:
self.use_context = False
self.norm3 = norm_layer(dim)
self.mlp = FeedForward(
dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
)
self.use_adanorm = True if time_fusion != 'token' else False
if self.use_adanorm:
self.adaln = AdaLN(
dim,
ada_mode=time_fusion,
r=ada_sola_rank,
alpha=ada_sola_alpha
)
if skip:
self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None
self.use_checkpoint = use_checkpoint
def forward(
self,
x,
time_token=None,
time_ada=None,
skip=None,
context=None,
x_mask=None,
context_mask=None,
extras=None
):
if self.use_checkpoint:
from torch.utils.checkpoint import checkpoint
return checkpoint(
self._forward,
x, time_token, time_ada, skip, context, x_mask, context_mask,
extras,
use_reentrant=False
)
else:
return self._forward(
x, time_token, time_ada, skip, context, x_mask, context_mask,
extras
)
def _forward(
self,
x,
time_token=None,
time_ada=None,
skip=None,
context=None,
x_mask=None,
context_mask=None,
extras=None
):
B, T, C = x.shape
if self.skip_linear is not None:
assert skip is not None
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
if self.use_adanorm:
time_ada = self.adaln(time_token, time_ada)
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = time_ada.chunk(6, dim=1)
if self.use_adanorm:
x_norm = film_modulate(
self.norm1(x), shift=shift_msa, scale=scale_msa
)
x = x + (1 - gate_msa) * self.attn(
x_norm, context=None, context_mask=x_mask, extras=extras
)
else:
x = x + self.attn(
self.norm1(x),
context=None,
context_mask=x_mask,
extras=extras
)
if self.use_context:
assert context is not None
x = x + self.cross_attn(
x=self.norm2(x),
context=self.norm_context(context),
context_mask=context_mask,
extras=extras
)
if self.use_adanorm:
x_norm = film_modulate(
self.norm3(x), shift=shift_mlp, scale=scale_mlp
)
x = x + (1 - gate_mlp) * self.mlp(x_norm)
else:
x = x + self.mlp(self.norm3(x))
return x
class FinalBlock(nn.Module):
def __init__(
self,
embed_dim,
patch_size,
in_chans,
img_size,
input_type='2d',
norm_layer=nn.LayerNorm,
use_conv=True,
use_adanorm=True
):
super().__init__()
self.in_chans = in_chans
self.img_size = img_size
self.input_type = input_type
self.norm = norm_layer(embed_dim)
self.use_adanorm = use_adanorm
if input_type == '2d':
self.patch_dim = patch_size**2 * in_chans
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
if use_conv:
self.final_layer = nn.Conv2d(
self.in_chans, self.in_chans, 3, padding=1
)
else:
self.final_layer = nn.Identity()
elif input_type == '1d':
self.patch_dim = patch_size * in_chans
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
if use_conv:
self.final_layer = nn.Conv1d(
self.in_chans, self.in_chans, 3, padding=1
)
else:
self.final_layer = nn.Identity()
def forward(self, x, time_ada=None, extras=0):
B, T, C = x.shape
x = x[:, extras:, :]
if self.use_adanorm:
shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
x = film_modulate(self.norm(x), shift, scale)
else:
x = self.norm(x)
x = self.linear(x)
x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
x = self.final_layer(x)
return x
class UDiT(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
input_type='2d',
out_chans=None,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
qk_norm=None,
act_layer='gelu',
norm_layer='layernorm',
context_norm=False,
use_checkpoint=False,
time_fusion='token',
ada_sola_rank=None,
ada_sola_alpha=None,
cls_dim=None,
context_dim=768,
context_fusion='concat',
context_max_length=128,
context_pe_method='sinu',
pe_method='abs',
rope_mode='none',
use_conv=True,
skip=True,
skip_norm=True
):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.in_chans = in_chans
self.input_type = input_type
if self.input_type == '2d':
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
elif self.input_type == '1d':
num_patches = img_size // patch_size
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
input_type=input_type
)
out_chans = in_chans if out_chans is None else out_chans
self.out_chans = out_chans
self.rope = rope_mode
self.x_pe = PE_wrapper(
dim=embed_dim, method=pe_method, length=num_patches
)
self.time_embed = TimestepEmbedder(embed_dim)
self.time_fusion = time_fusion
self.use_adanorm = False
if cls_dim is not None:
self.cls_embed = nn.Sequential(
nn.Linear(cls_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
else:
self.cls_embed = None
if time_fusion == 'token':
self.extras = 2 if self.cls_embed else 1
self.time_pe = PE_wrapper(
dim=embed_dim, method='abs', length=self.extras
)
elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
self.use_adanorm = True
self.time_act = nn.SiLU()
self.extras = 0
self.time_ada_final = nn.Linear(
embed_dim, 2 * embed_dim, bias=True
)
if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
else:
self.time_ada = None
else:
raise NotImplementedError
self.use_context = False
self.context_cross = False
self.context_max_length = context_max_length
self.context_fusion = 'none'
if context_dim is not None:
self.use_context = True
self.context_embed = nn.Sequential(
nn.Linear(context_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
self.context_fusion = context_fusion
if context_fusion == 'concat' or context_fusion == 'joint':
self.extras += context_max_length
self.context_pe = PE_wrapper(
dim=embed_dim,
method=context_pe_method,
length=context_max_length
)
context_dim = None
elif context_fusion == 'cross':
self.context_pe = PE_wrapper(
dim=embed_dim,
method=context_pe_method,
length=context_max_length
)
self.context_cross = True
context_dim = embed_dim
else:
raise NotImplementedError
self.use_skip = skip
if norm_layer == 'layernorm':
norm_layer = nn.LayerNorm
elif norm_layer == 'rmsnorm':
norm_layer = RMSNorm
else:
raise NotImplementedError
self.in_blocks = nn.ModuleList([
DiTBlock(
dim=embed_dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
skip=False,
skip_norm=False,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
) for _ in range(depth // 2)
])
self.mid_block = DiTBlock(
dim=embed_dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
skip=False,
skip_norm=False,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
)
self.out_blocks = nn.ModuleList([
DiTBlock(
dim=embed_dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
skip=skip,
skip_norm=skip_norm,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
) for _ in range(depth // 2)
])
self.use_conv = use_conv
self.final_block = FinalBlock(
embed_dim=embed_dim,
patch_size=patch_size,
img_size=img_size,
in_chans=out_chans,
input_type=input_type,
norm_layer=norm_layer,
use_conv=use_conv,
use_adanorm=self.use_adanorm
)
self.initialize_weights()
def _init_ada(self):
if self.time_fusion == 'ada':
nn.init.constant_(self.time_ada_final.weight, 0)
nn.init.constant_(self.time_ada_final.bias, 0)
for block in self.in_blocks:
nn.init.constant_(block.adaln.time_ada.weight, 0)
nn.init.constant_(block.adaln.time_ada.bias, 0)
nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
for block in self.out_blocks:
nn.init.constant_(block.adaln.time_ada.weight, 0)
nn.init.constant_(block.adaln.time_ada.bias, 0)
elif self.time_fusion == 'ada_single':
nn.init.constant_(self.time_ada.weight, 0)
nn.init.constant_(self.time_ada.bias, 0)
nn.init.constant_(self.time_ada_final.weight, 0)
nn.init.constant_(self.time_ada_final.bias, 0)
elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
nn.init.constant_(self.time_ada.weight, 0)
nn.init.constant_(self.time_ada.bias, 0)
nn.init.constant_(self.time_ada_final.weight, 0)
nn.init.constant_(self.time_ada_final.bias, 0)
for block in self.in_blocks:
nn.init.kaiming_uniform_(
block.adaln.lora_a.weight, a=math.sqrt(5)
)
nn.init.constant_(block.adaln.lora_b.weight, 0)
nn.init.kaiming_uniform_(
self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
)
nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
for block in self.out_blocks:
nn.init.kaiming_uniform_(
block.adaln.lora_a.weight, a=math.sqrt(5)
)
nn.init.constant_(block.adaln.lora_b.weight, 0)
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
w = self.patch_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.patch_embed.proj.bias, 0)
if self.use_adanorm:
self._init_ada()
if self.context_cross:
for block in self.in_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
for block in self.out_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
if self.cls_embed:
if self.use_adanorm:
nn.init.constant_(self.cls_embed[-1].weight, 0)
nn.init.constant_(self.cls_embed[-1].bias, 0)
if self.use_conv:
nn.init.xavier_uniform_(self.final_block.final_layer.weight)
nn.init.constant_(self.final_block.final_layer.bias, 0)
def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
assert context.shape[-2] == self.context_max_length
B = x.shape[0]
if x_mask is None:
x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
if context_mask is None:
context_mask = torch.ones(
B, context.shape[-2], device=context.device
).bool()
x_mask = torch.cat([context_mask, x_mask], dim=1)
x = torch.cat((context, x), dim=1)
return x, x_mask
def forward(
self,
x,
timesteps,
context,
x_mask=None,
context_mask=None,
cls_token=None,
controlnet_skips=None,
):
if timesteps.dim() == 0:
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
x = self.patch_embed(x)
x = self.x_pe(x)
B, L, D = x.shape
if self.use_context:
context_token = self.context_embed(context)
context_token = self.context_pe(context_token)
if self.context_fusion == 'concat' or self.context_fusion == 'joint':
x, x_mask = self._concat_x_context(
x=x,
context=context_token,
x_mask=x_mask,
context_mask=context_mask
)
context_token, context_mask = None, None
else:
context_token, context_mask = None, None
time_token = self.time_embed(timesteps)
if self.cls_embed:
cls_token = self.cls_embed(cls_token)
time_ada = None
time_ada_final = None
if self.use_adanorm:
if self.cls_embed:
time_token = time_token + cls_token
time_token = self.time_act(time_token)
time_ada_final = self.time_ada_final(time_token)
if self.time_ada is not None:
time_ada = self.time_ada(time_token)
else:
time_token = time_token.unsqueeze(dim=1)
if self.cls_embed:
cls_token = cls_token.unsqueeze(dim=1)
time_token = torch.cat([time_token, cls_token], dim=1)
time_token = self.time_pe(time_token)
x = torch.cat((time_token, x), dim=1)
if x_mask is not None:
x_mask = torch.cat([
torch.ones(B, time_token.shape[1],
device=x_mask.device).bool(), x_mask
], dim=1)
time_token = None
skips = []
for blk in self.in_blocks:
x = blk(
x=x,
time_token=time_token,
time_ada=time_ada,
skip=None,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
if self.use_skip:
skips.append(x)
x = self.mid_block(
x=x,
time_token=time_token,
time_ada=time_ada,
skip=None,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
for blk in self.out_blocks:
if self.use_skip:
skip = skips.pop()
if controlnet_skips:
skip = skip + controlnet_skips.pop()
else:
skip = None
if controlnet_skips:
x = x + controlnet_skips.pop()
x = blk(
x=x,
time_token=time_token,
time_ada=time_ada,
skip=skip,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
return x
class LayerFusionDiTBlock(DiTBlock):
def __init__(
self,
dim,
ta_context_dim,
ta_context_norm=False,
context_dim=None,
num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
qk_norm=None,
act_layer='gelu',
norm_layer=nn.LayerNorm,
ta_context_fusion='add',
time_fusion='none',
ada_sola_rank=None,
ada_sola_alpha=None,
skip=False,
skip_norm=False,
rope_mode='none',
context_norm=False,
use_checkpoint=False
):
super().__init__(
dim=dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
skip=skip,
skip_norm=skip_norm,
rope_mode=rope_mode,
context_norm=context_norm,
use_checkpoint=use_checkpoint
)
self.ta_context_fusion = ta_context_fusion
self.ta_context_norm = ta_context_norm
if self.ta_context_fusion == "add":
self.ta_context_projection = nn.Linear(
ta_context_dim, dim, bias=False
)
self.ta_context_norm = norm_layer(
ta_context_dim
) if self.ta_context_norm else nn.Identity()
elif self.ta_context_fusion == "concat":
self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
self.ta_context_norm = norm_layer(
ta_context_dim + dim
) if self.ta_context_norm else nn.Identity()
def forward(
self,
x,
time_aligned_context,
time_token=None,
time_ada=None,
skip=None,
context=None,
x_mask=None,
context_mask=None,
extras=None
):
if self.use_checkpoint:
from torch.utils.checkpoint import checkpoint
return checkpoint(
self._forward,
x, time_aligned_context, time_token, time_ada, skip, context,
x_mask, context_mask, extras,
use_reentrant=False
)
else:
return self._forward(
x, time_aligned_context, time_token, time_ada, skip, context,
x_mask, context_mask, extras,
)
def _forward(
self,
x,
time_aligned_context,
time_token=None,
time_ada=None,
skip=None,
context=None,
x_mask=None,
context_mask=None,
extras=None
):
B, T, C = x.shape
if self.skip_linear is not None:
assert skip is not None
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
if self.use_adanorm:
time_ada = self.adaln(time_token, time_ada)
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = time_ada.chunk(6, dim=1)
if self.use_adanorm:
x_norm = film_modulate(
self.norm1(x), shift=shift_msa, scale=scale_msa
)
tanh_gate_msa = torch.tanh(1 - gate_msa)
x = x + tanh_gate_msa * self.attn(
x_norm, context=None, context_mask=x_mask, extras=extras
)
else:
x = x + self.attn(
self.norm1(x),
context=None,
context_mask=x_mask,
extras=extras
)
if self.ta_context_fusion == "add":
time_aligned_context = self.ta_context_projection(
self.ta_context_norm(time_aligned_context)
)
if time_aligned_context.size(1) < x.size(1):
time_aligned_context = nn.functional.pad(
time_aligned_context, (0, 0, 1, 0)
)
x = x + time_aligned_context
elif self.ta_context_fusion == "concat":
if time_aligned_context.size(1) < x.size(1):
time_aligned_context = nn.functional.pad(
time_aligned_context, (0, 0, 1, 0)
)
cat = torch.cat([x, time_aligned_context], dim=-1)
cat = self.ta_context_norm(cat)
x = self.ta_context_projection(cat)
if self.use_context:
assert context is not None
x = x + self.cross_attn(
x=self.norm2(x),
context=self.norm_context(context),
context_mask=context_mask,
extras=extras
)
if self.use_adanorm:
x_norm = film_modulate(
self.norm3(x), shift=shift_mlp, scale=scale_mlp
)
x = x + (1 - gate_mlp) * self.mlp(x_norm)
else:
x = x + self.mlp(self.norm3(x))
return x
class LayerFusionAudioDiT(UDiT):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
input_type='2d',
out_chans=None,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
qk_norm=None,
act_layer='gelu',
norm_layer='layernorm',
context_norm=False,
use_checkpoint=False,
time_fusion='token',
ada_sola_rank=None,
ada_sola_alpha=None,
cls_dim=None,
ta_context_dim=768,
ta_context_fusion='concat',
ta_context_norm=True,
context_dim=768,
context_fusion='concat',
context_max_length=128,
context_pe_method='sinu',
pe_method='abs',
rope_mode='none',
use_conv=True,
skip=True,
skip_norm=True
):
nn.Module.__init__(self)
self.num_features = self.embed_dim = embed_dim
self.in_chans = in_chans
self.input_type = input_type
if self.input_type == '2d':
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
elif self.input_type == '1d':
num_patches = img_size // patch_size
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
input_type=input_type
)
out_chans = in_chans if out_chans is None else out_chans
self.out_chans = out_chans
self.rope = rope_mode
self.x_pe = PE_wrapper(
dim=embed_dim, method=pe_method, length=num_patches
)
self.time_embed = TimestepEmbedder(embed_dim)
self.time_fusion = time_fusion
self.use_adanorm = False
if cls_dim is not None:
self.cls_embed = nn.Sequential(
nn.Linear(cls_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
else:
self.cls_embed = None
if time_fusion == 'token':
self.extras = 2 if self.cls_embed else 1
self.time_pe = PE_wrapper(
dim=embed_dim, method='abs', length=self.extras
)
elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
self.use_adanorm = True
self.time_act = nn.SiLU()
self.extras = 0
self.time_ada_final = nn.Linear(
embed_dim, 2 * embed_dim, bias=True
)
if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
else:
self.time_ada = None
else:
raise NotImplementedError
self.use_context = False
self.context_cross = False
self.context_max_length = context_max_length
self.context_fusion = 'none'
if context_dim is not None:
self.use_context = True
self.context_embed = nn.Sequential(
nn.Linear(context_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
self.context_fusion = context_fusion
if context_fusion == 'concat' or context_fusion == 'joint':
self.extras += context_max_length
self.context_pe = PE_wrapper(
dim=embed_dim,
method=context_pe_method,
length=context_max_length
)
context_dim = None
elif context_fusion == 'cross':
self.context_pe = PE_wrapper(
dim=embed_dim,
method=context_pe_method,
length=context_max_length
)
self.context_cross = True
context_dim = embed_dim
else:
raise NotImplementedError
self.use_skip = skip
if norm_layer == 'layernorm':
norm_layer = nn.LayerNorm
elif norm_layer == 'rmsnorm':
norm_layer = RMSNorm
else:
raise NotImplementedError
self.in_blocks = nn.ModuleList([
LayerFusionDiTBlock(
dim=embed_dim,
ta_context_dim=ta_context_dim,
ta_context_fusion=ta_context_fusion,
ta_context_norm=ta_context_norm,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
skip=False,
skip_norm=False,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
) for i in range(depth // 2)
])
self.mid_block = LayerFusionDiTBlock(
dim=embed_dim,
ta_context_dim=ta_context_dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
ta_context_fusion=ta_context_fusion,
ta_context_norm=ta_context_norm,
skip=False,
skip_norm=False,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
)
self.out_blocks = nn.ModuleList([
LayerFusionDiTBlock(
dim=embed_dim,
ta_context_dim=ta_context_dim,
context_dim=context_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
qk_norm=qk_norm,
act_layer=act_layer,
norm_layer=norm_layer,
time_fusion=time_fusion,
ada_sola_rank=ada_sola_rank,
ada_sola_alpha=ada_sola_alpha,
ta_context_fusion=ta_context_fusion,
ta_context_norm=ta_context_norm,
skip=skip,
skip_norm=skip_norm,
rope_mode=self.rope,
context_norm=context_norm,
use_checkpoint=use_checkpoint
) for i in range(depth // 2)
])
self.use_conv = use_conv
self.final_block = FinalBlock(
embed_dim=embed_dim,
patch_size=patch_size,
img_size=img_size,
in_chans=out_chans,
input_type=input_type,
norm_layer=norm_layer,
use_conv=use_conv,
use_adanorm=self.use_adanorm
)
self.initialize_weights()
def forward(
self,
x,
timesteps,
time_aligned_context,
context,
x_mask=None,
context_mask=None,
cls_token=None,
controlnet_skips=None,
):
if timesteps.dim() == 0:
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
x = self.patch_embed(x)
x = self.x_pe(x)
B, L, D = x.shape
if self.use_context:
context_token = self.context_embed(context)
context_token = self.context_pe(context_token)
if self.context_fusion == 'concat' or self.context_fusion == 'joint':
x, x_mask = self._concat_x_context(
x=x,
context=context_token,
x_mask=x_mask,
context_mask=context_mask
)
context_token, context_mask = None, None
else:
context_token, context_mask = None, None
time_token = self.time_embed(timesteps)
if self.cls_embed:
cls_token = self.cls_embed(cls_token)
time_ada = None
time_ada_final = None
if self.use_adanorm:
if self.cls_embed:
time_token = time_token + cls_token
time_token = self.time_act(time_token)
time_ada_final = self.time_ada_final(time_token)
if self.time_ada is not None:
time_ada = self.time_ada(time_token)
else:
time_token = time_token.unsqueeze(dim=1)
if self.cls_embed:
cls_token = cls_token.unsqueeze(dim=1)
time_token = torch.cat([time_token, cls_token], dim=1)
time_token = self.time_pe(time_token)
x = torch.cat((time_token, x), dim=1)
if x_mask is not None:
x_mask = torch.cat([
torch.ones(B, time_token.shape[1],
device=x_mask.device).bool(), x_mask
], dim=1)
time_token = None
skips = []
for blk in self.in_blocks:
x = blk(
x=x,
time_aligned_context=time_aligned_context,
time_token=time_token,
time_ada=time_ada,
skip=None,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
if self.use_skip:
skips.append(x)
x = self.mid_block(
x=x,
time_aligned_context=time_aligned_context,
time_token=time_token,
time_ada=time_ada,
skip=None,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
for blk in self.out_blocks:
if self.use_skip:
skip = skips.pop()
if controlnet_skips:
skip = skip + controlnet_skips.pop()
else:
skip = None
if controlnet_skips:
x = x + controlnet_skips.pop()
x = blk(
x=x,
time_aligned_context=time_aligned_context,
time_token=time_token,
time_ada=time_ada,
skip=skip,
context=context_token,
x_mask=x_mask,
context_mask=context_mask,
extras=self.extras
)
x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
return x