打造社交APP人物动漫化:通义万相wan2.x训练优化指南
来源: 阿里云技术博客
1. 需求场景:AI特效生成
本项目旨在为社交类APP集成AIGC驱动的个人宣传视频生成功能,通过AI技术将用户上传的真人图像,转化为具有动漫风格的个性化短视频,尤其聚焦于“真人变身跳舞动漫仙女”的特定场景。项目采用通义万相系列AIGC模型,结合定制化训练与推理优化,打造高效、高质量、可商业落地的视频生成解决方案。
项目需求
客户希望在社交APP中新增一个功能模块:用户上传真人图片,系统自动生成一段具有动漫风格的跳舞短视频,用于个人形象展示、社交传播等场景。该功能需具备以下核心能力:
- 人物动漫形象化转换:将真人形象转化为动漫风格角色;
- 动态动作生成:支持舞蹈等复杂动作序列;
- 高质量视频输出:支持720p分辨率,帧率稳定,画面细腻;
- 风格一致性控制:确保生成视频在风格、色彩、动作上保持统一;
痛点问题
虽然当前市面上已有多个主流AIGC视频生成模型(如Stable Diffusion 3、Runway ML、Pika Labs等),但在本项目场景下存在以下关键痛点:
- 动态动作生成不稳定:现有模型在复杂动作(如舞蹈)生成中容易出现动作不连贯、肢体穿透、帧抖动等问题;
- 动漫风格控制能力弱:难以精准实现“仙女”类动漫风格,风格一致性差;
- 视频质量在低分辨率下下降明显:纹理丢失严重,细节表现不佳;
- 推理速度慢,不满足生产部署需求:无法在主流消费级显卡上实现高效推理;
- 缺乏个性化定制能力:无法针对“真人→动漫仙女”的特定场景进行模型强化;
针对以上痛点问题,本项目决定:
1. 进行模型选型,选择在动漫领域生成表现效果好的AIGC模型做对比,选择合适的模型;
2. 同时针对这一特定场景,专门做模型后训练、采用多种训练策略来强化该场景下的表现效果;
3. 更进一步为了提升训练效率节省成本,对训练过程和推理过程做性能和显存的优化,推动方案实际生产落地;
后续按照该顺序介绍整体的流程;
2. 模型、算力选型+对比验证
2.1 模型选型
生产级主流的视频生成模型选择wan2.1、wan2.2,Wan2.2-I2V-A14B、Wan2.1-I2V-14B-720P;
这两个模型在模型尺寸上是相匹配的,都是14b尺寸的大模型,在功能上是专门用于图片生成5s视频,符合客户实际的场景需求;5b模型过小,生成效果不理想;文生视频的模型不符合客户需求;
wan2.1该版本已经能够生成多种艺术风格的图像,如写实、卡通、水墨、油画、赛博朋克等,细节生成能力增强人物面部、光影效果、纹理细节等方面有了显著提升,能够生成更逼真、更具艺术感的图像;支持用户通过关键词、风格标签、构图控制等方式更精细地控制生成结果,提升创作的可操作性
wan2.2 支持更高分辨率图像的生成,同时在细节刻画、色彩表现、光影渲染等方面更加自然,接近专业艺术作品水平,该版本引入了更先进的风格编码机制,能够实现多种风格的融合与创新,用户可以自由组合不同风格元素,创造出电影级别的视觉效果,WAN2.2 在理解文本描述方面更加精准,能够根据复杂语义生成符合逻辑的场景构图,包括多对象布局、空间关系、动态动作等;支持图像局部编辑、修复、重绘等功能,用户可以在生成图像的基础上进行再创作,提升创作灵活性和实用性
根据实际场景的特点,选定这两个模型用于对比效果训练,选择效果更好的模型用于实际场景;
2.2 算力选型
根据wan2.1、wan2.2模型的算力需求和显存占用,综合分析单卡和多卡训练和推理场景的算力。
推理场景

训练场景

