乐于分享
好东西不私藏

FastWAM 源码入门分析

FastWAM 源码入门分析

1. FastWAM 论文的核心做法

FastWAM 也是以世界模型(World Model)作为 backbone 构建的机器人控制模型,其核心理念是: 在训练时保留视频共训练(保持世界建模能力),但在推理时跳过未来预测(直接生成动作),从而解耦这两个因素。

FastWAM 关键设计

关键设计:

  • 训练阶段:联合优化动作生成和视频去噪目标,让视频 DiT 学习物理有意义的运动和交互结构
  • 推理阶段:不显式生成未来视频帧,而是将视频 DiT 作为单次前向世界编码器,直接输出潜在表示供动作专家使用
  • 控制实验:构建多个对照变体来隔离变量:
    • Fast-WAM-Joint:联合去噪未来视频和动作(传统 WAM 范式 A)
    • Fast-WAM-IDM:先生成未来视频,再条件化预测动作(传统 WAM 范式 B)
    • Fast-WAM w.o. video co-train:移除视频共训练目标,作为对照组

主要发现:

  • Fast-WAM(无显式未来生成)与 imagine-then-execute 变体性能相当
  • 移除视频共训练导致性能显著下降,说明视频预测的主要价值在于训练时改善世界表示,而非测试时生成未来观察
  • 推理延迟仅为 190ms,比现有 imagine-then-execute WAMs 快 4 倍以上

模型架构

基础骨干: 基于预训练的 Wan2.2-5B(视频扩散 Transformer),包含视频 DiT、文本编码器和视频 VAE。

架构组件(Mixture-of-Transformer, MoT):

输入层:
├── 文本编码器 (T5):编码任务指令,通过 Cross-Attention 提供给所有 token
├── VAE 编码器:将视觉观察映射为潜在视频 token
└── 动作编码器:处理动作块

核心架构(共享注意力机制):
├── Video DiT(视频分支):
│   └── 处理观察帧的潜在 token(第一帧为干净 token,未来帧为噪声 token,仅训练时使用)
└── Action DiT(动作专家分支,1B 参数):
└── 处理动作 token,维度 d_a = 1024

注意力掩码设计(关键创新):
├── 训练时:
│   ├── 未来视频 token 可在视频分支内双向交互,并可访问第一帧
│   ├── 动作 token 可在动作分支内双向交互,并可访问第一帧
│   └── 动作 token 不能访问未来视频 token(防止未来信息泄露)
└── 推理时:
└── 仅保留第一帧干净 token,通过 Video DiT 单次前向传递 → 生成世界潜在表示 z(o,l) → 供 Action DiT 去噪动作

训练目标:

  • 联合流匹配损失:
    • :动作生成(flow matching on action tokens)
    • :视频共训练(flow matching on future video latents)

推理流程:

  1. 当前观察帧通过 VAE 编码为干净潜在 token 
  2.  单次通过 Video DiT → 获得世界潜在表示 
  3. Action DiT 基于  和语言指令,通过 10 步去噪生成动作块
  4. 完全跳过未来视频帧的生成,实现实时推理(190ms 延迟)

2. FastWAM 代码库的结构

FastWAM 代码库可以按 4 层来理解:

  1. Hydra 配置层
  • 入口是 scripts/train.py和configs/train.yaml。
  • 任务配置会把 data/model/task 拼起来,例如 configs/task/libero_uncond_2cam224_1e-4.yaml 选择 data=libero_2cammodel=fastwam
  1. 训练运行层

训练主链在 src/fastwam/runtime.py 和 src/fastwam/trainer.py

  • run_training() 做三件事:实例化模型、实例化数据集、交给 Wan22Trainer。
  • Wan22Trainer 只优化 model.dit,而在 FastWAM 里 model.dit = model.mot,这说明真正被 finetune 的核心是 MoT 主干,而不是整个 Wan 全量参数。
  1. 模型实现层

最核心的文件是:

  • src/fastwam/models/wan22/fastwam.py
  • src/fastwam/models/wan22/fastwam_joint.py
  • src/fastwam/models/wan22/fastwam_idm.py
  • src/fastwam/models/wan22/mot.py
  • src/fastwam/models/wan22/action_dit.py
  1. 数据与评测层
  • 数据主入口是 src/fastwam/datasets/lerobot/robot_video_dataset.py 和 src/fastwam/datasets/lerobot/processors/fastwam_processor.py
  • LIBERO 评测看 experiments/libero/eval_libero_single.py,RoboTwin 部署看 experiments/robotwin/fastwam_policy/deploy_policy.py

3. FastWAM 模型代码实现

如果只用一句话概括源码结构:FastWAM = Wan 视频专家 + ActionDiT 动作专家 + MoT 共享多层注意力 + VAE + 可选文本编码器

3.1 顶层类 FastWAM

