# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import copy import math import numpy as np import os import torch import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F from einops import rearrange from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from .attention import flash_attention, SingleStreamAttention, sdpa_attention, flex_attention from fp8_gemm import FP8Linear import logging try: from sageattention import sageattn USE_SAGEATTN = True logging.info("Using sageattn") except: USE_SAGEATTN = False __all__ = ['WanModel'] def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 position = position.type(torch.float64) # calculation sinusoid = torch.outer( position, torch.pow(10000, -torch.arange(half).to(position).div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x # @amp.autocast(enabled=False) def rope_params(max_seq_len, dim, theta=10000): assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): s, n, c = x.size(1), x.size(2), x.size(3) // 2 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = s f = int(seq_len // (h * w)) x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) freqs_i = freqs_i.to(device=x_i.device) x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.cat([x_i, x[i, seq_len:]]) output.append(x_i) return torch.stack(output) # .float() def rope_apply(x, grid_sizes, freqs, f_list=[], rope_list=[]): s, n, c = x.size(1), x.size(2), x.size(3) // 2 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) output = [] for f_l, r_l in zip(f_list, rope_list): start_f, end_f = f_l start_r, end_r = r_l f = end_f - start_f _, h, w = grid_sizes.tolist()[0] seq_len = (end_f - start_f) * h * w x_i = torch.view_as_complex( x[0, start_f * h * w:end_f * h * w].to(torch.float64) \ .reshape(seq_len, n, -1, 2) ) freqs_i = torch.cat([ freqs[0][start_r:end_r].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) freqs_i = freqs_i.to(device=x_i.device) x_i = torch.view_as_real(x_i * freqs_i).flatten(2) output.append(x_i) return torch.concat(output, dim=0).unsqueeze(0) class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return self._norm(x.float()).to(dtype=x.dtype) * self.weight.to(dtype=x.dtype) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) class WanLayerNorm(nn.LayerNorm): def __init__(self, dim, eps=1e-6, elementwise_affine=False): super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype out = F.layer_norm( inputs.float(), self.normalized_shape, None if self.weight is None else self.weight.float(), None if self.bias is None else self.bias.float(), self.eps ).to(origin_dtype) return out class WanSelfAttention(nn.Module): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.attn_mask = None self.frame_seqlen = None self.memory_proj_k = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False) self.memory_proj_v = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False) def post_init(self, device): self.memory_proj_k = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False).to( device, dtype=torch.bfloat16) self.memory_proj_v = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False).to( device, dtype=torch.bfloat16) nn.init.constant_(self.memory_proj_k.weight, 1.0 / 5.0) nn.init.constant_(self.memory_proj_v.weight, 1.0 / 5.0) def k_compress(self, k, n_frame=5): B, N, H, C = k.shape assert N % n_frame == 0 T = N // n_frame k = k.view(B, N, H * C).transpose(1, 2) k = self.memory_proj_k(k) k = k.view(B, H, C, T).permute(0, 3, 1, 2) return k def v_compress(self, v, n_frame=5): B, N, H, C = v.shape assert N % n_frame == 0 T = N // n_frame v = v.view(B, N, H * C).transpose(1, 2) v = self.memory_proj_k(v) v = v.view(B, H, C, T).permute(0, 3, 1, 2) return v def kv_mean(self, kv, n_frame=5): B, N, H, C = kv.shape assert N % n_frame == 0 T = N // n_frame kv = kv.view(B, T, n_frame, H, C).mean(dim=2) return kv def init_kvidx(self, frame_len, world_size): self.frame_seqlen = frame_len self.kv_idx0 = torch.tensor(list(range(6 * frame_len // world_size)), device=f'cuda:{int(os.getenv("RANK", 0))}') self.kv_idx2 = torch.tensor(list(range(14 * frame_len // world_size)), device=f'cuda:{int(os.getenv("RANK", 0))}') def _move_kv_cache_to_device(self, kv_cache, device): kv_cache["k"] = kv_cache["k"].to(device=device, non_blocking=True) kv_cache["v"] = kv_cache["v"].to(device=device, non_blocking=True) if kv_cache.get("k_scale") is not None: kv_cache["k_scale"] = kv_cache["k_scale"].to(device=device, non_blocking=True) if kv_cache.get("v_scale") is not None: kv_cache["v_scale"] = kv_cache["v_scale"].to(device=device, non_blocking=True) def _quantize_kv_tensor(self, kv): fp8_max = torch.finfo(torch.float8_e4m3fn).max scale = kv.detach().abs().amax(dim=-1, keepdim=True).to(torch.float32) scale = torch.clamp(scale / fp8_max, min=1e-12) q_kv = (kv / scale.to(dtype=kv.dtype)).to(torch.float8_e4m3fn) return q_kv.contiguous(), scale.contiguous() def _dequantize_kv_tensor(self, q_kv, scale, dtype): return q_kv.to(dtype=dtype) * scale.to(device=q_kv.device, dtype=dtype) def _load_kv_cache(self, kv_cache, device, dtype): if kv_cache["offload_cache"]: self._move_kv_cache_to_device(kv_cache, device) if kv_cache.get("fp8_kv_cache", False): k_cache = self._dequantize_kv_tensor(kv_cache["k"], kv_cache["k_scale"], dtype) v_cache = self._dequantize_kv_tensor(kv_cache["v"], kv_cache["v_scale"], dtype) else: if kv_cache["k"].dtype != dtype: kv_cache["k"] = kv_cache["k"].to(dtype=dtype) if kv_cache["v"].dtype != dtype: kv_cache["v"] = kv_cache["v"].to(dtype=dtype) k_cache = kv_cache["k"] v_cache = kv_cache["v"] return k_cache, v_cache def _store_kv_cache(self, kv_cache, k_cache, v_cache): if kv_cache.get("fp8_kv_cache", False): kv_cache["k"], kv_cache["k_scale"] = self._quantize_kv_tensor(k_cache) kv_cache["v"], kv_cache["v_scale"] = self._quantize_kv_tensor(v_cache) else: kv_cache["k"] = k_cache kv_cache["v"] = v_cache if kv_cache["offload_cache"]: self._move_kv_cache_to_device(kv_cache, 'cpu') def forward(self, x, seq_lens, grid_sizes, freqs, kv_cache={}, start_idx=None, end_idx=None, update_cache=False): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) k_cache, v_cache = self._load_kv_cache(kv_cache, f'cuda:{int(os.getenv("RANK", 0))}', torch.bfloat16) frame_seqlen = self.frame_seqlen if frame_seqlen is None: raise RuntimeError("WanSelfAttention.init_kvidx() must be called before forward().") current_start_frame = start_idx // frame_seqlen if update_cache: if kv_cache["mean_memory"]: k_compress, v_compress = self.kv_mean, self.kv_mean else: k_compress, v_compress = self.k_compress, self.v_compress k_cache[:, 2 * frame_seqlen: 3 * frame_seqlen].copy_( k_compress(k_cache[:, 2 * frame_seqlen: 7 * frame_seqlen])) v_cache[:, 2 * frame_seqlen: 3 * frame_seqlen].copy_( v_compress(v_cache[:, 2 * frame_seqlen: 7 * frame_seqlen])) k_cache[:, 3 * frame_seqlen: 4 * frame_seqlen].copy_( k_compress(k_cache[:, 7 * frame_seqlen: 12 * frame_seqlen])) v_cache[:, 3 * frame_seqlen: 4 * frame_seqlen].copy_( v_compress(v_cache[:, 7 * frame_seqlen: 12 * frame_seqlen])) k_cache[:, 4 * frame_seqlen: 6 * frame_seqlen].copy_(k_cache[:, 12 * frame_seqlen: 14 * frame_seqlen]) v_cache[:, 4 * frame_seqlen: 6 * frame_seqlen].copy_(v_cache[:, 12 * frame_seqlen: 14 * frame_seqlen]) if start_idx != 0: k_cache[:, 6 * frame_seqlen:] = k v_cache[:, 6 * frame_seqlen:] = v else: k_cache[:, : 6 * frame_seqlen] = k v_cache[:, : 6 * frame_seqlen] = v roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) roped_key = causal_rope_apply(k_cache, grid_sizes, freqs, start_frame=0).type_as(v) if USE_SAGEATTN: x = sageattn( roped_query, roped_key[:, :end_idx, ...], v_cache[:, :end_idx, ...], tensor_layout="NHD", is_causal=False, ).type_as(x) else: x = sdpa_attention( q=roped_query, k=roped_key[:, :end_idx, ...], v=v_cache[:, :end_idx, ...], k_lens=seq_lens, window_size=self.window_size, attn_mask=self.attn_mask, ).type_as(x) self._store_kv_cache(kv_cache, k_cache, v_cache) # output x = x.flatten(2) x = self.o(x) return x, None class WanI2VCrossAttention(nn.Module): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, context, context_lens, cross_kv_cache={}): context_img = context[:, :257] context = context[:, 257:] b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x)).view(b, -1, n, d) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) v_img = self.v_img(context_img).view(b, -1, n, d) # if not cross_kv_cache: # # print('----init cross_kv_cache!!!') # k = self.norm_k(self.k(context)).view(b, -1, n, d) # v = self.v(context).view(b, -1, n, d) # k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) # v_img = self.v_img(context_img).view(b, -1, n, d) # cross_kv_cache['k'], cross_kv_cache['v'], cross_kv_cache['k_img'], cross_kv_cache['v_img'] = \ # k, v, k_img, v_img # else: # # print('----use cross_kv_cache!!!') # k, v, k_img, v_img = \ # cross_kv_cache['k'], cross_kv_cache['v'], cross_kv_cache['k_img'], cross_kv_cache['v_img'] if USE_SAGEATTN: img_x = sageattn(q, k_img, v_img, tensor_layout='NHD') x = sageattn(q, k, v, tensor_layout='NHD') else: # img_x = flash_attention(q, k_img, v_img, k_lens=None) img_x = sdpa_attention(q, k_img, v_img, k_lens=None) # compute attention # x = flash_attention(q, k, v, k_lens=context_lens) x = sdpa_attention(q, k, v, k_lens=context_lens) # output x = x.flatten(2) img_x = img_x.flatten(2) x = x + img_x x = self.o(x) return x class WanAttentionBlock(nn.Module): def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6, output_dim=768, norm_input_visual=True, class_range=24, class_interval=4): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WanI2VCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) self.norm2 = WanLayerNorm(dim, eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) # init audio module self.audio_cross_attn = SingleStreamAttention( dim=dim, encoder_hidden_states_dim=output_dim, num_heads=num_heads, qk_norm=False, qkv_bias=True, eps=eps, norm_layer=WanRMSNorm, # class_range=class_range, # class_interval=class_interval ) self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity() def forward( self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, kv_cache={}, start_idx=None, end_idx=None, update_cache=False, cross_kv_cache={}, audio_embedding=None, ref_target_masks=None, human_num=None, skip_audio=False, ): dtype = x.dtype # assert e.dtype == torch.float32 if len(e.shape) == 3: # with amp.autocast(dtype=torch.float32): e = (self.modulation.to(e.device) + e).chunk(6, dim=1) else: # with amp.autocast(dtype=torch.float32): e = (self.modulation.unsqueeze(-2).to(e.device) + e)[0].chunk(6, dim=0) # assert e[0].dtype == torch.float32 # self-attention y, x_ref_attn_map = self.self_attn( (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes, freqs, kv_cache=kv_cache, start_idx=start_idx, end_idx=end_idx, update_cache=update_cache) # with amp.autocast(dtype=torch.float32): x = x + y * e[2] x = x.to(dtype) # cross-attention of text x = x + self.cross_attn(self.norm3(x), context, context_lens, cross_kv_cache=cross_kv_cache) # cross attn of audio if not skip_audio: frame_seqlen = self.self_attn.frame_seqlen start_f = start_idx // frame_seqlen x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding, frame_seqlen=frame_seqlen, start_f=start_f, USE_SAGEATTN=USE_SAGEATTN) if start_f == 0: x_a[:, :frame_seqlen] = 0 x = x + x_a y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype)) # with amp.autocast(dtype=torch.float32): x = x + y * e[5] x = x.to(dtype) return x class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): super().__init__() self.dim = dim self.out_dim = out_dim self.patch_size = patch_size self.eps = eps # layers out_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, out_dim) # modulation self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) def forward(self, x, e): r""" Args: x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ # assert e.dtype == torch.float32 # with amp.autocast(dtype=torch.float32): e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x class MLPProj(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class AudioProjModel(ModelMixin, ConfigMixin): def __init__( self, seq_len=5, seq_len_vf=12, blocks=12, channels=768, intermediate_dim=512, output_dim=768, context_tokens=32, norm_output_audio=False, ): super().__init__() self.seq_len = seq_len self.blocks = blocks self.channels = channels self.input_dim = seq_len * blocks * channels self.input_dim_vf = seq_len_vf * blocks * channels self.intermediate_dim = intermediate_dim self.context_tokens = context_tokens self.output_dim = output_dim # define multiple linear layers self.proj1 = nn.Linear(self.input_dim, intermediate_dim) self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim) self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity() def forward(self, audio_embeds, audio_embeds_vf): video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] B, _, _, S, C = audio_embeds.shape # process audio of first frame audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") batch_size, window_size, blocks, channels = audio_embeds.shape audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) # process audio of latter frame audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) # first projection audio_embeds = torch.relu(self.proj1(audio_embeds)) audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) batch_size_c, N_t, C_a = audio_embeds_c.shape audio_embeds_c = audio_embeds_c.view(batch_size_c * N_t, C_a) # second projection audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c * N_t, self.context_tokens, self.output_dim) # normalization and reshape # with amp.autocast(dtype=torch.float32): context_tokens = self.norm(context_tokens) context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) return context_tokens from torch.utils.checkpoint import checkpoint class WanBlockOffloadManager: def __init__(self, blocks, onload_device, offload_device='cpu'): self.blocks = blocks self.onload_device = torch.device(onload_device) self.offload_device = torch.device(offload_device) self.prefetch_stream = torch.cuda.Stream(device=self.onload_device) self.compute_slot = 0 self.prefetch_slot = 1 self.pending_slots = set() self.slot_block_indices = [None, None] self.cuda_blocks = nn.ModuleList([ copy.deepcopy(self.blocks[0]).to(self.onload_device), copy.deepcopy(self.blocks[0]).to(self.onload_device), ]) for block in self.blocks: block.to(self.offload_device) self._pin_module_memory(block) def _copy_tensor(self, dst, src): dst.copy_(src, non_blocking=True) def _pin_tensor(self, tensor): if tensor is None or tensor.device.type != 'cpu' or tensor.is_pinned(): return tensor return tensor.pin_memory() def _pin_module_memory(self, module): for name, param in module.named_parameters(recurse=False): if param is not None: param.data = self._pin_tensor(param.data) for name, buffer in module.named_buffers(recurse=False): if buffer is not None: module._buffers[name] = self._pin_tensor(buffer) if isinstance(module, FP8Linear): module._fp16_weight_cpu = self._pin_tensor(module._fp16_weight_cpu) module._fp16_bias_cpu = self._pin_tensor(module._fp16_bias_cpu) for child in module.children(): self._pin_module_memory(child) def _copy_fp8_linear(self, dst_module, src_module): if dst_module.linear is not None and src_module.linear is not None: self._copy_module_state(dst_module.linear, src_module.linear) if dst_module.bias is not None and src_module.bias is not None: self._copy_tensor(dst_module.bias.data, src_module.bias.data) dst_module._fp16_weight_cpu = src_module._fp16_weight_cpu dst_module._fp16_bias_cpu = src_module._fp16_bias_cpu if src_module._fp8_weight is None or src_module._fp8_weight_scale is None: dst_module._fp8_weight = None dst_module._fp8_weight_scale = None dst_module._weight_cache_device = None if dst_module._fp16_weight_cpu is not None: dst_module.materialize_fp8_weight(self.onload_device) else: if dst_module._fp8_weight is None or dst_module._fp8_weight.shape != src_module._fp8_weight.shape: dst_module._fp8_weight = src_module._fp8_weight.to(device=self.onload_device, non_blocking=True) else: self._copy_tensor(dst_module._fp8_weight, src_module._fp8_weight) if dst_module._fp8_weight_scale is None or dst_module._fp8_weight_scale.shape != src_module._fp8_weight_scale.shape: dst_module._fp8_weight_scale = src_module._fp8_weight_scale.to(device=self.onload_device, non_blocking=True) else: self._copy_tensor(dst_module._fp8_weight_scale, src_module._fp8_weight_scale) dst_module._weight_cache_device = dst_module._cached_fp8_device() dst_module._last_weight_version = src_module._last_weight_version def _copy_module_state(self, dst_module, src_module): if isinstance(dst_module, FP8Linear) and isinstance(src_module, FP8Linear): self._copy_fp8_linear(dst_module, src_module) return dst_params = dict(dst_module.named_parameters(recurse=False)) src_params = dict(src_module.named_parameters(recurse=False)) for name, dst_param in dst_params.items(): src_param = src_params.get(name) if src_param is not None: self._copy_tensor(dst_param.data, src_param.data) dst_buffers = dict(dst_module.named_buffers(recurse=False)) src_buffers = dict(src_module.named_buffers(recurse=False)) for name, dst_buffer in dst_buffers.items(): src_buffer = src_buffers.get(name) if src_buffer is not None: self._copy_tensor(dst_buffer, src_buffer) dst_children = dict(dst_module.named_children()) src_children = dict(src_module.named_children()) for name, dst_child in dst_children.items(): src_child = src_children.get(name) if src_child is not None: self._copy_module_state(dst_child, src_child) if hasattr(src_module, "frame_seqlen"): dst_module.frame_seqlen = src_module.frame_seqlen if hasattr(src_module, "kv_idx0"): dst_module.kv_idx0 = src_module.kv_idx0 if hasattr(src_module, "kv_idx2"): dst_module.kv_idx2 = src_module.kv_idx2 def _load_slot(self, slot_idx, block_idx, async_transfer=False): def copy_block(): self._copy_module_state(self.cuda_blocks[slot_idx], self.blocks[block_idx]) self.slot_block_indices[slot_idx] = block_idx if async_transfer: with torch.cuda.stream(self.prefetch_stream): copy_block() self.pending_slots.add(slot_idx) else: copy_block() self.pending_slots.discard(slot_idx) def _wait_slot(self, slot_idx): if slot_idx in self.pending_slots: torch.cuda.current_stream(device=self.onload_device).wait_stream(self.prefetch_stream) self.pending_slots.discard(slot_idx) def get_block(self, block_idx): if self.slot_block_indices[self.compute_slot] == block_idx: self._wait_slot(self.compute_slot) elif self.slot_block_indices[self.prefetch_slot] == block_idx: self._wait_slot(self.prefetch_slot) self.compute_slot, self.prefetch_slot = self.prefetch_slot, self.compute_slot else: self._load_slot(self.compute_slot, block_idx, async_transfer=False) next_idx = block_idx + 1 if next_idx < len(self.blocks) and self.slot_block_indices[self.prefetch_slot] != next_idx: # We are about to overwrite self.prefetch_slot on the prefetch stream. # Must ensure the compute stream has finished using it from previous steps. self.prefetch_stream.wait_stream(torch.cuda.current_stream(device=self.onload_device)) self._load_slot(self.prefetch_slot, next_idx, async_transfer=True) return self.cuda_blocks[self.compute_slot] def unload_all(self): torch.cuda.current_stream(device=self.onload_device).wait_stream(self.prefetch_stream) self.pending_slots.clear() self.slot_block_indices = [None, None] class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ ignore_for_config = [ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' ] _no_split_modules = ['WanAttentionBlock'] @register_to_config def __init__(self, model_type='i2v', patch_size=(1, 2, 2), text_len=512, in_dim=16, dim=2048, ffn_dim=8192, freq_dim=256, text_dim=4096, out_dim=16, num_heads=16, num_layers=32, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, # audio params audio_window=5, intermediate_dim=512, output_dim=768, context_tokens=32, vae_scale=4, # vae timedownsample scale norm_input_visual=True, norm_output_audio=True, weight_init=True): super().__init__() assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.' self.model_type = model_type self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim self.dim = dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.text_dim = text_dim self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps self.gradient_checkpointing = False self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale self.return_layers_cosine = False self.cos_sims = [] self.skip_layer = [] self.block_offload_manager = None self.block_offload_enabled = False # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks cross_attn_type = 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, output_dim=output_dim, norm_input_visual=norm_input_visual) for _ in range(num_layers) ]) # head self.head = Head(dim, out_dim, patch_size, eps) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads self.freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) if model_type == 'i2v': self.img_emb = MLPProj(1280, dim) else: raise NotImplementedError('Not supported model type.') # init audio adapter self.audio_proj = AudioProjModel( seq_len=audio_window, seq_len_vf=audio_window + vae_scale - 1, intermediate_dim=intermediate_dim, output_dim=output_dim, context_tokens=context_tokens, norm_output_audio=norm_output_audio, ) # initialize weights if weight_init: self.init_weights() def init_freqs(self): d = self.dim // self.num_heads self.freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) def enable_block_offload(self, onload_device=None, offload_device='cpu'): if onload_device is None: onload_device = self.patch_embedding.weight.device onload_device = torch.device(onload_device) if onload_device.type != 'cuda': raise ValueError("WanModel block offload requires a CUDA onload device.") self.block_offload_manager = WanBlockOffloadManager( self.blocks, onload_device=onload_device, offload_device=offload_device, ) self.block_offload_enabled = True torch.cuda.empty_cache() return self def forward( self, x, t, context, seq_len=None, clip_fea=None, y=None, audio=None, ref_target_masks=None, e0=None, kv_cache={}, start_idx=None, end_idx=None, cross_kv_cache={}, update_cache=True, skip_audio=False, ): assert clip_fea is not None and y is not None _, T, H, W = x[0].shape N_t = T // self.patch_size[0] N_h = H // self.patch_size[1] N_w = W // self.patch_size[2] if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x[0] = x[0].to(context[0].dtype) # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) x = torch.cat(x) # time embeddings if e0 is None: # with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # assert e.dtype == torch.float32 and e0.dtype == torch.float32 else: # with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()) # text embedding context_lens = None context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) # clip embedding if clip_fea is not None: context_clip = self.img_emb(clip_fea) context = torch.concat([context_clip, context], dim=1).to(x.dtype) audio_cond = audio.to(device=x.device, dtype=x.dtype) first_frame_audio_emb_s = audio_cond[:, :1, ...] latter_frame_audio_emb = audio_cond[:, 1:, ...] latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale) middle_index = self.audio_window // 2 latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index + 1, ...] latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index + 1, ...] latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_frame_audio_emb_s = torch.concat( [latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) human_num = len(audio_embedding) audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype) # convert ref_target_masks to token_ref_target_masks if ref_target_masks is not None: ref_target_masks = ref_target_masks.unsqueeze(0) # .to(torch.float32) token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest') token_ref_target_masks = token_ref_target_masks.squeeze(0) token_ref_target_masks = (token_ref_target_masks > 0) token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) token_ref_target_masks = token_ref_target_masks.to(x.dtype) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, audio_embedding=audio_embedding, ref_target_masks=token_ref_target_masks, human_num=human_num, start_idx=start_idx, end_idx=end_idx, update_cache=update_cache, skip_audio=skip_audio, ) block_offload_manager = self.block_offload_manager if self.block_offload_enabled else None if torch.is_grad_enabled() and self.gradient_checkpointing: for block_index, block in enumerate(self.blocks): if block_offload_manager is not None: block = block_offload_manager.get_block(block_index) if kv_cache.get(block_index) is None: kv_cache[block_index] = {} if cross_kv_cache.get(block_index) is None: cross_kv_cache[block_index] = {} x = checkpoint( block, x, kv_cache=kv_cache[block_index], cross_kv_cache=cross_kv_cache[block_index], use_reentrant=False, **kwargs ) else: for block_index, block in enumerate(self.blocks): if block_offload_manager is not None: block = block_offload_manager.get_block(block_index) if kv_cache.get(block_index) is None: kv_cache[block_index] = {} if cross_kv_cache.get(block_index) is None: cross_kv_cache[block_index] = {} x = block(x, kv_cache=kv_cache[block_index], cross_kv_cache=cross_kv_cache[block_index], **kwargs) # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return torch.stack(x) # .float() def unpatchify(self, x, grid_sizes): r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] """ c = self.out_dim out = [] for u, v in zip(x, grid_sizes.tolist()): u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out def init_weights(self): r""" Initialize model parameters using Xavier initialization. """ # basic init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # init embeddings nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) for m in self.text_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) for m in self.time_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) # init output layer nn.init.zeros_(self.head.head.weight)