机型1和机型2因显存不符合推理和训练的要求,已排除,机型3的显存足够,但整个机器的算力成本非常高,导致训练的性价比较低;
结合算力成本,综合分析决定采用机型4作为训练和部署推理的机器;因为训练视频文件需要的显存非常大,和帧数、fps相关,需要预留足够多的显存。
2.3 数据集构建
构建多模态训练集,包含:50组小样本数据集和5000组全量数据集,数据集由提示词文本、首帧图片、控制视频、vace视频;小样本用于本地效果验证,大样本用于生产级模型训练;同时对数据集随机切分,按9:1的形式分割训练集和测试集,用测试集评估生成质量;
以下是数据集的构建方法:
以下是数据集的组织格式:

本数据集中使用提示词、训练视频首帧、训练视频传入wan模型做训练;
微调训练方法:
采用LoRA+全参训练的对比训练方法:
在相同数据集上,对比训练wan2.1和wan2.2,lora训练模块选择注意力机制的qkv和前馈神经网络的前两层,全训练模块选择全部;

3. lora微调+全量训练
3.1 训练过程
基于PAI DSW进行小样本数据集的验证测试;
lora微调,在实际测试中发现,因wan2.1、wan2.2本身对显存的占用不同,wan2.1的显存占用42GB,wan2.2占用达到51GB,因此能够被训练的视频文件中的帧数是不一样的,本数据集中全部采用5s视频,fps15,总帧数75,wan2.1能够参与训练的最大帧数是前60帧,也就是视频的前4s,wan2.2能够参与训练的最大帧数是前45帧,也就是视频的前3s;wan2.1、wan2.2采用相同的优化策略;但由于训练帧数因显存原因无法对齐,后续在优化策略中会解决这一问题;而全参数训练占用的显存更加大,wan2.1只能训练到前41帧,wan2.2 25帧;

镜像选择DSW官方镜像:
dsw-registry-vpc.cn-wulanchabu.cr.aliyuncs.com/pai/modelscope:1.29.0-pytorch2.6.0-gpu-py311-cu124-ubuntu22.04
以下是训练的环境依赖,实际测试中dsw的官方镜像支持直接进行训练和推理:
torch>=2.0.0
torchvision
transformers
imageio
imageio[ffmpeg]
safetensors
einops
sentencepiece
protobuf
modelscope
ftfy
pynvml
pandas
accelerate
peft以下是训练命令,用deepspeed训练框架:
wan2.1 lora
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path train_dataset \
--dataset_metadata_path train_dataset/metadata.csv \
--height 720 \
--width 1280 \
--num_frames 60 \
--dataset_repeat 10 \
--model_paths '[
[
"wan21/diffusion_pytorch_model-00001-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00002-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00003-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00004-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00005-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00006-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00007-of-00007.safetensors"
],
"wan21/models_t5_umt5-xxl-enc-bf16.pth",
"wan21/Wan2.1_VAE.pth",
"wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
]' \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offloadwan2.1 full
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 720 \
--width 1280 \
--num_frames 41 \
--dataset_repeat 10 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offloadwan2.2 lora
accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \
--dataset_base_path train_dataset \
--dataset_metadata_path train_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 45 \
--dataset_repeat 10 \
--model_paths '[
[
"wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
],
[
"wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
],
"wan22/models_t5_umt5-xxl-enc-bf16.pth",
"wan22/Wan2.1_VAE.pth"
]' \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
accelerate launch DiffSynth-Studio/examples/wanvideo/model_training/train.py \
--dataset_base_path train_dataset \
--dataset_metadata_path train_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 25 \
--dataset_repeat 10 \
--model_paths '[
[
"wan22/high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
"wan22/high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
],
[
"wan22/low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
"wan22/low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors"
],
"wan22/models_t5_umt5-xxl-enc-bf16.pth",
"wan22/Wan2.1_VAE.pth"
]' \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358 \
# boundary corresponds to timesteps [0, 900wan2.2 full
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 25 \
--dataset_repeat 10 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.358 \
--min_timestep_boundary 0
# boundary corresponds to timesteps [900, 1000]
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.358
# boundary corresponds to timesteps [0, 900)3.2 训练性能数据
在总体的小样本数据集上,50组训练视频,各自取一定的帧数,再乘重复训练次数(epoch),就是总的训练步长,例如wan2.1 lora里取前60帧,那么总步长=60*50*10=30000;总步长决定训练的效率,每次梯度迭代会经历一定的步长;

