|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import types |
|
|
from copy import deepcopy |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.nn as nn |
|
|
from diffusers.configuration_utils import register_to_config |
|
|
from diffusers.utils import is_torch_version |
|
|
from einops import rearrange |
|
|
|
|
|
from ..dist import (get_sequence_parallel_rank, |
|
|
get_sequence_parallel_world_size, get_sp_group, |
|
|
usp_attn_s2v_forward) |
|
|
from .attention_utils import attention |
|
|
from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder, |
|
|
FramePackMotioner, MotionerTransformers, |
|
|
rope_precompute) |
|
|
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock, |
|
|
WanLayerNorm, WanSelfAttention, |
|
|
sinusoidal_embedding_1d) |
|
|
|
|
|
|
|
|
def zero_module(module): |
|
|
""" |
|
|
Zero out the parameters of a module and return it. |
|
|
""" |
|
|
for p in module.parameters(): |
|
|
p.detach().zero_() |
|
|
return module |
|
|
|
|
|
|
|
|
def torch_dfs(model: nn.Module, parent_name='root'): |
|
|
module_names, modules = [], [] |
|
|
current_name = parent_name if parent_name else 'root' |
|
|
module_names.append(current_name) |
|
|
modules.append(model) |
|
|
|
|
|
for name, child in model.named_children(): |
|
|
if parent_name: |
|
|
child_name = f'{parent_name}.{name}' |
|
|
else: |
|
|
child_name = name |
|
|
child_modules, child_names = torch_dfs(child, child_name) |
|
|
module_names += child_names |
|
|
modules += child_modules |
|
|
return modules, module_names |
|
|
|
|
|
|
|
|
@amp.autocast(enabled=False) |
|
|
@torch.compiler.disable() |
|
|
def s2v_rope_apply(x, grid_sizes, freqs, start=None): |
|
|
n, c = x.size(2), x.size(3) // 2 |
|
|
|
|
|
output = [] |
|
|
for i, _ in enumerate(x): |
|
|
s = x.size(1) |
|
|
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) |
|
|
freqs_i = freqs[i, :s] |
|
|
|
|
|
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) |
|
|
x_i = torch.cat([x_i, x[i, s:]]) |
|
|
|
|
|
output.append(x_i) |
|
|
return torch.stack(output).float() |
|
|
|
|
|
|
|
|
def s2v_rope_apply_qk(q, k, grid_sizes, freqs): |
|
|
q = s2v_rope_apply(q, grid_sizes, freqs) |
|
|
k = s2v_rope_apply(k, grid_sizes, freqs) |
|
|
return q, k |
|
|
|
|
|
|
|
|
class WanS2VSelfAttention(WanSelfAttention): |
|
|
|
|
|
def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0): |
|
|
""" |
|
|
Args: |
|
|
x(Tensor): Shape [B, L, num_heads, C / num_heads] |
|
|
seq_lens(Tensor): Shape [B] |
|
|
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) |
|
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] |
|
|
""" |
|
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
|
|
|
|
|
|
|
|
def qkv_fn(x): |
|
|
q = self.norm_q(self.q(x)).view(b, s, n, d) |
|
|
k = self.norm_k(self.k(x)).view(b, s, n, d) |
|
|
v = self.v(x).view(b, s, n, d) |
|
|
return q, k, v |
|
|
|
|
|
q, k, v = qkv_fn(x) |
|
|
|
|
|
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs) |
|
|
|
|
|
x = attention( |
|
|
q.to(dtype), |
|
|
k.to(dtype), |
|
|
v=v.to(dtype), |
|
|
k_lens=seq_lens, |
|
|
window_size=self.window_size) |
|
|
x = x.to(dtype) |
|
|
|
|
|
|
|
|
x = x.flatten(2) |
|
|
x = self.o(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class WanS2VAttentionBlock(WanAttentionBlock): |
|
|
|
|
|
def __init__(self, |
|
|
cross_attn_type, |
|
|
dim, |
|
|
ffn_dim, |
|
|
num_heads, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=False, |
|
|
eps=1e-6): |
|
|
super().__init__( |
|
|
cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps |
|
|
) |
|
|
self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps) |
|
|
|
|
|
def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0): |
|
|
|
|
|
seg_idx = e[1].item() |
|
|
seg_idx = min(max(0, seg_idx), x.size(1)) |
|
|
seg_idx = [0, seg_idx, x.size(1)] |
|
|
e = e[0] |
|
|
modulation = self.modulation.unsqueeze(2) |
|
|
e = (modulation + e).chunk(6, dim=1) |
|
|
e = [element.squeeze(1) for element in e] |
|
|
|
|
|
|
|
|
norm_x = self.norm1(x).float() |
|
|
parts = [] |
|
|
for i in range(2): |
|
|
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * |
|
|
(1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1]) |
|
|
norm_x = torch.cat(parts, dim=1) |
|
|
|
|
|
y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs) |
|
|
with amp.autocast(dtype=torch.float32): |
|
|
z = [] |
|
|
for i in range(2): |
|
|
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1]) |
|
|
y = torch.cat(z, dim=1) |
|
|
x = x + y |
|
|
|
|
|
|
|
|
def cross_attn_ffn(x, context, context_lens, e): |
|
|
x = x + self.cross_attn(self.norm3(x), context, context_lens) |
|
|
norm2_x = self.norm2(x).float() |
|
|
parts = [] |
|
|
for i in range(2): |
|
|
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * |
|
|
(1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1]) |
|
|
norm2_x = torch.cat(parts, dim=1) |
|
|
y = self.ffn(norm2_x) |
|
|
with amp.autocast(dtype=torch.float32): |
|
|
z = [] |
|
|
for i in range(2): |
|
|
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1]) |
|
|
y = torch.cat(z, dim=1) |
|
|
x = x + y |
|
|
return x |
|
|
|
|
|
x = cross_attn_ffn(x, context, context_lens, e) |
|
|
return x |
|
|
|
|
|
|
|
|
class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
cond_dim=0, |
|
|
audio_dim=5120, |
|
|
num_audio_token=4, |
|
|
enable_adain=False, |
|
|
adain_mode="attn_norm", |
|
|
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27], |
|
|
zero_init=False, |
|
|
zero_timestep=False, |
|
|
enable_motioner=True, |
|
|
add_last_motion=True, |
|
|
enable_tsm=False, |
|
|
trainable_token_pos_emb=False, |
|
|
motion_token_num=1024, |
|
|
enable_framepack=False, |
|
|
framepack_drop_mode="drop", |
|
|
model_type='s2v', |
|
|
patch_size=(1, 2, 2), |
|
|
text_len=512, |
|
|
in_dim=16, |
|
|
dim=2048, |
|
|
ffn_dim=8192, |
|
|
freq_dim=256, |
|
|
text_dim=4096, |
|
|
out_dim=16, |
|
|
num_heads=16, |
|
|
num_layers=32, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=True, |
|
|
eps=1e-6, |
|
|
in_channels=16, |
|
|
hidden_size=2048, |
|
|
*args, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
model_type=model_type, |
|
|
patch_size=patch_size, |
|
|
text_len=text_len, |
|
|
in_dim=in_dim, |
|
|
dim=dim, |
|
|
ffn_dim=ffn_dim, |
|
|
freq_dim=freq_dim, |
|
|
text_dim=text_dim, |
|
|
out_dim=out_dim, |
|
|
num_heads=num_heads, |
|
|
num_layers=num_layers, |
|
|
window_size=window_size, |
|
|
qk_norm=qk_norm, |
|
|
cross_attn_norm=cross_attn_norm, |
|
|
eps=eps, |
|
|
in_channels=in_channels, |
|
|
hidden_size=hidden_size |
|
|
) |
|
|
|
|
|
assert model_type == 's2v' |
|
|
self.enbale_adain = enable_adain |
|
|
|
|
|
self.adain_mode = adain_mode |
|
|
self.zero_timestep = zero_timestep |
|
|
self.enable_motioner = enable_motioner |
|
|
self.add_last_motion = add_last_motion |
|
|
self.enable_framepack = enable_framepack |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm, |
|
|
cross_attn_norm, eps) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") |
|
|
if cond_dim > 0: |
|
|
self.cond_encoder = nn.Conv3d( |
|
|
cond_dim, |
|
|
self.dim, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_size) |
|
|
self.trainable_cond_mask = nn.Embedding(3, self.dim) |
|
|
self.casual_audio_encoder = CausalAudioEncoder( |
|
|
dim=audio_dim, |
|
|
out_dim=self.dim, |
|
|
num_token=num_audio_token, |
|
|
need_global=enable_adain) |
|
|
self.audio_injector = AudioInjector_WAN( |
|
|
all_modules, |
|
|
all_modules_names, |
|
|
dim=self.dim, |
|
|
num_heads=self.num_heads, |
|
|
inject_layer=audio_inject_layers, |
|
|
root_net=self, |
|
|
enable_adain=enable_adain, |
|
|
adain_dim=self.dim, |
|
|
need_adain_ont=adain_mode != "attn_norm", |
|
|
) |
|
|
|
|
|
if zero_init: |
|
|
self.zero_init_weights() |
|
|
|
|
|
|
|
|
if enable_motioner and enable_framepack: |
|
|
raise ValueError( |
|
|
"enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" |
|
|
) |
|
|
if enable_motioner: |
|
|
motioner_dim = 2048 |
|
|
self.motioner = MotionerTransformers( |
|
|
patch_size=(2, 4, 4), |
|
|
dim=motioner_dim, |
|
|
ffn_dim=motioner_dim, |
|
|
freq_dim=256, |
|
|
out_dim=16, |
|
|
num_heads=16, |
|
|
num_layers=13, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=False, |
|
|
eps=1e-6, |
|
|
motion_token_num=motion_token_num, |
|
|
enable_tsm=enable_tsm, |
|
|
motion_stride=4, |
|
|
expand_ratio=2, |
|
|
trainable_token_pos_emb=trainable_token_pos_emb, |
|
|
) |
|
|
self.zip_motion_out = torch.nn.Sequential( |
|
|
WanLayerNorm(motioner_dim), |
|
|
zero_module(nn.Linear(motioner_dim, self.dim))) |
|
|
|
|
|
self.trainable_token_pos_emb = trainable_token_pos_emb |
|
|
if trainable_token_pos_emb: |
|
|
d = self.dim // self.num_heads |
|
|
x = torch.zeros([1, motion_token_num, self.num_heads, d]) |
|
|
x[..., ::2] = 1 |
|
|
|
|
|
gride_sizes = [[ |
|
|
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([ |
|
|
1, self.motioner.motion_side_len, |
|
|
self.motioner.motion_side_len |
|
|
]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([ |
|
|
1, self.motioner.motion_side_len, |
|
|
self.motioner.motion_side_len |
|
|
]).unsqueeze(0).repeat(1, 1), |
|
|
]] |
|
|
token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs) |
|
|
token_freqs = token_freqs[0, :, |
|
|
0].reshape(motion_token_num, -1, 2) |
|
|
token_freqs = token_freqs * 0.01 |
|
|
self.token_freqs = torch.nn.Parameter(token_freqs) |
|
|
|
|
|
if enable_framepack: |
|
|
self.frame_packer = FramePackMotioner( |
|
|
inner_dim=self.dim, |
|
|
num_heads=self.num_heads, |
|
|
zip_frame_buckets=[1, 2, 16], |
|
|
drop_mode=framepack_drop_mode) |
|
|
|
|
|
def enable_multi_gpus_inference(self,): |
|
|
self.sp_world_size = get_sequence_parallel_world_size() |
|
|
self.sp_world_rank = get_sequence_parallel_rank() |
|
|
self.all_gather = get_sp_group().all_gather |
|
|
for block in self.blocks: |
|
|
block.self_attn.forward = types.MethodType( |
|
|
usp_attn_s2v_forward, block.self_attn) |
|
|
|
|
|
def process_motion(self, motion_latents, drop_motion_frames=False): |
|
|
if drop_motion_frames or motion_latents[0].shape[1] == 0: |
|
|
return [], [] |
|
|
self.lat_motion_frames = motion_latents[0].shape[1] |
|
|
mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents] |
|
|
batch_size = len(mot) |
|
|
|
|
|
mot_remb = [] |
|
|
flattern_mot = [] |
|
|
for bs in range(batch_size): |
|
|
height, width = mot[bs].shape[3], mot[bs].shape[4] |
|
|
flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() |
|
|
motion_grid_sizes = [[ |
|
|
torch.tensor([-self.lat_motion_frames, 0, |
|
|
0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([self.lat_motion_frames, height, |
|
|
width]).unsqueeze(0).repeat(1, 1) |
|
|
]] |
|
|
motion_rope_emb = rope_precompute( |
|
|
flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads, |
|
|
self.dim // self.num_heads), |
|
|
motion_grid_sizes, |
|
|
self.freqs, |
|
|
start=None) |
|
|
mot_remb.append(motion_rope_emb) |
|
|
flattern_mot.append(flat_mot) |
|
|
return flattern_mot, mot_remb |
|
|
|
|
|
def process_motion_frame_pack(self, |
|
|
motion_latents, |
|
|
drop_motion_frames=False, |
|
|
add_last_motion=2): |
|
|
flattern_mot, mot_remb = self.frame_packer(motion_latents, |
|
|
add_last_motion) |
|
|
if drop_motion_frames: |
|
|
return [m[:, :0] for m in flattern_mot |
|
|
], [m[:, :0] for m in mot_remb] |
|
|
else: |
|
|
return flattern_mot, mot_remb |
|
|
|
|
|
def process_motion_transformer_motioner(self, |
|
|
motion_latents, |
|
|
drop_motion_frames=False, |
|
|
add_last_motion=True): |
|
|
batch_size, height, width = len( |
|
|
motion_latents), motion_latents[0].shape[2] // self.patch_size[ |
|
|
1], motion_latents[0].shape[3] // self.patch_size[2] |
|
|
|
|
|
freqs = self.freqs |
|
|
device = self.patch_embedding.weight.device |
|
|
if freqs.device != device: |
|
|
freqs = freqs.to(device) |
|
|
if self.trainable_token_pos_emb: |
|
|
with amp.autocast(dtype=torch.float64): |
|
|
token_freqs = self.token_freqs.to(torch.float64) |
|
|
token_freqs = token_freqs / token_freqs.norm( |
|
|
dim=-1, keepdim=True) |
|
|
freqs = [freqs, torch.view_as_complex(token_freqs)] |
|
|
|
|
|
if not drop_motion_frames and add_last_motion: |
|
|
last_motion_latent = [u[:, -1:] for u in motion_latents] |
|
|
last_mot = [ |
|
|
self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent |
|
|
] |
|
|
last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot] |
|
|
last_mot = torch.cat(last_mot) |
|
|
gride_sizes = [[ |
|
|
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([0, height, |
|
|
width]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([1, height, |
|
|
width]).unsqueeze(0).repeat(batch_size, 1) |
|
|
]] |
|
|
else: |
|
|
last_mot = torch.zeros([batch_size, 0, self.dim], |
|
|
device=motion_latents[0].device, |
|
|
dtype=motion_latents[0].dtype) |
|
|
gride_sizes = [] |
|
|
|
|
|
zip_motion = self.motioner(motion_latents) |
|
|
zip_motion = self.zip_motion_out(zip_motion) |
|
|
if drop_motion_frames: |
|
|
zip_motion = zip_motion * 0.0 |
|
|
zip_motion_grid_sizes = [[ |
|
|
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([ |
|
|
0, self.motioner.motion_side_len, self.motioner.motion_side_len |
|
|
]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor( |
|
|
[1 if not self.trainable_token_pos_emb else -1, height, |
|
|
width]).unsqueeze(0).repeat(batch_size, 1), |
|
|
]] |
|
|
|
|
|
mot = torch.cat([last_mot, zip_motion], dim=1) |
|
|
gride_sizes = gride_sizes + zip_motion_grid_sizes |
|
|
|
|
|
motion_rope_emb = rope_precompute( |
|
|
mot.detach().view(batch_size, mot.shape[1], self.num_heads, |
|
|
self.dim // self.num_heads), |
|
|
gride_sizes, |
|
|
freqs, |
|
|
start=None) |
|
|
return [m.unsqueeze(0) for m in mot |
|
|
], [r.unsqueeze(0) for r in motion_rope_emb] |
|
|
|
|
|
def inject_motion(self, |
|
|
x, |
|
|
seq_lens, |
|
|
rope_embs, |
|
|
mask_input, |
|
|
motion_latents, |
|
|
drop_motion_frames=False, |
|
|
add_last_motion=True): |
|
|
|
|
|
if self.enable_motioner: |
|
|
mot, mot_remb = self.process_motion_transformer_motioner( |
|
|
motion_latents, |
|
|
drop_motion_frames=drop_motion_frames, |
|
|
add_last_motion=add_last_motion) |
|
|
elif self.enable_framepack: |
|
|
mot, mot_remb = self.process_motion_frame_pack( |
|
|
motion_latents, |
|
|
drop_motion_frames=drop_motion_frames, |
|
|
add_last_motion=add_last_motion) |
|
|
else: |
|
|
mot, mot_remb = self.process_motion( |
|
|
motion_latents, drop_motion_frames=drop_motion_frames) |
|
|
|
|
|
if len(mot) > 0: |
|
|
x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] |
|
|
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], |
|
|
dtype=torch.long) |
|
|
rope_embs = [ |
|
|
torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb) |
|
|
] |
|
|
mask_input = [ |
|
|
torch.cat([ |
|
|
m, 2 * torch.ones([1, u.shape[1] - m.shape[1]], |
|
|
device=m.device, |
|
|
dtype=m.dtype) |
|
|
], |
|
|
dim=1) for m, u in zip(mask_input, x) |
|
|
] |
|
|
return x, seq_lens, rope_embs, mask_input |
|
|
|
|
|
def after_transformer_block(self, block_idx, hidden_states): |
|
|
if block_idx in self.audio_injector.injected_block_id.keys(): |
|
|
audio_attn_id = self.audio_injector.injected_block_id[block_idx] |
|
|
audio_emb = self.merged_audio_emb |
|
|
num_frames = audio_emb.shape[1] |
|
|
|
|
|
if self.sp_world_size > 1: |
|
|
hidden_states = self.all_gather(hidden_states, dim=1) |
|
|
|
|
|
input_hidden_states = hidden_states[:, :self.original_seq_len].clone() |
|
|
input_hidden_states = rearrange( |
|
|
input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) |
|
|
|
|
|
if self.enbale_adain and self.adain_mode == "attn_norm": |
|
|
audio_emb_global = self.audio_emb_global |
|
|
audio_emb_global = rearrange(audio_emb_global, |
|
|
"b t n c -> (b t) n c") |
|
|
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id]( |
|
|
input_hidden_states, temb=audio_emb_global[:, 0] |
|
|
) |
|
|
attn_hidden_states = adain_hidden_states |
|
|
else: |
|
|
attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id]( |
|
|
input_hidden_states |
|
|
) |
|
|
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) |
|
|
attn_audio_emb = audio_emb |
|
|
context_lens = torch.ones( |
|
|
attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device |
|
|
) * attn_audio_emb.shape[1] |
|
|
residual_out = self.audio_injector.injector[audio_attn_id]( |
|
|
x=attn_hidden_states, |
|
|
context=attn_audio_emb, |
|
|
context_lens=context_lens) |
|
|
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) |
|
|
hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out |
|
|
|
|
|
if self.sp_world_size > 1: |
|
|
hidden_states = torch.chunk( |
|
|
hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
t, |
|
|
context, |
|
|
seq_len, |
|
|
ref_latents, |
|
|
motion_latents, |
|
|
cond_states, |
|
|
audio_input=None, |
|
|
motion_frames=[17, 5], |
|
|
add_last_motion=2, |
|
|
drop_motion_frames=False, |
|
|
cond_flag=True, |
|
|
*extra_args, |
|
|
**extra_kwargs |
|
|
): |
|
|
""" |
|
|
x: A list of videos each with shape [C, T, H, W]. |
|
|
t: [B]. |
|
|
context: A list of text embeddings each with shape [L, C]. |
|
|
seq_len: A list of video token lens, no need for this model. |
|
|
ref_latents A list of reference image for each video with shape [C, 1, H, W]. |
|
|
motion_latents A list of motion frames for each video with shape [C, T_m, H, W]. |
|
|
cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W]. |
|
|
audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. |
|
|
motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5] |
|
|
add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. |
|
|
For frame packing, the behavior depends on the value of add_last_motion: |
|
|
add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. |
|
|
add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. |
|
|
add_last_motion = 2: All motion-related latents are used. |
|
|
drop_motion_frames Bool, whether drop the motion frames info |
|
|
""" |
|
|
device = self.patch_embedding.weight.device |
|
|
dtype = x.dtype |
|
|
add_last_motion = self.add_last_motion * add_last_motion |
|
|
|
|
|
|
|
|
x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
|
|
|
|
|
|
|
|
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) |
|
|
audio_emb_res = self.casual_audio_encoder(audio_input) |
|
|
if self.enbale_adain: |
|
|
audio_emb_global, audio_emb = audio_emb_res |
|
|
self.audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() |
|
|
else: |
|
|
audio_emb = audio_emb_res |
|
|
self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :] |
|
|
|
|
|
|
|
|
cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] |
|
|
x = [x_ + pose for x_, pose in zip(x, cond)] |
|
|
|
|
|
grid_sizes = torch.stack( |
|
|
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
|
|
x = [u.flatten(2).transpose(1, 2) for u in x] |
|
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
|
|
|
|
|
original_grid_sizes = deepcopy(grid_sizes) |
|
|
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] |
|
|
|
|
|
|
|
|
ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents] |
|
|
batch_size = len(ref) |
|
|
height, width = ref[0].shape[3], ref[0].shape[4] |
|
|
ref = [r.flatten(2).transpose(1, 2) for r in ref] |
|
|
x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)] |
|
|
|
|
|
self.original_seq_len = seq_lens[0] |
|
|
seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long) |
|
|
ref_grid_sizes = [ |
|
|
[ |
|
|
torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), |
|
|
] |
|
|
] |
|
|
grid_sizes = grid_sizes + ref_grid_sizes |
|
|
|
|
|
|
|
|
x = torch.cat(x) |
|
|
b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads |
|
|
self.pre_compute_freqs = rope_precompute( |
|
|
x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None) |
|
|
x = [u.unsqueeze(0) for u in x] |
|
|
self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_input = [ |
|
|
torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device) |
|
|
for u in x |
|
|
] |
|
|
for i in range(len(mask_input)): |
|
|
mask_input[i][:, self.original_seq_len:] = 1 |
|
|
|
|
|
self.lat_motion_frames = motion_latents[0].shape[1] |
|
|
x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion( |
|
|
x, |
|
|
seq_lens, |
|
|
self.pre_compute_freqs, |
|
|
mask_input, |
|
|
motion_latents, |
|
|
drop_motion_frames=drop_motion_frames, |
|
|
add_last_motion=add_last_motion) |
|
|
x = torch.cat(x, dim=0) |
|
|
self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0) |
|
|
mask_input = torch.cat(mask_input, dim=0) |
|
|
|
|
|
|
|
|
x = x + self.trainable_cond_mask(mask_input).to(x.dtype) |
|
|
|
|
|
seq_len = seq_lens.max() |
|
|
if self.sp_world_size > 1: |
|
|
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size |
|
|
assert seq_lens.max() <= seq_len |
|
|
x = torch.cat([ |
|
|
torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))], |
|
|
dim=1) for u in x |
|
|
]) |
|
|
|
|
|
|
|
|
if self.zero_timestep: |
|
|
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)]) |
|
|
with amp.autocast(dtype=torch.float32): |
|
|
e = self.time_embedding( |
|
|
sinusoidal_embedding_1d(self.freq_dim, t).float()) |
|
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
|
|
assert e.dtype == torch.float32 and e0.dtype == torch.float32 |
|
|
if self.zero_timestep: |
|
|
e = e[:-1] |
|
|
zero_e0 = e0[-1:] |
|
|
e0 = e0[:-1] |
|
|
token_len = x.shape[1] |
|
|
|
|
|
e0 = torch.cat( |
|
|
[ |
|
|
e0.unsqueeze(2), |
|
|
zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1) |
|
|
], |
|
|
dim=2 |
|
|
) |
|
|
e0 = [e0, self.original_seq_len] |
|
|
else: |
|
|
e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1) |
|
|
e0 = [e0, 0] |
|
|
|
|
|
|
|
|
context_lens = None |
|
|
context = self.text_embedding( |
|
|
torch.stack([ |
|
|
torch.cat( |
|
|
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) |
|
|
for u in context |
|
|
])) |
|
|
|
|
|
if self.sp_world_size > 1: |
|
|
|
|
|
x = torch.chunk(x, self.sp_world_size, dim=1) |
|
|
sq_size = [u.shape[1] for u in x] |
|
|
sq_start_size = sum(sq_size[:self.sp_world_rank]) |
|
|
x = x[self.sp_world_rank] |
|
|
|
|
|
|
|
|
|
|
|
sp_size = x.shape[1] |
|
|
seg_idx = e0[1] - sq_start_size |
|
|
e0[1] = seg_idx |
|
|
|
|
|
self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1) |
|
|
self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank] |
|
|
|
|
|
|
|
|
if self.teacache is not None: |
|
|
if cond_flag: |
|
|
if t.dim() != 1: |
|
|
modulated_inp = e0[0][:, -1, :] |
|
|
else: |
|
|
modulated_inp = e0[0] |
|
|
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps |
|
|
if skip_flag: |
|
|
self.should_calc = True |
|
|
self.teacache.accumulated_rel_l1_distance = 0 |
|
|
else: |
|
|
if cond_flag: |
|
|
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) |
|
|
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) |
|
|
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: |
|
|
self.should_calc = False |
|
|
else: |
|
|
self.should_calc = True |
|
|
self.teacache.accumulated_rel_l1_distance = 0 |
|
|
self.teacache.previous_modulated_input = modulated_inp |
|
|
self.teacache.should_calc = self.should_calc |
|
|
else: |
|
|
self.should_calc = self.teacache.should_calc |
|
|
|
|
|
|
|
|
if self.teacache is not None: |
|
|
if not self.should_calc: |
|
|
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond |
|
|
x = x + previous_residual.to(x.device)[-x.size()[0]:,] |
|
|
else: |
|
|
ori_x = x.clone().cpu() if self.teacache.offload else x.clone() |
|
|
|
|
|
for idx, block in enumerate(self.blocks): |
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
x = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
x, |
|
|
e0, |
|
|
seq_lens, |
|
|
grid_sizes, |
|
|
self.pre_compute_freqs, |
|
|
context, |
|
|
context_lens, |
|
|
dtype, |
|
|
t, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
x = self.after_transformer_block(idx, x) |
|
|
else: |
|
|
|
|
|
kwargs = dict( |
|
|
e=e0, |
|
|
seq_lens=seq_lens, |
|
|
grid_sizes=grid_sizes, |
|
|
freqs=self.pre_compute_freqs, |
|
|
context=context, |
|
|
context_lens=context_lens, |
|
|
dtype=dtype, |
|
|
t=t |
|
|
) |
|
|
x = block(x, **kwargs) |
|
|
x = self.after_transformer_block(idx, x) |
|
|
|
|
|
if cond_flag: |
|
|
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x |
|
|
else: |
|
|
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x |
|
|
else: |
|
|
for idx, block in enumerate(self.blocks): |
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
x = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
x, |
|
|
e0, |
|
|
seq_lens, |
|
|
grid_sizes, |
|
|
self.pre_compute_freqs, |
|
|
context, |
|
|
context_lens, |
|
|
dtype, |
|
|
t, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
x = self.after_transformer_block(idx, x) |
|
|
else: |
|
|
|
|
|
kwargs = dict( |
|
|
e=e0, |
|
|
seq_lens=seq_lens, |
|
|
grid_sizes=grid_sizes, |
|
|
freqs=self.pre_compute_freqs, |
|
|
context=context, |
|
|
context_lens=context_lens, |
|
|
dtype=dtype, |
|
|
t=t |
|
|
) |
|
|
x = block(x, **kwargs) |
|
|
x = self.after_transformer_block(idx, x) |
|
|
|
|
|
|
|
|
if self.sp_world_size > 1: |
|
|
x = self.all_gather(x.contiguous(), dim=1) |
|
|
|
|
|
|
|
|
x = x[:, :self.original_seq_len] |
|
|
|
|
|
x = self.head(x, e) |
|
|
x = self.unpatchify(x, original_grid_sizes) |
|
|
x = torch.stack(x) |
|
|
if self.teacache is not None and cond_flag: |
|
|
self.teacache.cnt += 1 |
|
|
if self.teacache.cnt == self.teacache.num_steps: |
|
|
self.teacache.reset() |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x, grid_sizes): |
|
|
""" |
|
|
Reconstruct video tensors from patch embeddings. |
|
|
|
|
|
Args: |
|
|
x (List[Tensor]): |
|
|
List of patchified features, each with shape [L, C_out * prod(patch_size)] |
|
|
grid_sizes (Tensor): |
|
|
Original spatial-temporal grid dimensions before patching, |
|
|
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) |
|
|
|
|
|
Returns: |
|
|
List[Tensor]: |
|
|
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] |
|
|
""" |
|
|
|
|
|
c = self.out_dim |
|
|
out = [] |
|
|
for u, v in zip(x, grid_sizes.tolist()): |
|
|
u = u[:math.prod(v)].view(*v, *self.patch_size, c) |
|
|
u = torch.einsum('fhwpqrc->cfphqwr', u) |
|
|
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) |
|
|
out.append(u) |
|
|
return out |
|
|
|
|
|
def zero_init_weights(self): |
|
|
with torch.no_grad(): |
|
|
self.trainable_cond_mask = zero_module(self.trainable_cond_mask) |
|
|
if hasattr(self, "cond_encoder"): |
|
|
self.cond_encoder = zero_module(self.cond_encoder) |
|
|
|
|
|
for i in range(self.audio_injector.injector.__len__()): |
|
|
self.audio_injector.injector[i].o = zero_module( |
|
|
self.audio_injector.injector[i].o) |
|
|
if self.enbale_adain: |
|
|
self.audio_injector.injector_adain_layers[i].linear = \ |
|
|
zero_module(self.audio_injector.injector_adain_layers[i].linear) |