它的构造函数里挂了这些核心成员:

  • self.video_expert:视频分支,类型是 WanVideoDiT
  • self.action_expert:动作分支,类型是 ActionDiT
  • self.mot:把两路 expert 拼起来做混合注意力
  • self.vae:视频 VAE,用来把像素视频编码到 latent
  • self.text_encoder / self.tokenizer: 文本编码器,训练时通常不加载,改用预缓存 context
  • self.proprio_encoder: 把机器人状态映射到文本维度,作为额外一个 conditioning token
  • train/infer video scheduler
  • train/infer action scheduler

还有一个很关键的小设计:self.dit = self.mot,这是为了兼容训练器,因为训练器默认只训练 model.dit。在 FastWAM 里这意味着真正被训练的是 MoT,而不是把整个 Wan 模型当黑盒端到端训练。

3.2 视频专家 WanVideoDiT 代码

代码位于 wan_video_dit.py 中,它本质上是一个视频 latent Transformer,组成如下:

  • patch_embedding:Conv3d,把 VAE latent 切成 patch token
  • text_embedding: 把文本 token 映射到 Transformer hidden dim
  • time_embedding + time_projection:生成时间步调制向量t_mod
  • blocks: 多层 DiTBlock
  • head: 把 token 还原回视频 latent patch

它的输入不是 RGB,而是 VAE latent,shape 是 [B, C, T, H, W],然后通过 patchify() 变成 token 序列:

x_tokens: [B, S, D]
S = f * h * w

这里的 f,h,w 是 latent patch 网格大小。

这样,最终视频会被切成 token 序列 S = num_frames * tokens_per_frame,tokens_per_frame 表示一帧 latent 图像对应多少个 patch token。


pre_dit() 是这个类最重要的函数之一,它负责做 5 件事:

  1. 检查输入 shape
  2. 把视频 latent patchify 成 token
  3. 生成 3D RoPE 位置编码 freqs
  4. 生成时间调制 t_mod
  5. 把文本 context 也投影到 hidden dim

最终返回一个字典:

  • tokens
  • freqs
  • t_mod
  • context
  • context_mask
  • meta.grid_size
  • meta.tokens_per_frame

这里的 tokens_per_frame 非常重要,后面做 attention mask 时会用它判断“哪些 token 属于第一帧”。

3.3 DiTBlock 每层在干什么

DiTBlock 是一个单层结构,每层都是标准 DiT 风格:

  • norm1 -> self-attn
  • norm3 -> cross-attn(text/proprio)
  • norm2 -> FFN

但它不是普通 residual,而是用 time modulation 做 FiLM 风格调制:

  • shift_msa, scale_msa, gate_msa
  • shift_mlp, scale_mlp, gate_mlp

也就是每层都被 diffusion timestep 控制。这个调制不是在顶层做一次,而是每层都做。

timestep 作为调制信号

扩散模型里的 timestep t 表示“当前噪声有多大”。网络在不同的 t 下,应该需要做不同的事情:

  • 噪声很大时,要更偏向粗粒度恢复
  • 噪声很小时,要更偏向细节修正

在代码实现上,先在 pre_dit() 里把 timestep 编码成向量:

  1. sinusoidal_embedding_1d(self.freq_dim, timestep)
  2. 过 time_embedding
  3. 再过 time_projection

最后得到 t_mod

对于视频分支来说,t_mod 被整理成和每层调制对应的结构,大致可理解为:

  • 每层需要 6 组调制参数
  • 分别给 attention 和 MLP 用

在 DiTBlock.forward() 里,这行代码最关键:

shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ...

这 6 个量来自:

  • block 自己的可学习参数 self.modulation
  • 加上 timestep 生成的 t_mod

所以每一层都会得到一套“与当前 timestep 相关”的控制参数。


这些参数怎么作用到 block 里?主要有两处。

第一处,控制 self-attention 之前的输入:

input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(...))

也就是 timestep 会决定:

  • attention 前输入怎么缩放、平移
  • attention 输出残差该加多少

第二处,控制 MLP 之前的输入:

input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))

也就是 timestep 还会决定:

  • MLP 前特征怎么调制
  • MLP 残差强度有多大

所以不是简单地“把 t 拼到输入里”,而是:

  • timestep -> 生成调制参数
  • 调制参数 -> 改写每层 attention/MLP 的工作方式

这就是 DiT 里常见的 AdaLN / FiLM 式 conditioning。

Video Expert 中的注意力模式

在视频专家里,self-attention 的 mask 可以切换模式,见 build_video_to_video_mask():

  • bidirectional:所有视频 token 都可以互相看见,不区分过去、当前、未来。如果你想把第一帧当作“测试时可用观测”,这种模式不适合 FastWAM 的核心设定。
  • per_frame_causal:以“帧”为单位做因果 masking,每个 frame 只能看到以前的 frames,看不到未来 frames,同一 frame 内的所有 patch tokens 彼此全可见。
  • first_frame_causal:第一帧是特殊锚点,第一帧 token 不能看后面的帧,后面的帧可以看第一帧,也可以彼此互看。这是 FastWAM 默认用的模式。