以下是训练wan2.1的过程:

训练后推理的时长:

3.3 训练后测试推理的命令
以wan2.1 lora为例,以下是加载lora权重和测试推理的代码,注意num_frames参数是控制生成的帧数,如果和训练时指定的帧数相同,则最大化训练的效果,但本训练数据中因显卡最大显存限制,无法训练全部的75帧,所以只取前一部分帧数做训练,num_frames可以不指定具体的值,默认取5s视频的帧数;
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path=[
"wan21/diffusion_pytorch_model-00001-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00002-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00003-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00004-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00005-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00006-of-00007.safetensors",
"wan21/diffusion_pytorch_model-00007-of-00007.safetensors",]),
ModelConfig(path="wan21/models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(path="wan21/Wan2.1_VAE.pth"),
ModelConfig(path="wan21/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
use_usp=True,
)
pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-2.safetensors", alpha=1)
pipe.enable_vram_management()
image = Image.open("1.jpg")
prompt='''
Flower-Fairy. A radiant metamorphosis unfolds as the character, encircled by shimmering butterflies, rises from a whirling vortex of whimsical 2D halos. The initial attire transforms into a resplendent emerald green tunic adorned with golden embellishments and a multi-layered tulle skirt, reminiscent of Tinker Bell's iconic outfit, while preserving the original hairstyle. Luminous halos sparkle beneath her feet as the backdrop bursts into a vivid, enchanted forest teeming with bioluminescent plants and flickering fireflies.
'''
# Image-to-video
video = pipe(
prompt=prompt,
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
input_image=image,
seed=0, tiled=True,
height=720, width=1280,
#num_frames=25
)
save_video(video, "video_lora.mp4", fps=15, quality=5)wan2.2的效果更加写实风格,wan2.1偏向动漫风格,更适合本场景需求;
4. 训练优化加速
优化加速主要是为了让模型的训练和推理能够提升速度,同时尽量保持原有的性能水平,尽量减少使用算力的时间成本,使训练和推理能够真正的在消费级显卡上使用,成为一个可落地的解决方案;
一般来说,优化加速可以从以下方面去考虑:
- 训练速度提升;
- 推理速度提升;
- 训练时显存占用减少;
训练速度提升是在模型训练过程中,通过高效注意力计算机制、多卡并行训练策略、模型量化等方法,加速模型反向传播计算的时间,同时尽量减少对最优参数的选择的影响,从而提升训练效率;
推理速度提升是在模型训练结束后的推理验证中,通过多卡并行推理策略、缓存机制、高效注意力计算机制等数学方法,加速前向传播的计算时间,同时尽量减少对最优参数的选择的影响,从而提升推理效率;
显存占用减少是模型训练或者推理过程中,通过分块组计算、非核心参数卸载等方法,减少实际在GPU中参与计算的参数数量,很大程度上减少GPU的显存占用,从而在算力不变的情况下能够训练更大的模型、更多的数据;
以下是本项目中实际采用的优化方法:

4.1 Sage Attention - 27% Train & Infer Speed-Up
Sage Attention 是一种基于int8量化的高效注意力计算方法,加速Transformer模型的推理过程,同时保持模型精度。其优势包括:
- K矩阵平滑:通过减去token间的平均值缓解K矩阵的通道级异常值,并且不影响softmax分数的计算;
- 混合精度计算:Q和K使用INT8量化,P和V保持FP16并采用FP16累加器;
- 自适应量化策略:根据层对精度的敏感度动态选择量化粒度(per-token或per-block);
- 硬件优化:基于Triton实现的高效内核,利用NVIDIA Tensor Core的INT8和FP16指令加速计算;

sage计算方法如上,K矩阵的平滑计算,特别适合对扩散模型diffusion的视频帧计算,能够很大程度减少相同连续区块的重复计算,因此对AIGC类模型加速效果比flash attn更好,在消费级显卡上,sage attention能实现比flash attention高2-3倍的加速比;但目前语言类模型和视觉理解类模型的加速效果还是采用flash attention效果更好;

以下是使用方法:
基础环境需求

从源码编译安装sage
git clone https://github.com/thu-ml/SageAttention.git
cd SageAttention
export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32
python setup.py install对一般的Diffusion模型,使用方法是修改模型transformer架构中attention矩阵,用sage替换torch自带的dot-product-attention
from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
F.scaled_dot_product_attention = attn_output但在wan2.1、wan2.2模型已经原生支持sage加速,只需要启动训练命令,自动会按照sage方法计算attention;本项目中,开启sage和不开启的效果对比,wan2.1 lora为例

4.2 TeaCache - 28% Infer Speed-Up
在视频生成中,很多帧之间是高度相似的,比如:摄像机缓慢移动,背景不变、只有前景物体移动,人物说话时背景不变;在这种情况下,如果每一帧都从头开始重新生成,会浪费大量计算资源;TeaCache 是一种 用于视频生成任务的缓存加速机制,通过缓存帧之间的相似内容,减少重复计算,从而提升视频生成速度并降低计算资源消耗。
TeaCache 主要包含以下几个步骤:
1. 在生成第 t 帧时,模型会与前一帧 t−1 做对比,判断哪些区域变化较小。使用 L1 距离(像素级差异) 或 特征空间差异(latent)来衡量两帧之间的差异。如果某个区域的差异小于阈值 ,则认为该区域“变化不大”,可以使用缓存。
2. 缓存的内容通常是 Latent 空间中的中间表示。缓存区域的坐标(bounding box)、缓存区域的 latent 特征、缓存区域的时间戳(用于判断是否过期)
3. 在生成下一帧时:对于变化较小的区域,直接从缓存中提取之前计算好的 latent 特征,避免重复计算。只对变化较大的区域进行完整的扩散模型计算。
4. 在每一帧生成完成后,会更新缓存:更新缓存区域的 latent 特征、移除过期的缓存条目(例如超过一定帧数)、添加新生成的缓存区域;
tea_cache_l1_thresh是teacache的关键参数,判断帧间相似度的阈值,0.02-0.1取值,值越小,缓存帧数越多,视频生成越快,质量损失大,值越大,缓存帧数越少,视频生成越慢,质量损失小;
在wan模型中如何集成:
在 WanVideoPipeline 的推理流程中,TeaCache 模块被嵌入到每一帧的生成过程中:
for frame_idx in range(total_frames):
text_emb = encode_text(prompt)
noise = get_initial_noise()
if frame_idx > 0:
cache_mask = tea_cache.get_cache_mask(prev_latent, current_latent)
latent = diffusion_model(noise, text_emb, cache_mask=cache_mask)
tea_cache.update_cache(latent, frame_idx)
frame = vae.decode(latent)
以下是使用方法:
基础环境需求
accelerate>0.17.0
bs4
click
colossalai==0.4.0
diffusers==0.30.0
einops
fabric
ftfy
imageio
imageio-ffmpeg
matplotlib
ninja
numpy<2.0.0
omegaconf
packaging
psutil
pydantic
ray
rich
safetensors
sentencepiece
timm
torch>=1.13
tqdm
peft==0.13.2
transformers==4.39.3从源码编译安装
git clone https://github.com/ali-vilab/TeaCache.git
cd TeaCache
python setup.py install代码使用,已内置在wanvideo pipeline中:
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
seed=0, tiled=True,
# TeaCache parameters
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
tea_cache_model_id="Wan2.1-I2V-14B-720P",
)效果对比

4.3 XDIT Sequence Parallelism - 400% Train & Infer Speed-Up
xDiT是一种面向扩散模型的大规模并行推理框架,其核心目标是通过创新的并行化技术和编译优化技术,实现高效的大规模扩散模型训练与推理。

xDiT的核心原理
1. PipeFusion:分块流水线并行,将图像分割为多个patch,用pipeline在不同卡上并行处理,最后整合结果,有效减少单卡上的显存占用;
2. USP:统一序列并行,针对扩散模型中的长序列生成任务,将序列的不同维度拆分到不同卡,实现多卡的序列并行计算,最后整合向量。
3. CFG Parallel:在分类器自由指导过程中,将正向和负向样本的计算分配到不同卡,降低单次推理的计算负载。
4. DistVAE:分布式VAE模块,对扩散模型中的VAE模块进行分块并行处理,避免显存溢出。
5. 编译优化技术:xDiT通过编译器技术优化GPU执行效率,主要依赖:Torch.compile:PyTorch 2.0的JIT编译器,通过融合
算子、消除冗余计算提升性能;OneDiff:针对扩散模型的专用编译优化工具,支持内核融合、内存复用等高级特性。
以下是使用方法:
从源码编译
git clone https://github.com/xdit-project/xDiT.git
cd xDiT
pip install -e .
# Or optionally, with flash attention
pip install -e ".[flash-attn]"使用方法:
多卡并行训练
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 720 \
--width 1280 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload
多卡并行推理
## 首先pipeline启用usp
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
use_usp=True,
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
## 然后shell启动多卡推理
torchrun --standalone --nproc_per_node=8 test.py效果对比