FastWAM 默认配置用的是 first_frame_causal。这很关键,因为它让第一帧成为一个特殊的信息锚点:

  • 训练时第一帧对应“真实当前观测”
  • 未来帧是噪声 latent,需要被建模
  • 推理时只保留第一帧,所以训练时也希望第一帧表征不要依赖未来帧,避免信息泄露
  • 同时又希望未来帧在训练里能充分交互,学到更强的视频世界表示

所以它本质上是一个“保护首帧因果性,但放松未来帧之间约束”的折中设计。

3.4 动作专家 ActionDiT 是怎么写的?

动作专家的代码实现在 action_dit.py 中。它和视频专家非常像,但简化很多:

  • 输入是动作 token [B, T_action, action_dim]
  • action_encoder 先升维到 hidden dim
  • 也有 text_embedding
  • 也有 time_embedding / time_projection
  • 也有一串 DiTBlock
  • 最后 head 直接回归动作噪声

最关键的约束是它必须和视频专家“层数、头数、head dim 完全一致”,因为 MoT 要在每一层把两边的 Q/K/V 拼起来混合注意力。这个检查在 FastWAM.from_wan22_pretrained() 里写得很明确。

3.5 MoT 到底做了什么

MoT 在 mot.py 中实现,是最核心的“共享世界表征”实现。

它不是简单地把两个网络串联,而是:

  1. 对 video expert 当前层,算出 q_v, k_v, v_v
  2. 对 action expert 当前层,算出 q_a, k_a, v_a
  3. 把两路在序列维拼起来:
    • q_cat = [q_v, q_a]
    • k_cat = [k_v, k_a]
    • v_cat = [v_v, v_a]
  4. 用一个联合的 attention_mask 做一次 mixed attention
  5. 再把输出切回 video/action 两支,分别走各自 block 的后半段

所以论文里的 “Mixture-of-Transformer with shared attention” 在代码里就是这套 “每层 QKV 拼接 + 联合 mask” 的实现。

你可以把它理解成:

  • token embedding、输出头、MLP、cross-attn 都还是 expert-specific
  • 只有 self-attention 的 token mixing 是共享的

3.6 FastWAM 主方法的 forward 逻辑

论文主方法对应 FastWAM.training_loss() 和 FastWAM.infer_action()

训练时

  1. build_inputs(),把 video/action/proprio/context 读进来,并用 VAE 编码视频
  2. 视频 latent 加噪,得到 latents
  3. 动作序列加噪,得到 noisy_action
  4. video_expert.pre_dit(...) 和 action_expert.pre_dit(...) 分别把两路 token 变成 Transformer 输入。
  5. _build_mot_attention_mask(...) 决定谁能看谁
  6. self.mot(...) 做共享 mixed attention
  7. video_expert.post_dit(...) 和 action_expert.post_dit(...) 分别回到视频预测和动作预测
  8. 算 loss_video + lambda * loss_action

这里最关键的是 _build_mot_attention_mask(),在 FastWAM 主方法里规则是:

  • video -> video: 按视频分支自己的 mask
  • action -> action: 全可见
  • action -> video: 只能看第一帧 token

这正是论文里“训练时有视频共训练,但动作不直接偷看未来视频”的代码落点。

推理时

FastWAM 相比于以往的 WAM 方法,推理速度加快了。他的推理逻辑是:

  1. 只把当前图像编码成 first_frame_latents
  2. 用 video_expert.pre_dit() 得到视频 token
  3. 用 mot.prefill_video_cache() 一次性把视频分支每一层的 K/V 全缓存下来
  4. 后续动作去噪每一步,只跑 action branch
  5. action branch 通过 forward_action_with_video_cache() 去访问缓存的视频 K/V
  6. 完全不 rollout 未来视频 latent

也就是说,FastWAM 在测试时不是“先生成未来,再预测动作”,而是:首帧 -> 世界 latent cache -> 动作去噪

这个 cache 机制就在 prefill_video_cache() 和 forward_action_with_video_cache()。

3.7 scheduler 在模型里的角色

扩散/flow matching 调度器在 scheduler_continuous.py。

视频和动作各自有一套:

  • sample_training_t(): 训练时采 timestep
  • add_noise(): 加噪
  • training_target(): 目标是 noise - sample
  • build_inference_schedule(): 推理时生成步长
  • step(): 更新 latent

所以 FastWAM 是“视频 diffusion + 动作 diffusion”双调度器并行存在,只不过在主方法测试时只真正迭代动作那一路。

四、部署模型

部署模型的相关代码在 deploy_policy.py 中,它展示了真实控制时的接口边界:

  • 把当前 observation 转成拼接图像
  • 归一化机器人状态为 proprio
  • 调 model.infer_action(...)
  • 再把输出反归一化成环境动作