4.4 Gradient Checkpointing Offload - 13% GPU Memory Reduce
梯度检查点卸载(Gradient Checkpointing Offload)是一种以计算换显存的优化技术。在前向传播时,模型不会存储所有中间激活值到GPU显存,并且把激活值卸载到内存中,而是仅保留部分关键层的激活值,并在需要时重新计算中间值。这可以显著减少显存占用,但会增加少量计算时间,还引入额外的 GPU-CPU 数据传输开销。
当GPU的显存比较紧张时,为了实现训练大模型,可以通过这种方法显著减少显存占用,实现训练目标;
实现步骤:
1. 前向传播阶段
计算激活值:模型在前向传播时会计算每一层的输出(即激活值),模型会将部分激活值 从 GPU 显存移动到 CPU 内存。对于未启用梯度检查点的层,激活值通常会保留到显存中。
2. 反向传播阶段
在反向传播计算梯度时,模型需要中间激活值来计算梯度。则需要把激活值从内存加载回 GPU 显存,优化卸载大幅减少显存占用(降低20%左右显存占用)。

实现代码,以下是前向传播时传递的关键参数:
# 在前向传播时传递卸载参数
inputs_shared = {
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
...
}
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class WanTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None, audio_processor_config=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
if audio_processor_config is not None:
audio_processor_config = ModelConfig(model_id=audio_processor_config.split(":")[0], origin_file_pattern=audio_processor_config.split(":")[1])
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, audio_processor_config=audio_processor_config)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
# CFG-unsensitive parameters
inputs_shared = {
# Assume you are usingthis pipeline for inference,
# please fill in the input parameters.
"input_video": data["video"],
"height": data["video"][0].size[1],
"width": data["video"][0].size[0],
"num_frames": len(data["video"]),
# Please donot modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"cfg_merge": False,
"vace_scale": 1,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
# Extra inputs
for extra_input in self.extra_inputs:
if extra_input == "input_image":
inputs_shared["input_image"] = data["video"][0]
elif extra_input == "end_image":
inputs_shared["end_image"] = data["video"][-1]
elif extra_input == "reference_image"or extra_input == "vace_reference_image":
inputs_shared[extra_input] = data[extra_input][0]
else:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
if __name__ == "__main__":
parser = wan_parser()
args = parser.parse_args()
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_video_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
),
special_operator_map={
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
}
)
model = WanTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
audio_processor_config=args.audio_processor_config,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
)
launch_training_task(dataset, model, model_logger, args=args)在训练时启用该参数gradient_checkpointing_offload
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 720 \
--width 1280 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload
效果对比
gpu显存占用下降了13%,相对应的训练时间延长了9%,牺牲了少量时间,来降低显存占用,从而让更多的帧数可以参与训练,是合理的优化。

4.5 Tiled VAE - 11% GPU Memory Reduce
视频通常由多帧连续图像组成,直接处理高分辨率视频序列会占用大量显存。分块编码解码技术通过逐块处理单帧图像,降低单次计算的显存需求,能有效减少显存占用,但会略微降低视频生成质量;
视频VAE的latent空间通常为5D张量([B, C, T, H, W],B=batch size, C=通道数, T=帧数, H/W=高度/宽度),先对单帧图像进行空间编码/解码,再通过时间模块(Time-Series Transformer)建模帧间关系。
分块技术的实现步骤:
1. 编码阶段(视频到latent空间)
对输入视频的每一帧独立分块。例如,若单帧分辨率为[H, W],则用kernel划分为多个tile_size大小的向量块;通过tile_stride确保相邻块部分重叠,避免块边界处的语义断裂。
2. 解码阶段(latent空间到视频)
逐帧解码:对潜在空间的每帧独立解码,解码向量块到对应图像块;对重叠区域的像素值进行加权平均,减少块边界伪影。对解码后的多帧图像通过时间滤波增强帧间连续性。
3. 时间一致性优化策略
在解码前,通过时间模块(Time-Series Transformer)对分块后的潜在特征进行全局时间一致性约束;在分块时对齐相邻帧的块位置,确保同一物体在不同帧中的重建区域一致。
4. 显存优化
根据显存限制,动态调整tile_size和tile_stride;用Sequence Parallel同时处理多个块;
以下是使用方法:
#tiled: Whether to enable tiled VAE inference, default is False. Setting to True significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time.
#tile_size: Tile size during VAE encoding/decoding, defaultis(30, 52), only effective when tiled=True.
#tile_stride: Stride of tiles during VAE encoding/decoding, defaultis(15, 26), only effective when tiled=True. Must be less than or equal to tile_size.
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
seed=0, tiled=True,
tile_size=[30,52],tile_stride=[15,26]
)效果对比
Tiled VAE有效减少了显存占用,同时速度基本不影响;

4.6 Quantization - 33% Train Speed-Up
在训练模型初期,发现wan2.1、wan2.2模型本身架构不同,wan2.2有高噪声和低噪声两套架构,导致模型大小比wan2.1大一倍左右,因此两个模型的显存占用存在差异,在相同数据集下,导致了能够参与训练的视频的帧数两者不一样,例如lora训练中,wan2.1能训练到前60帧,而wan2.2只有前45帧,理论上来说,帧数越少,那么模型能遵循的视频效果越差,为了再对比训练中尽量保持条件一致,并且要保持wan2.2-I2V-A14B模型的基本架构,最好的办法就是模型量化,用精度更低的wan2.2模型,降低大量显存占用,提高训练帧数,还能加速训练时间。
后面的结果不再采用原始的wan2.2模型,后续的结果全部基于量化wan2.2和wan2.1的对比效果。
wan2.2系列模型族由众多的量化模型,AWQ、GGUF、FP8、FP16、INT8,分别代表不同的参数精度,根据量化模型的尺寸,最终选择wan2.2-fp8模型,模型显存占用基本和wan2.1达到相同,训练帧数也和wan2.1最大帧数相同。
以下是原始的两个模型对比:

模型量化是深度学习模型优化的重要技术,在通过降低模型参数和计算的精度来减少模型大小、内存占用和计算开销,同时尽量保持模型性能。
核心方法是将高精度浮点数(如FP32)转换为低精度表示(如INT8或FP16),从而减少存储和计算需求,有效降低显存占用和计算资源消耗。
量化类型 权重量化:仅对模型权重进行量化,激活值保持高精度。 激活量化:对输入/输出激活值进行量化。 全量化:同时量化权重和激活值。 混合量化:对不同层使用不同精度。
量化方法
对称量化:数值范围对称分布(如-127~127),适用于激活值接近零的情况。 非对称量化:数值范围非对称(如0~255),适用于偏移较大的数据。 动态量化:推理时根据输入动态调整量化参数。 静态量化:训练后通过校准数据集确定量化参数。 量化感知训练:在训练阶段模拟量化误差,提升量化后模型精度。
以下是常见的量化精度:

效果对比

5. 结果
总结训练的全流程链路,一般的模型后训练都可以采用这一套完整的流程,实现系统化、可复用、可扩展的实践:
模型的训练实际上是一个持续迭代优化的过程,没有一蹴而就的完美训练,更多的是要持续的调整训练参数、采用新的训练方法、对比实验等等来不断提升模型的准确率和性能,这需要反复迭代,才能达到理想中的最优效果;

在评价AIGC生成内容的质量时,有一些常用的量化指标,在本场景中,动态表现、相机控制、帧质量、准确性是衡量生成质量的重要参考因素,决定采用客观指标+主观打分的方法来综合评价模型训练的效果;客观指标决定采用行业通用的视频质量评价参数:
峰值信噪比(PSNR)
计算方式:基于均方误差(MSE)的对数转换
评价维度:像素级保真度,反映视频帧与参考帧的噪声差异;
典型阈值:10-50,优质视频≥32

结构相似性指数(SSIM)
计算方式:8*8 kernel的滑动窗口计算
评价维度:动态场景中空间结构与时间连续性的保持能力;
标准范围:0(完全失真)-1(完全匹配),优质生成需≥0.85;

这是设计的评测流程:
PSNR的计算代码:
import cv2
import numpy as np
def calculate_psnr(video_ref_path, video_gen_path):
cap_ref = cv2.VideoCapture(video_ref_path)
cap_gen = cv2.VideoCapture(video_gen_path)
psnr_list = []
while True:
ret_ref, frame_ref = cap_ref.read()
ret_gen, frame_gen = cap_gen.read()
ifnot ret_ref ornot ret_gen:
break
# 确保分辨率一致
if frame_ref.shape != frame_gen.shape:
frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0]))
# 计算MSE
mse = np.mean((frame_ref - frame_gen) ** 2)
if mse == 0:
psnr = float('inf')
else:
psnr = 20 * np.log10(255.0 / np.sqrt(mse))
psnr_list.append(psnr)
return np.mean(psnr_list)
psnr_value = calculate_psnr("reference_video.mp4", "generated_video.mp4")
print(f"PSNR: {psnr_value:.2f} dB")SSIM计算代码:
from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np
def calculate_ssim(video_ref_path, video_gen_path):
cap_ref = cv2.VideoCapture(video_ref_path)
cap_gen = cv2.VideoCapture(video_gen_path)
ssim_list = []
while True:
ret_ref, frame_ref = cap_ref.read()
ret_gen, frame_gen = cap_gen.read()
ifnot ret_ref ornot ret_gen:
break
if frame_ref.shape != frame_gen.shape:
frame_gen = cv2.resize(frame_gen, (frame_ref.shape[1], frame_ref.shape[0]))
# 转换为灰度图
gray_ref = cv2.cvtColor(frame_ref, cv2.COLOR_BGR2GRAY)
gray_gen = cv2.cvtColor(frame_gen, cv2.COLOR_BGR2GRAY)
score, _ = ssim(gray_ref, gray_gen, full=True)
ssim_list.append(score)
return np.mean(ssim_list)
ssim_value = calculate_ssim("reference_video.mp4", "generated_video.mp4")
print(f"SSIM: {ssim_value:.4f}")主观评价由客户划定几个评价指标:动态表现、相机控制、帧质量、目标准确性四个方面分别做人工打分,评分1-5分;

最后综合客观得分+主观得分,综合确定视频生成的效果;计算方式是加权分,总分=主观分/20 * 50% + PSNR/50 * 25% + SSIM/1 *25%;
50组样本里,随机划分5个测试数据,对这5个用于测试的未参与训练的素材图片,测试四个模型的生成效果,再经过自动评分+人工评分,得到如下的得分结果:

最终根据生成视频的总体得分情况,wan2.1 full的生成效果最好,但整体的全量训练所需算力成本比lora大很多,综合比较wan2.1、wan2.2 的训练成本,包括训练时间、所需算力,综合选择性价比最高的wan2.1 lora模型作为上线生产环境的主力模型。
来源 | 阿里云开发者公众号
作者 | 李德