ARIS (Auto Research in Sleep) + ARIS-in-AI-Offer 工作流生成 · Continuous DLM Frontier · 2026-05

A Blog on Continuous DLM
2026 上半年 6 篇集中工作综述 — 以 ELF 为锚,5 家 FM 横向对比 + 字节 Cola-DLM

Ruofeng Yang(杨若峰) · Shanghai Jiao Tong University · 2026-05
锚论文:ELF arXiv 2605.10938 (Hu, Qiu, Lu, Zhao, Li, Kim, Andreas, He · MIT · 2026-05-11,非本文作者)· Code: lillian039/ELF · 代码审过 PyTorch port @ pytorch_elf

📌 TL;DR

ELF(Embedded Language Flows, MIT; Hu/Qiu et al., senior 作者 Kim/Andreas/He, arXiv 2605.10938, 2026-05-11) = 在冻结的 T5-small contextual embedding 空间里做连续时间 Flow Matching,只在最后一步离散化; denoiser 和 final-step decoder 共享同一个 Transformer 权重。 一句话定位:"continuous DLM 的瓶颈可能不在'连续'本身——而在过去工作把 encoder 放进训练联合学。ELF 把它冻结,diffusion 只学如何在已有几何中 transport,在 ablation 中显著优于 learnable encoder 选项。"

核心五个数(全部来自 paper):

Gen-PPL = 模型自己采样 → GPT-2 Large 给打分的 perplexity,低 = 生成像自然语言;通常配 entropy 看防 mode collapse。

三条可带走的 takeaway(读完这篇 paper 应该被 update 的认知):

历史定位:2026 上半年 continuous-DLM 出现 6+ 篇集中工作。FM 家族按发布时间排: CFM (Feb)、FLM/FMLM (Feb)、DFM (Apr)、LangFlow (Apr)、ELF (May); 外加 ByteDance Cola-DLM (May 末) 的 latent-VAE 路线。 ELF 在 FM 家族里是 已知 size 最小(105M)、设计最干净(无 distillation / 无 latent VAE)、 且无蒸馏 32 步即接近 dataset reference PPL;Cola 是最大尺度、最复杂、reasoning-focused。 详见 §3 (FM 家族 5-way) 和 §9 (Cola 单独详细对比)。

1 · 离散 DLM 的已有问题

1.1 离散 DLM 五年演进路线

年份工作主要贡献评测尺度
2021D3PM (Austin et al. NeurIPS)定义离散扩散框架:mask / uniform / embedding transitiontext8, LM1B 小尺度
2024SEDD (Lou et al. ICML Best Paper)Score entropy loss 给离散空间一个 clean 损失OWT scale
2024MDLM (Sahoo et al. NeurIPS)Masked diffusion ELBO ≡ weighted CE — 极大简化训练OWT scale
2025LLaDA (Nie et al.)第一个 8B-scale 离散 DLM,证明 scaling 可行8B params
2025Dream 7B (Ye et al.)大规模 diffusion LM(具体机制细节见原文)7B params
2026外部 landscapediscrete:BD3-LM (semi-AR block) / ReMDM / PRISM(test-time remasking);continuous-latent:VADD / LADD / HDLM / CADD;inference-time coupling:CoDD-style PC layer各种 scale

1.2 共同特征 + 共同瓶颈

这套路线都把 D-LLM 定义为:mask/uniform 离散腐蚀 → 逐位置独立 softmax reverse。 随着工作越做越多,发现它撞上四堵墙:

① 表示层 — Quantization Wall

Token 在嵌入前是孤立范畴点,几何邻近性需要从零学。 对比连续扩散在 image / video 上,模型可以利用像素几何("红"和"暗红"邻近), 但离散 token 空间"猫"和"狗"是几何上无关的两个 one-hot vector。 模型必须把所有词之间的语义距离从训练数据慢慢"学"出来。
结果:参数效率低 — 一个 7B 离散 DLM 学完 token 几何后剩下的容量才用来学语言。

② 建模层 — 因式分解瓶颈(Factorization Gap)

标准 D-LLM 的 reverse 是每个 token 独立的 softmax(可类比为图模型的 fully factorized reverse head,treewidth-0)。 真实联合后验 P(X0|Xt) 的所有 token 间依赖被一刀切掉。具体症状(来自论文 App C 描述的 poor generation 区域):

  • 重复:"the the the"、"sample sample"
  • 退化句(repetitive/degenerate generations):开头一句生成完美,接下来一段全是重复模板
  • 长序列上累积误差:序列越长 factorization 误差越容易累积

③ 优化层 — 参数化别扭

Categorical score / transition 在数学上没有干净的 x / v / ε 对应物。 Image diffusion 圈两年发展出来的工具——self-conditioning, classifier-free guidance, flow matching, rectified-flow, EDM-style noise schedule——没有一个能直接搬到离散空间。 每个都要重新设计,速度被拖慢。

④ 推理层 — 离散错误累积

每步都要 unmask / 取 argmax / round;离散误差不可微地累积。 经验上需要很多步才能 recover:ELF Fig 7(a) 对 MDLM / Duo 用了 1024-step baseline 才达到与 ELF-B 32-step 持平的 PPL。 即使做 distillation(MDLM+SDTT, Duo+DCD),few-step variant 也只能到 32-64 步且 PPL 退化。 作为对比,image diffusion 已经发展出 consistency model / mean-flow / flow-distillation 等 1-2 步直采路线。

1.3 过去一年的"修"路线

方法在哪个接口"修"问题
CoDD用 tractable probabilistic / circuit layer 替换或增广 factorized output训练时还可能是 factorized;inference-time coupling 不改根本
CoDARContinuous latent diffusion + fixed encoder + separate / contextual AR decoder
(ELF Tab 2 把它列在 latent diffusion 类)
引入 AR decoder,部分丢了 DLM 全并行性
E2D2 / 类 block diffusionSemi-autoregressive:block 内 joint denoising,block 间 ARblock-AR 牺牲并行性
VADD / LADD / HDLM / CADD (外部)加 latent / hierarchical 结构架构越来越复杂,没有 clean theoretical story
ReMDM / PRISM-styletest-time remasking / search训练阶段不变,只动推理

这些都是 ——保持"离散 + factorized"这个根基不变,在边缘 patching。 ELF 的回答更激进:跳出离散空间。如果 token 嵌入已经被 T5 学好了, 为什么 diffusion 还要在离散符号上跑?为什么不直接在 T5 那个连续 + 有几何结构的空间里跑 flow?

2 · ELF 的核心想法 — 一个 "已知未知" 的连续空间

ELF Fig 1 hero plot
Fig 1 (paper p.1): ELF 在 OWT 上 32 步达到 Gen-PPL≈24,明显优于 MDLM/Duo/FLM/LangFlow,且<strong>无蒸馏</strong>、训练 token 远少。

2.1 两个"连续",一个"冻结"

ELF 在两个意义上是 continuous(paper §2 明确说 "continuous in two senses"):

但 ELF 还有一个关键的"冻结":encoder 不学。这是它和过去同类工作的最大区别。

💡 ELF 到底 denoise 在哪个 tensor 上?—— 位置 vs 词表 vs 维度

很容易混淆"每个 token 位置对应一个 vector"是不是指"每个词有一个 vector"。不是。先理清三个数字:

数字含义
1024序列长度——位置数量("句子里有 1024 个 token 位")
512T5-small contextual embedding 维度——每个位置的 vector 有多少维
32128词表大小——T5 SentencePiece 总共 32128 种不同的 token

ELF 的 tensor 流(OWT, L=1024, D=512):

input_ids:        [B, 1024]               ← 1024 个 token ID(每个是 0~32127 之间的整数)
                       ↓ T5-small frozen encoder
T5 contextual emb: [B, 1024, 512]         ← 每个"位置"得到一个 512-d vector
                       ↓ + noise
z_t (noisy):       [B, 1024, 512]         ← ELF denoise 就在这个 tensor 上
                       ↓ 32 步 Flow Matching denoising
clean embedding:   [B, 1024, 512]         ← 终点 ≈ 干净的 contextual embedding
                       ↓ decoder head(512 → 32128 logits)
vocab logits:      [B, 1024, 32128]       ← 只在 t=1 这一步才出现
                       ↓ per-position argmax
output token IDs:  [B, 1024]

关键澄清

对比离散 DLM(MDLM / LLaDA)

离散 DLM (MDLM)ELF
Denoise 状态token IDs [B, 1024](离散)contextual embedding [B, 1024, 512](连续)
每步输出空间32128 vocab 上的 softmax 分布512-d 连续 velocity
词表 32128 出现频率1024 步每步都做 vocab classification只在 t=1 一次

这就是 §2.4 / §9 里说的"ELF 跳到 continuous embedding,但仍然每个位置一个 vector,最后做 per-position CE"—— ELF 的连续性已经改了(中间不再 round 到 vocab),但 factorization 在 final CE 那一步仍然 per-position 独立(每个位置独立 softmax,没有 joint posterior 跨位置耦合)。

2.2 为什么"frozen contextual embedding"是这条路的关键

过去 5 年的 continuous DLM 路线(paper Tab 2 p.15 给了完整 landscape)分两条线:

共同问题:

  1. "鸡和蛋":embedding 学习和 diffusion 学习互相依赖,早期训练不稳定
  2. 多目标拉扯:embedding 既要满足 reconstruction,又要满足 diffusion topology,又要不 collapse — 容易陷入 trivial 解
  3. Per-step token CE:Tab 2 显示许多老工作在中间步就强加 token-level CE 监督("Train per-step discr.: Yes"),这相当于把"连续 transport"和"离散 classification"两件事强行耦合,部分把问题①(quantization wall)带回来了

ELF 在 Tab 2 里是唯一一行 "fix enc / no train per-step discr / no infer per-step discr / no separate decoder" 同时为否的——架构上最干净的设计点。

ELF 的 framing:把"语义几何"和"transport 动力学"两件事在架构上分离

💡 类比:摄像机标定 vs 拍电影

过去的方法:拍电影 + 现场标定摄像机 + 同时调灯光 — 三件事一起做,一砸就一片乱。
ELF:用已标定好的摄像机 (T5),专心拍电影 (diffusion transport)。 "语言什么样"的问题已经被 Google 用 T5 预训练(~1T tokens)解决了, ELF 不再重复学这个,把所有训练算力都集中在 transport 上。 45B vs 524B 训练 token 的差距就是这个 framing 的直接结果。

⚠️ 但字节 Cola-DLM 给了对立答案:joint training 更好

ELF 上面这套"冻结 encoder 才稳"的论证,在 2 周后字节的 Cola-DLM(2026-05 末)那里被正面反驳。 Cola 选择joint training——VAE encoder 和 DiT prior 一起从 Stage 2 开始联合优化,冻结 encoder。 他们的 RQ2/RQ3 ablation 关键发现:

所以 frozen-vs-joint 这个设计选择没有一个干净的答案。两边都对,只是 regime 不同:

规模谁的 ablation 占优解读
~100M-650M 参数 + ~45B tokensELF frozenencoder 容量已够,再 unfreeze 反而拖慢 transport 学习
~2B 参数 + ~2000 EFLOPsCola joint大算力下 encoder 也有 headroom,joint co-adapt 释放更多潜力

注意两边都没测对方的 regime——ELF 没在 2B+ 规模 sweep joint,Cola 也没在 100M 规模测 frozen。所以这个 tension 严格说是open question。详细分析见 §9 Cola-DLM 对比

2.3 T5-small encoder 为什么是好选择

方面T5-small (ELF 用)其它候选
大小35M (encoder-only)BERT-base 110M、RoBERTa 125M、Sentence-BERT 110M
训练任务Span-corruption denoising 预训练;后续 text-to-text multi-task transferMLM (BERT) / 对比 (Sentence-BERT)
表示性质contextual(同一个 token 在不同句子里不同向量)BERT 也 contextual;static embedding (word2vec/GloVe) 则不
几何性质span corruption 训练让模型学到"什么 span 合理",几何比较平滑BERT MLM 也类似
VocabSentencePiece 32128BERT WordPiece 30K
是否 generativeencoder-only 用法

论文没有 sweep encoder 选择(只 ablate T5-small / base / large 三个 size,App C.1)。 这是 ELF 最容易被 reviewer 攻击的地方:能不能换成 BERT? Sentence-BERT? CLIP text? 一个 LLaMA-7B hidden state? 猜测:T5 的 span corruption 让 hidden state 几何特别平滑,恰好适合 flow matching。 但这是猜测,不是 paper claim。

2.4 "Contextual" vs "Static" embedding —— 借 T5 的语义流形当坐标系

一句话总结:ELF 没有自己学语言的几何,它直接搭便车坐在 T5 已经画好的"文本流形"上做 transport。

2.4.1 Static (非 contextual) embedding

最早的词向量做法(word2vec / GloVe / Diffusion-LM 用的 learned embedding matrix):

所以无论是 "I deposited money in the bank"(金融机构)还是 "We walked along the river bank"(河岸),`bank` 都被映射到同一个固定向量。模型自己没法从 embedding 那一层看出"这个 bank 是哪个意思"——只能靠后面的 transformer 自己消歧。

2.4.2 Contextual embedding

预训练 encoder(BERT / T5 encoder 等)的做法:跑一遍 transformer encoder over the whole sequence,每个 token 位置的输出向量都受整句话影响。

input_ids: [B, 1024]    一个句子,1024 个位置
   ↓ T5 encoder (6 层 self-attention + MLP)
output:    [B, 1024, 512]    每个位置一个 512-d 向量,
                              且这个向量是 attention over 所有位置算出来的

所以同样是 "bank":

2.4.3 几何对比

Static embeddingContextual embedding
存储lookup table [V, D]encoder forward [L, D] per sequence
"bank" 在不同句子里同一个向量不同向量
信息量V 个固定点(vocab=32128)几乎无穷多个点(每句话每位置都不同)
几何结构学到的"词类别"聚类语义+句法消歧后的"上下文状态空间"
来源训练时学一张表跑一遍 frozen encoder

2.4.4 对 diffusion 来说为什么重要 —— 借坐标系,不画地图只走路

T5 预训练(~1T tokens)已经把文本数据组织成了一个有结构的 512-d 流形:语义相近的句子 / 位置在这个空间里几何上也相近,上下文消歧、词性、句法关系都被编码进了向量的位置和方向。

ELF 把 T5 encoder 冻住当"坐标系"——diffusion 的工作变成了"在这张已经画好的语义地图上做 transport",而不是"先画地图再 transport"。

对比之下:

2.4.5 这就是为什么 105M + 45B tokens 能打过 1B+ 的 LLaDA

因为 ELF 实际能用的"语言知识"远不止 45B,而是 45B + T5 那 1T tokens 的预训练迁移。Codex GPT-5.5 xhigh 在 cross-model review 里也直接点了这个:

"45B 'tokens' exclude pretrained T5-small prior; frozen encoder doing real work. What at 2B from scratch?"

—— traces/research-review/2026-05-26_run01/codex_elf_vs_cola.md

代价:上限被 T5-small 锁死。换 7B encoder 还有效吗?scale 到 LLaMA-70B 行不行?这是 ELF 的命门,也是 Gemini auto-gemini-3 review 标的"T5 的生成式插件而非通用架构"那个点(详见 §10 与 ELF 的命门讨论)。

2.5 最后一步离散化 — 共享 Transformer 的 decoder head

整个 flow 从 ε(标准高斯 × noise_scale=2.0)→ clean embedding 都在连续空间。 但 t=1 时模型必须输出离散 token。ELF 的做法:

  1. 主 transformer forward 已经预测 clean embedding x̂([B, L, 512])
  2. 最后另外跑一次 forward,decoder_step_active=True,这次 transformer 输出经过 factored decoder head(768 → 512 GELU → 32128)映射到 vocab logits
  3. 取 argmax 得到 token

关键:denoise 主干 + 离散化 decoder 是同一份 transformer 权重,靠 4 个 mode token 切换。 论文 App C.4 显示这种 in-context conditioning 比 adaLN-Zero 略好且省 43M 参数(详见 §5)。

ELF method overview Fig 2
Fig 2 (paper p.2): 离散 DLM vs 连续 DLM 的对比图。离散版每步都做 vocab classification;连续版整条轨迹在 embedding 空间里平滑流动。

3 · 相关工作 A — Flow-Matching 家族 5 paper 对比(同期)

2026 年 2-5 月,一共出现了 5 篇基于 Flow Matching 的语言模型工作, 其中 ELF 是 paper 自己 Tab 2 显示的最后一行。这一节把这 5 篇放在一起对比,明确 ELF 在 FM 家族里的独特定位

3.1 五篇 paper 速写

CFM — Categorical Flow Maps (Roos et al., UvA/Oxford, arXiv 2602.12233, Feb 2026)。 核心决策:把 categorical generation 写成"endpoint-prediction flow map", 模型预测 simplex 上的 endpoint 分布 πs,t,再用 self-distillation 做少步生成。 评测 Text8 NLL + LM1B Gen-PPL,单步 LM1B Gen-PPL 274.87

FLM / FMLM — Flow Map Language Models (Lee et al., KAIST + CMU, arXiv 2602.16813, Feb 2026)。 核心决策:在 token-wise one-hot 连续空间建模,simplex-valued denoiser,CE 后验预测; FMLM 是 FLM 蒸馏后的少步版本。LM1B + OWT 评测, FLM 1024-step LM1B Gen-PPL 96.91、OWT 62.23;FMLM 单步 LM1B 119.34、OWT 168.30

DFM — Discrete Flow Maps (Potaptchik et al., Harvard/Oxford/MIT/NYU, arXiv 2604.09784, Apr 2026)。 核心决策:把 flow map 从 average velocity 重参数化为 mean denoiser ψs,t, 让 off-diagonal flow map 自然落在 probability simplex 上。直接批判欧氏 L2 与概率几何的不匹配。 LM1B 评测,DFM-ESD 单步 Gen-PPL 68.11(entropy 3.79);DFM-PSD 单步 94.08(entropy 4.06)

LangFlow — Chen et al., UIUC, arXiv 2604.11748, Apr 2026。 核心决策:回到 learned embedding-space,用 Bregman-divergence 把 token CE 解释为 Flow-Matching posterior matching; 推出 ODE-based NLL upper bound,第一次给 embedding-space DLM 一个可信的 likelihood。 LM1B PPL 30.0 / OWT PPL 24.6(注意:是 held-out NLL upper bound,不是 Gen-PPL); 对应 Gen-PPL 是 LM1B 92.2 / OWT 36.5。

ELF — Embedded Language Flows (Hu/Qiu et al., MIT, arXiv 2605.10938, May 2026, 主角)。 核心决策:在冻结的 T5-small contextual embedding里做 FM,只在 t=1 用共享 transformer 切到 decode mode。 80% MSE + 20% CE,无蒸馏。OWT 32-step SDE Gen-PPL 24.08 ± 0.16

3.2 八维结构对比表 (点击列名 → 高亮该列 + 弹论文卡片)

维度 CFM FLM/FMLM DFM LangFlow ELF
State space Probability simplex ΔK(endpoint predictor πs,t one-hot 连续,simplex-valued denoiser Mean denoiser ψs,t: ℝK → ΔK-1 Learned token embedding (V×D matrix) Frozen T5-small contextual embedding (512-d)
Interpolant 几何 Straight-line stochastic interpolant Gaussian/one-hot interpolant + simplex repar Linear interpolant (β-reparameterized) VP γ-path, deterministic ODE Rectified linear: zt = t·x + (1−t)·ε
Training loss CE diagonal + endpoint consistency (CSD/ECLD) FLM: CE posterior; FMLM: KL/CE flow-map distillation CE diagonal + PSD/LSD/ESD KL consistency CE-as-Bregman (FM posterior matching) 80% MSE on velocity + 20% CE on decode
默认步数 1-step 主推(NFE 1-64 sweep) FLM: 1024-step / FMLM: 1-step 1-2-4-8 step 128-step (LM1B) / 1024-step (OWT) 32-step SDE (headline) / 64-step (scaling)
是否蒸馏 ✅ Self-distillation (CSD/ECLD) ✅ FMLM 从 FLM teacher 蒸馏 ✅ Diagonal 1M + off-diagonal 200k/100k ❌ 无 distillation,明说留给未来 无 distillation
Time grid Logit-normal w/ 0.75 diagonal fraction Decoding-error-rate reparameterization argmax-linearized + convex mix β̃(t) Learnable Gumbel scheduler Logit-normal (P_mean=−1.5)
Endpoint 处理 Simplex endpoint π; argmax 或采样 Simplex posterior; argmax Simplex mean denoiser; softmax Argmax over token probability 共享 transformer t=1 decode mode + argmax
Headline eval LM1B 1-step Gen-PPL 274.87 FLM OWT 1024-step 62.23; FMLM 1-step 168.30 LM1B 1-step Gen-PPL 68.11 (ESD) OWT Gen-PPL 36.5 @ 1024-step (PPL upper-bound 24.6) OWT 32-step Gen-PPL 24.08 ± 0.16

3.3 State space 几何 — 什么叫"在 simplex / one-hot 上走 flow"

§3.2 表第一行的 State space 是 5 篇 paper 的核心分歧点。CFM / FLM / DFM 一类说自己"在 simplex 上走 flow", LangFlow 说自己"在 learned embedding 上走",ELF 说自己"在 contextual embedding 上走"——这些到底什么意思?把它讲透。

3.3.1 One-hot 是什么

一个 token 表示成 vocab-size 长的向量,只有一位是 1,其他全是 0:

vocab V = 32128
token "bank" (id=4321) → one_hot 长度 V 的向量
   = [0, 0, ..., 0, 1, 0, ..., 0]
              ↑ 第 4321 位

整个句子(L=64 个 token)→ 形状 [64, 32128],每行是一个 one-hot。

3.3.2 Simplex 是什么

V 维空间里所有"概率分布"组成的集合:长度 V 的非负向量,且各分量之和 = 1。直观理解:

V=3 时的 simplex(直观图):

      (1,0,0)    ← token A 的 one-hot
        /\
       /  \
      / .  \    ← 内部任意点 = 概率混合
     / .. . \
    /________\
(0,1,0)    (0,0,1)
 token B    token C

3.3.3 "走 flow" 在不同 state space 是什么意思

Flow matching 的核心是 zt = t·x + (1−t)·ε——从噪声 ε 走到数据 x。 关键问题是:x 长什么样、ε 长什么样、在哪个空间里走

路线x 长什么样走 flow 的空间哪几篇
One-hot / SimplexV 维 one-hot(或 simplex 内的概率向量)L × V(V ≈ 32k)CFM, FLM/FMLM, DFM
Learned embeddingD 维 static 向量(embedding lookup)L × D(D ≈ 128)LangFlow, Diffusion-LM
Contextual embeddingD 维 context-aware 向量(encoder 输出)L × D(D = 512)ELF

3.3.4 几何含义

想象 V=32128 维空间里数据点 x 长什么样:

路线数据 manifold 形状
ELF(contextual)T5 已组织好的"语义流形",相近词在附近
LangFlow / Diffusion-LM(static)学一张 lookup 表,V 个固定点的离散点云
CFM / FLM / DFM(one-hot/simplex)V 个相互垂直的"角"(汉明距离 = 2,谁离谁都一样远)

最关键的一点:one-hot 空间里,"bank" 和 "dog" 的欧氏距离 = "bank" 和 "shore" 的欧氏距离 = √2。没有任何语义几何——token 之间的相似性必须由 Transformer 自己从头学出来。

3.3.5 4-token 直观例子

假设词表只有 4 个 token:river / bank / money / shore

3.3.6 为什么说 simplex/one-hot 派"连地图都没有,纯靠 Transformer 后处理"

回到 §2.4 那个比喻:

这就解释了为什么 CFM / FLM / DFM 都需要更多 trick:

  1. 大模型容量:Transformer 自己要把"one-hot 空间 → 内部 hidden 语义空间"的映射学出来
  2. 蒸馏 / 重参数化:FLM 蒸馏到 1 步,CFM self-distill,DFM 改 mean denoiser——都是为了弥补"几何信息缺失"
  3. 必须用 CE/KL on simplex:纯 L2 在 one-hot 上不合理(详见 §3.4 ③),所以 loss 必须搬到概率几何上

对照之下 ELF 用 MSE 是合法的——因为 contextual embedding 不是概率分布,是带几何结构的语义向量,欧氏距离反映语义距离。

3.4 四个聚类轴

① State space

  • 纯 simplex 派:CFM + DFM — endpoint / mean denoiser 都约束在 ΔK-1
  • one-hot 派:FLM/FMLM — 状态是 V 维 one-hot,但学习对象重参数到 simplex posterior
  • Learned embedding 派:LangFlow — 学一份 V×D embedding matrix
  • Frozen contextual embedding 派ELF(独占) — 借 T5-small encoder 的语义几何

ELF 是唯一不学 embedding 也不约束在 simplex 上的方案。这是它最独特的设计点。

② Distillation vs base flow

  • 蒸馏派(少步压缩):CFM + FMLM + DFM — 目标都是 1-4 步生成
  • Base flow 派:FLM + LangFlow + ELF — 目标是基础质量,步数靠采样器调

ELF 选 "32-64 步直接采样" 这条路而不蒸馏,是反潮流的—— 当时所有 LM-FM 工作都在卷"少步"。

③ Loss geometry

  • 纯 CE/KL on simplex:CFM, FLM, DFM — "L2 regression 几何不对"
  • Bregman-CE 桥:LangFlow — "embedding 走 L2 / output 走 CE",靠 Bregman 连接
  • Euclidean MSE on embedding + final CEELF — 中间态不是概率分布,是 contextual vector,所以 MSE 合理

ELF 的 MSE 合法性建立在"embedding 不是分布"这个根本前提。这也是它和其它 4 篇的分水岭。

④ 评测范式

  • Gen-PPL(外部 LM 评生成质量):CFM, FLM/FMLM, DFM, ELF
  • Held-out NLL upper bound(likelihood / PPL)LangFlow 独占(同时也报 Gen-PPL)

注意:LangFlow LM1B 30.0 / OWT 24.6 是 ODE NLL upper bound,不是 Gen-PPL。 和 ELF 的 24.08(Gen-PPL)不可直接横比。可比的是 LangFlow OWT Gen-PPL 36.5(1024 步)vs ELF 24.08(32 步)。

3.5 两个 critical reading 问题

Q1: Euclidean vs simplex mismatch 怎么处理?

Q2: 能不能蒸馏到 1 步?

不对称的原因:CFM/FMLM/DFM 的状态在 simplex 或 one-hot 上,flow map 有 clean 的几何对象(mean denoiser)可学;ELF 的状态是 T5 embedding,要做 flow map distillation 需要先证明 T5 embedding 上的 flow map 也有 clean 形式——paper 没做。

3.6 ELF 对每个 sibling 的对比叙述

ELF vs CFM:两者都把离散文本放到连续 FM 里,但CFM 的连续对象是 simplex-valued endpoint,ELF 的是 T5 contextual embedding。 CFM 押注"少步 self-distillation",单步 LM1B 274.87;ELF 押注"无蒸馏 32 步",OWT 24.08。 两条完全互补的路线——蒸馏路 vs 采样路。(点击此标题 → 滚到 §3.2 表 + 高亮 CFM 列)

ELF vs FLM/FMLM:两者都做 continuous flow LM,也都在 OWT 上用 Gen-PPL; 但 FLM 用 V 维 one-hot(vocab 越大状态越大),ELF 用 512 维 T5 embedding + 128 bottleneck。 FLM 的卖点是天然可蒸馏到 FMLM;ELF 的卖点是无蒸馏即少步。 直接对比:FLM 1024 步 OWT Gen-PPL 62.23 vs ELF 32 步 24.08 —— ELF 完胜,但代价是放弃了 vocab-level 状态空间的"天然性"。

ELF vs DFM:两者都认真做 ablation,但变量完全不同—— DFM 在变 PSD/ESD consistency 把 flow map 压到 few-step;ELF 在变 ODE/SDE/γ/SC-CFG 把 base flow 调到 32 步。 DFM 的强点是 1-4 NFE;ELF 的强点是无蒸馏 32 NFE。再次互补。

ELF vs LangFlow:两者都在 embedding-space 做 FM,也都报 OWT 数字;但 LangFlow 的 24.6 是 held-out NLL upper bound,ELF 的 24.08 是 Gen-PPL不能直接比。 可比项:LangFlow OWT Gen-PPL 36.5 (1024 步) vs ELF 24.08 (32 步) —— ELF 在生成质量上明显优。 反过来,LangFlow 有 likelihood story(可信 NLL bound + 4/7 zero-shot benchmark 超过 AR),ELF 目前没有对应 likelihood claim。 两边互补。

3.7 FM 家族一句话总结

这 5 篇本质在回答同一个问题:语言 FM 应该把"连续性"放在哪里? Simplex(CFM/DFM)、one-hot(FLM)、learned embedding(LangFlow),还是 frozen contextual embedding(ELF)
ELF 押注最后一种,并用无蒸馏 32 步 OWT Gen-PPL 24.08 证明这条路在 sample quality 上最有说服力。

4 · 训练 pipeline(ELF 算法核心)

💡 阅读前置:把 ELF 网络当黑盒

本节先把 ELF 网络当成一个抽象函数:

x̂ = netθ(z, t, c, ω, mode)

这一节不依赖具体架构——只要 netθ 是个能接受 (z, t, c, ω, mode) 的可学函数即可。具体 Transformer 实现(T5 encoder、DiT block、RoPE、bottleneck、factored decoder head 等)下一节 §5 展开。

💡 关键澄清:网络直接预测 clean embedding x̂,不是 velocity

ELF 用的是 x-prediction parameterization(App C.1 ablation 选出来的—— x-pred 在高维比 v-pred / ε-pred 稳定,因为"clean text 在低维流形上"):

# sampling_utils.py:115-127
def net_out_to_v_x(net_out, z, t, t_eps=5e-2):
    x = net_out                            # ← 网络输出直接就是 x̂(clean embedding pred)
    denom = torch.clamp(1.0 - t, min=t_eps)
    v = (x - z) / denom                    # ← velocity 是后处理算出来的
    return v, x

但这不代表跳过逐步去噪。推理仍然 32 步 Euler 迭代:

for step i in range(32):
    x̂  = net_θ(z_t, t, c, ω, "denoise")    # 每步预测 clean embedding
    v  = (x̂ - z_t) / (1 - t)               # 由 x̂ 换算瞬时 velocity
    z_t = z_t + dt · v                     # Euler 走一小步
# 最后一步:x̂ = net_θ(z_t≈x, t=1, c, ω, "decode") → unembed → token

三件事要分清:

真值
网络直接输出什么clean embedding x̂(一直在预测最终目标 x0
Flow Matching 框架下的 transport 量velocity v
推理过程仍然 32 步 Euler 迭代;每步用网络的 x̂ 算 v 再走一小步

为什么这么设计

  1. x-pred 的 target 固定(clean embedding x0),不随 t 变;v 和 ε 的 target 都随 t 变化,高维下难学
  2. App C.1 显示 ε-pred 在 768/1024 dim 直接 collapse
  3. 实际 MSE loss 形式上是 ‖v_pred − v_target‖²,但 v_pred / v_target 都从 x_pred / x0 换算,梯度其实在监督 x_pred → x0
App C.1 Fig 10 — Prediction targets
Fig 10 (paper p.21, App C.1): x-pred vs v-pred vs ε-pred 在三种 encoder 维度(T5-small 512 / T5-base 768 / T5-large 1024)下的 Gen-PPL ↔ entropy frontier。<strong>x-pred (橙色) 在全 dim 都稳定 + frontier 接近平行</strong>;v-pred 在 512 ok,768 退化,1024 严重退化;<strong>ε-pred 全 dim 都崩</strong>(落到红色 entropy &lt; 5 或 PPL &gt; 300 区)。

和 consistency model / mean flow 的区别:那些工作目标是真的 1 步从 noise 直接跳到 x(学一个"时间无关的 x-prediction")。ELF 没走那条路——仍是 32 步,但每步的局部预测对象是 x 而非 v。FMLM 就是 FLM 的 consistency-distilled 1-步版本;ELF 的 future work 也提了这个方向。

Training pipeline Fig 9
Fig 9 (paper p.16, App B.1): clean embedding x → corrupt → self-condition → add control → ELF net → MSE 或 CE loss。<strong>同一个 ELF 网络</strong>同时学 denoise 和 decode,靠 model-mode token 切换。
4.1 实际代码执行顺序(≠ 论文叙述顺序)

论文 App B.1 把 Alg 3 / 4 写得像两条独立 pipeline。PyTorch port 实际是这样执行(按 train_step.py):

  1. Label drop mask 应用到 T5 encoder attention mask(XSum/WMT 有,OWT 无)
  2. T5-encode + 归一化input_ids → x₀ ∈ [B, L, 512](x₀−μ)/0.2
  3. 为整个 batch 构造 denoiser corruption:抽 per-sequence t,加噪 z = t·x + (1−t)·ε·2.0
  4. Per-example 抽 decoder/denoiser gate(Bernoulli(0.2))
  5. Decoder 行另构造 per-token corruption:抽 per-token pz̃ = p·x + (1−p)·ε·5.0
  6. Mixz_mixed[row] = decoder_z if decoder_step_active[row]==1 else denoiser_z
  7. 第 1 次 no-grad forward:self-cond 输入 = zeros,用于构造 uncond reference
  8. 主 gradient-tracked forward:self-cond 输入 = stopgrad(uncond x_pred);CE + L2 都来自这次
  9. 第 2 次 no-grad forward:和第 8 步同输入,但用于构造 CFG cond reference
  10. CFG target 组装:v_target = v + (1−1/ω)(v_cond − v_uncond),.detach()
  11. Loss 合并(单一分母)+ grad clip 1.0 + Muon step + EMA

关键认识:CFG target 的第二个 no-grad forward 在主 gradient forward 之后。 数学目标不变(用 v_target 监督 v_pred),但实现顺序不是"两个 no-grad 然后一个 grad"。

4.2 论文 Algorithm 3 (denoiser) + Algorithm 4 (decoder)

💡 MSE 和 CE 在 ELF 里各扮演什么角色?(看 ELF 之前先看这段)

ELF 用两条不同的 loss 学两件不同的事。这是 ELF 整篇 paper 最容易被误解的点:

Loss学什么作用训练频率
LMSE(denoiser) 连续 embedding 空间里 transport 噪声到 clean embedding 学"动力学"——给定 noisy zt 和时间 t,怎么把 zt 沿 flow 推到 x。这是 Flow Matching 的核心。 80%(per-example Bernoulli(0.2) 抽 decoder 模式,剩下都是 denoiser)
LCE(decoder) 离散 token 空间里把 clean embedding 投影回 vocab 学"离散化"——把 flow 终点 (clean embedding ≈ x) 映射到 32128 个 token id。这是 ELF 唯一触及离散 token 的训练步骤。 20%

为什么必须有 CE? 你既然懂 MSE,问题就清楚了: MSE 只能让模型预测 clean embedding,但下游任务要的是生成 token。 如果只有 MSE,推理时拿到 clean embedding 后没办法回到 token(找最近的 T5 embedding 找最近邻?精度差且不可微)。 所以必须额外训练一个 decoder head,把 embedding 映射回 vocab logits —— 这个 head 必须用 CE 训练(因为 vocab 是离散的)。 ELF 的优雅之处是:decoder head 和 denoiser 共享同一份 transformer 权重,只用 4 个 mode token 切换 mode。

所以 ELF 的训练目标本质是:

Ltotal = 𝔼(s, c) [ (1−pdecode) · LMSE(transport) + pdecode · LCE(round-to-token) ]

其中 pdecode = 0.2 是 decoder 分支抽样概率(论文 Tab 4 默认)。

💡 在看 Algorithm 3 之前:什么是 self-conditioning(自条件)?

Self-conditioning(来自 Chen, Zhang, Hinton, "Analog Bits", ICLR 2023,ELF 引用 [9])是 diffusion / flow matching 的一个推理时迭代精修技巧:

普通 diffusion 每一步 forward 只看 ztt = net(zt, t)。 Self-cond 让网络额外接收上一步的预测作为输入:t = net(zt, t, x̂t−1)。 推理时这相当于"我已经有一个 partial 估计,refine 它",比从零预测更稳。

但训练时模型并没有"上一步预测" — 因此训练用以下 trick 模拟:

  1. 50% 概率跑两次 forward:第一次 self-cond 输入填 0,得到 x̂no_sc; 第二次 self-cond 输入 = stopgrad(x̂no_sc),对这次反传梯度。 这样模型学到"给 x̂prev 作 input 时怎么 refine"。
  2. 另外 50% 概率self-cond 输入直接填 0 — 让模型也学到"没有 prior 估计时怎么从 z 直接预测"(推理第 1 步用得到)。

为什么要 stopgrad?因为不希望梯度通过第一次 forward 回流—— 那会让训练目标变成"让 x̂no_sc 也参与优化",破坏 mathematical setup。 ELF 的 self-cond 完全沿用 Chen et al. 这个标准做法。

实现上 self-cond 把网络输入维度从 D 扩到 2D(concatenate [z, x̂prev]), 然后用一个 self_cond_proj: 2D → D 线性层压回 D。 所以你在 Alg 3 看到的 self_cond_proj(concat([z, ...], dim=-1)) 就是这个压回操作;不是 ELF 独创

💡 那 ELF 为什么要用 self-conditioning?为什么不就是 plain Flow Matching?

有 5 个理由,按重要性排序:

  1. 推理时迭代精修 — 同样步数下质量更好
    Plain FM 的 32 步采样 = 32 次独立预测,每步只看 zt。 带 self-cond:每步还看上一步的 x̂i−1,等价于"看着你之前的答案 refine"。 轨迹更平滑,达到同质量需要的总步数少 2-3 倍。这是 self-cond 的原始动机
  2. x-prediction 让 self-cond 特别契合
    ELF 每步直接预测 clean x̂(见上面"关键澄清")。自然的语义是"我这步的 x̂ 比上步更准吗"。 self-cond 显式把 x̂i−1 当 input,让网络有"refine"的语义把手。 v-prediction 没这个直觉——v 是局部量,前一步的 v 跟当前步关系弱。
  3. self-cond 是 SC-CFG (Eq 3) 的 prerequisite ← 这一条最被忽略。
    SC-CFG 的 "C" 就是 self-Conditioning。Eq 3 的 (1−1/ω)·(v_cond − v_uncond) 里:

    • "uncond" = self-cond 输入填 0
    • "cond" = self-cond 输入 = stopgrad(x̂_no_sc)

    没 self-cond 就没 SC-CFG,没 SC-CFG 就没有"推理只跑 1 次 forward"的训练时 CFG 优化(见 §4.4)。 所以整套 Eq 3 训练时 CFG trick 都建立在 self-cond 之上。

  4. Chen et al. 2023 证明 discrete-target diffusion 没 self-cond 不行
    "Analog Bits" 的原始动机:在 discrete data(token / quantized image)上做 diffusion, 没 self-cond 质量差一大截。即使 ELF 在 continuous embedding 空间跑,最终输出仍是离散 token —— ELF 也吃这个红利。
  5. Cost 不大,所以"为啥不用"反而需要理由
    训练 forward 数:1 → 最多 3 次(no-sc + sc + gradient),但训练只跑一次(5 epoch)。 推理 forward 数:完全不变(OWT 无条件 1 forward/step,因 SC-CFG 已 baked-in)。 代价主要是训练时间约 1.5×,换来更稳的轨迹 + 训练时 CFG 优化 + 离散 target 上更好的质量。

所以 Alg 3 里 self-cond 占大位置 ≠ 概念复杂

Alg 3 里 self-cond 看起来占大半篇幅是因为记账复杂,不是概念复杂。剥掉 plumbing 后本质就是 "plain FM + Chen 2023 的 self-cond + Eq 3 的训练时 CFG"。下面这 5 个 step 都是 implementation 细节,不是 5 个独立 ELF 创新:

Alg 3 里的步骤它在干嘛
x_no_sc = net(concat([z, 0]))no-cond reference,self-cond 输入填 0
stopgrad(x_no_sc)不让梯度回流第一次 forward(否则训练目标污染)
x_sc = net(concat([z, stopgrad(x_no_sc)]))self-cond reference forward
(1−1/ω)·(v_sc − v_no_sc)SC-CFG guidance 烤进 v_target(Eq 3)
where(self_cond_mask, ..., ...)50% 概率二选一(让模型也学"没 prior" 的情况,覆盖推理第 1 步)

一句话:self-cond 是 ELF 站在 Chen 2023 + DiT-style 训练时 CFG 这两个肩膀上的"工程接缝"。 拿掉 self-cond 也能跑 plain FM,但你会失去 (a) 32 步达到同质量的能力 (b) Eq 3 训练时 CFG 这两个 ELF 系统级卖点。

论文 App B 写成两个独立算法。PyTorch port 改成 per-example mix,数学上等价(见 §4.6)。两个算法的关键差异:

Denoiser (Alg 3)Decoder (Alg 4)
Mode token gate0(无 mode 信号)1(mode token 激活)
时间 tper-sample t = σ(N(Pm=−1.5, Ps²=0.8²))t = 1(始终终点)
Corruption ratioper-sample tper-token p = σ(N(0.8, 0.8²))(独立!)
Noise scale2.05.0 (OWT) / 1.0 (XSum, WMT)
Self-cond input50% stopgrad(x'); 50% zeros始终 zeros(不学 self-cond)
LossMSE on velocity(base + CFG-augmented target)CE per token
Output headFinalLayer (768→512 flow output)Factored decoder (768→512→32128)

Algorithm 3 伪代码(denoiser,paper App B.1,去 LaTeX):

# Algorithm 3: ELF denoiser training with conditioning and guidance
# net(z, t, c, w, mode): ELF network with in-context conditioning
# self_cond_proj(z): concat-to-original-dim projection
# self_cond_prob: 0.5
# s: discrete token sequence
# c: condition (optional, only for XSum/WMT)

x = encode(s)                               # T5 frozen forward
t = sample_t()                              # logit-normal scalar per sample
w = sample_sc_cfg_scale()                   # ω ∈ [0.5, 5], paper: power-biased
                                            # PyTorch port: shifted log-uniform
e = randn_like(x)                           # standard Gaussian(与 paper 一致)
z = t * x + (1 - t) * e                     # rectified-flow interpolant
v = x - e                                   # base velocity target

# (1) 不带 self-cond 的 forward (no_grad)
z_no_sc = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
x_no_sc = net(z_no_sc, t, c, w, mode="denoise")
v_no_sc = (x_no_sc - z) / (1 - t)

# (2) 带 self-cond 的 forward (no_grad, stopgrad on x_no_sc)
z_sc = self_cond_proj(concat([z, stopgrad(x_no_sc)], dim=-1))
x_sc = net(z_sc, t, c, w, mode="denoise")
v_sc = (x_sc - z) / (1 - t)

# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)

# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob
v_pred   = where(self_cond_mask, v_sc, v_no_sc)
v_target = where(self_cond_mask, v_target, v)
v_target = stopgrad(v_target)

loss_denoise = mse_loss(v_pred, v_target)

逐行拆解:v_target 是怎么构造出来的?(Alg 3 最绕的 5 行)

看 Alg 3 最容易迷的就是 (3) 之后 wherestopgrad 那 5 行。它们做了 4 件事:

# (3) CFG target: post-combination quantity
v_target = v + (1 - 1/w) * (v_sc - v_no_sc)     # 计算"带 CFG"的 target

# Per-example self-cond mask
self_cond_mask = uniform(B) < self_cond_prob     # 每个 example 独立抽 Bernoulli(0.5)
v_pred   = where(self_cond_mask, v_sc, v_no_sc)  # gradient forward 走哪条
v_target = where(self_cond_mask, v_target, v)    # target 选哪个
v_target = stopgrad(v_target)                    # 截梯度

关键 insight:这是同一个 batch 同时训练两种模式。 mask 把 batch 切两半:50% 走 self-cond + CFG 路径,50% 走 plain FM 路径。 这两条路径必须用同一个 mask挑 v_pred 和 v_target,否则配对错乱。

mask 值输入有 self-cond?v_pred (gradient)v_target网络学到什么
True (50%) 有,self-cond 输入 = stopgrad(x̂_no_sc) v_sc(带 self-cond 的 forward) (x − ε) + (1 − 1/ω)·(v_sc − v_no_sc) "给 prior 估计 + ω,输出 CFG-amplified velocity"
False (50%) 没有,self-cond 输入 = 0 v_no_sc(不带 self-cond 的 forward) v = (x − ε)(base velocity) "没 prior,给 plain FM velocity"(推理第 1 步用得到)
三个不容易看出来的细节
  1. 为什么 v_pred 和 v_target 必须用同一个 mask?
    它们是配对的。mask=True 的 example:gradient forward 输入带 self-cond,那 target 也必须是"with-self-cond 想达到的样子"。 如果 mask 不对齐——比如 v_pred 是 v_sc 但 target 是 base v——网络会被告知"用 self-cond 输入预测 no-self-cond 的 target",等价于教网络不用 self-cond 输入也行,破坏整个 self-cond 机制。
  2. 为什么 CFG 只 apply 到 mask=True 的 example?
    SC-CFG 的本质是"放大 self-cond 信号"(cond = 有 prior,uncond = 没 prior,按 ω 外推)。 mask=False 的例子连 self-cond 输入都是 0,不存在 cond/uncond 的对偶,没什么可放大。 所以这种 example 的 target 就是 plain FM 的 base velocity (x − ε),没有 CFG 项。
  3. stopgrad 在防范什么?
    v_no_sc 和 v_sc 都在 no_grad 上下文里算出来的(已经无梯度),v = (x − ε) 来自 leaf tensor x 和 ε 也无梯度。 所以理论上 v_target 本身就没有 gradient flow 进网络。 stopgrad 是防御性的——(a) 文档上明确"v_target 是 supervision,不可微";(b) 防止有人 fork 代码后去掉 no_grad 时还能保住语义。 这是好的工程习惯。
那剩下的 loss_denoise 是什么?

就是 mse_loss(v_pred, v_target) 算 L2 距离:

这里 mse_loss 展开是什么?

就是 velocity 上的 L2 距离 + per-token mean,标准 Flow Matching 损失。具体计算:

# v_pred ∈ [B, L, 512],来自 ELF 主干 gradient-tracked forward 后转 velocity:
#   x_pred = ELF_transformer(z_self_cond, t, c, ω, decode_mode=False)
#   v_pred = (x_pred - z) / clamp(1 - t, t_eps)            # paper Eq 4

# v_target ∈ [B, L, 512],由 Eq 3 训练时 CFG 构造(前面 4.4 节详):
#   v_base    = (x_0 - z) / clamp(1 - t, t_eps) = (x - ε·noise_scale)
#   v_target  = v_base + (1 - 1/ω) · (v_cond - v_uncond)
#   v_target  = v_target.detach()                          # 梯度只走 v_pred

l2_per_token = ((v_pred - v_target) ** 2).mean(dim=-1)     # [B, L]  channel-wise mean
l2_per_token *= loss_mask                                  # 排除 padding + cond positions

数学上 ELF 的 MSE 等价于:

LMSE = 𝔼(x, c) 𝔼t, ε [ Σi ∈ valid ‖ vθ(zt, x'prev, t, c, ω)i − vtarget,i22 / D ]

其中:

符号含义
xfrozen T5-small encode 后的 contextual embedding,归一化后(× 1/0.2 = ×5)
ztnoisy embedding:zt = t·x + (1−t)·ε·noise_scale,其中 noise_scale = 2.0
tper-sequence corruption 时间,t ~ σ(N(−1.5, 0.64)),logit-normal
ε标准高斯 ∈ ℝD,D = 512(T5-small encoder dim)
vθELF 主干预测的速度场(实际预测 x_pred,再转 v = (x_pred − z) / (1−t))
vtarget训练时 CFG 烤进的 target velocity(Eq 3,.detach() 截梯度)
x'prevself-conditioning 输入(50% 概率 = stopgrad(uncond x_pred);50% = zeros)
‖·‖22 / Dchannel 维(512)取均值,等价于 per-channel MSE
valid 位置排除 padding + 条件生成的 cond positions

Algorithm 4 伪代码(decoder,paper App B.1):

# Algorithm 4: ELF decoder training with conditioning and guidance
x = encode(s)
p = sample_per_token_p()                    # logit-normal PER TOKEN (different from denoiser)
w = sample_sc_cfg_scale()
e = randn_like(x)                           # standard Gaussian(与 paper 一致)
z = p * x + (1 - p) * e                     # per-token corruption ratio

# decoder always uses zero self-cond input
z = self_cond_proj(concat([z, zeros_like(z)], dim=-1))
h = net(z, t=1, c, w, mode="decode")        # mode token gate = 1
s_pred = unembed(h)                         # factored decoder head

loss_decode = ce_loss(s_pred, s)

这里 ce_loss 展开是什么?

就是标准的 per-token cross-entropy,没有任何 ELF-specific 改造。具体计算(来自 src/train_step.py):

# decoder_logits ∈ [B, L, 32128],由 ELF 主干 forward + factored decoder head 给出:
#   hidden_768 = ELF_transformer(z̃, t=1, c, ω, decode_mode=True)[B, L, 768]
#   hidden_512 = GELU_tanh(hidden_768 @ proj_kernel + proj_bias)     # 768 → 512
#   decoder_logits = hidden_512 @ unembed_kernel + unembed_bias       # 512 → 32128

log_probs    = F.log_softmax(decoder_logits.float(), dim=-1)          # [B, L, 32128]
ce_per_token = -log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
# ⇔ -log p_θ(s_i | z̃_i, t=1, c, ω, mode=decode)                     # [B, L]

# 然后 mask 掉 padding 位置 + 条件生成的 cond 位置,再聚合
ce_per_token *= loss_mask   # loss_mask = (1 - cond_seq_mask) * attention_mask

数学上 ELF 的 CE 等价于:

LCE = − 𝔼(s, c) 𝔼pi, εi [ Σi ∈ valid log pθ(si | z̃i, t=1, c, ω, mode=decode) ]

关键变量含义:

符号含义
si第 i 个位置的真实 token id(ground truth)
i该位置的 per-token corrupted clean embedding: z̃i = pi·xi + (1−pi)·εi·noise_scale, 其中 pi ~ σ(N(0.8, 0.64)) 每个 token 独立抽,noise_scale = 5.0 (OWT) / 1.0 (cond)
t=1decoder 分支永远在终点;时间 token 编码 t=1
c条件输入(XSum / WMT 才有;OWT 无条件 c=∅)
ωSC-CFG scale ∈ [0.5, 5](虽然 decoder 分支不学 SC-CFG guidance,但 ω token 仍 prepend 作为输入)
mode=decode4 个 mode_token gate=1(denoiser 分支时 gate=0,token 被乘 0)
pθ(·)factored decoder head 输出的 softmax 分布(768 → 512 GELU → 32128 → softmax)
valid 位置排除 padding token + 条件生成的 cond positions(cond 不要求模型预测)

三个容易混淆的点

  1. decoder 分支永远 t=1。不像 denoiser 分支 t 从 logit-normal 抽,decoder 总是在"clean 端点"做 final-step decoding。 但 z̃i 仍然有 corruption——只是 corruption ratio pi 从一个独立的 logit-normal 分布抽(per token,不是 per sequence),让 decoder 见过 noisy embedding 也能还原 token。
  2. decoder 分支不带 SC-CFG guidance。Eq 3 的 self-conditioning CFG 只用于 denoiser 分支的 LMSE。 decoder 分支自己只跑一次 forward,输入 self-cond 部分填 0,直接算 CE。
  3. per-token corruption,不是 per-sequence。这是 decoder 和 denoiser 最大的差别—— denoiser 的 t 是 (B,) scalar per 序列;decoder 的 p 是 (B, L, 1) tensor per 位置。 意义:模拟推理时不同 token 位置 reconstruction 质量不同(有些 token 在 ODE 轨迹上接近 clean,有些差远了)。

4.3 Embedding corruption — 两套独立的 logit-normal schedule

两条分支用不同的时间/腐蚀分布。Paper App B.1 + 代码 src/configs/training_configs/train_owt_ELF-B.yml 默认:

分支分布P_meanP_stdNoise scale说明
denoiserper-sequence logit-normal−1.50.82.0t = σ(N(−1.5, 0.64)) → 偏向小 t(噪声多)
decoder (OWT)per-token logit-normal0.80.85.0p = σ(N(0.8, 0.64)) → 偏向大 p(接近 clean);每个 token 独立
decoder (XSum/WMT)per-token logit-normal0.80.81.0条件生成用更小 noise

核心思路:denoiser 训练时多接触噪声大的样本(学习如何 transport), decoder 训练时多接触接近 clean 的样本(学习如何 round 回 token)。 而且 decoder 是每个 token 独立抽 corruption ratio,模拟推理时不同位置的 reconstruction 质量。

💡 两个细节背后的设计意图:(1) "大 p = clean" 的约定 (2) per-sequence vs per-token 的 granularity 差

(1) 为什么"大 p = clean"?

ELF 用 Lipman et al. 2023 rectified-flow 标准插值公式:

z = p · x + (1 − p) · ε · noise_scale

p 值z 长什么样翻译
p = 0z = ε · 5.0 = 全是噪声完全 noisy
p = 0.5z = 0.5·x + 0.5·ε·5.0一半信号一半噪声
p = 1z = x = 完全 clean embedding完全 clean

所以 p 是"信号比例",1−p 是"噪声比例"。p 越大,z 里 clean x 成分越多。 denoiser 的 t 同理:z_t = t·x + (1−t)·ε·2.0,t=1 是 clean 端点,t=0 是噪声端点。

(2) 为什么 denoiser t 是 per-sequence,decoder p 是 per-token?

关键原则:训练分布必须匹配推理分布。两条分支推理时见到的样子完全不同:

denoiserdecoder
推理时跑在哪 整条 32 步 ODE 轨迹,所有 token 位置同步沿 t 推进 32 步完后跑一次 forward,把 zt=1 映射到 token
推理时各 token 位置的 corruption 状态 同一时刻 t,所有位置共享同一个 corruption level 不同位置 ODE 落地质量不均——有的 95% 干净,有的 70% 干净(attention 把噪声去除得不均衡)
训练 sample 单位 per-sequence t ~ logit-normal(−1.5) per-token pi ~ logit-normal(+0.8)
shape t: (B,) → broadcast 到 (B, 1, 1) p: (B, L, 1) 每个位置独立
动机 所有位置同 t 才能让 attention 学到正确的 transport 动力学 per-token p 训练 decoder 在per-position 残余噪声不均时仍能 round 回 token

Denoiser 不能 per-token:如果训练时同一序列内不同位置取不同 t,attention 看到"有的位置很 noisy 有的很 clean"——但推理时根本不会出现这种状态。模型会学到错误的 transport 模式。

Decoder 必须 per-token:推理时 32 步 ODE 落地的 zt=1 不是完美 clean,不同位置残余噪声幅度不同(attention 在 32 步里对各位置的去噪进度不同步)。 Per-token pi 训练 decoder看上下文判断该位置真正的 token——paper App B.1 原话: "encourages the shared-weight decoder mode to recover corrupted embeddings from their surrounding context, making final-step discretization more robust to imperfect embeddings produced by the denoiser at inference time."

注意 p 不作为网络输入

per-token pi 只通过 z̃i 的"信号/噪声混合比例"隐式传递给网络——网络直接知道某个位置的 pi 是多少。网络的时间输入是固定的 t=1 scalar,decoder 通过 attention 看周围上下文自动判断该位置可信度。

对比一图看清

denoiserdecoder
P_mean−1.5(偏小 t)+0.8(偏大 p)
σ(P_mean) 中位~0.18(多 noisy)~0.69(多 clean)
大部分样本落在[0.05, 0.4] noisy 区[0.5, 0.95] clean 区
Granularityper-sequence(match transport 同步性)per-token(match 终点 per-position 不均)
Noise scale2.0OWT 5.0 / 条件 1.0
# sampling_utils.py::add_noise (denoiser)
def add_noise(x0, noise, t, config, cond_seq_mask=None):
    t_expanded = t.reshape(-1, 1, 1)                       # [B, 1, 1]
    z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale
    if cond_seq_mask is not None:                          # 条件生成时保留 cond token 不腐蚀
        z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z
    return z

# train_step.py — decoder branch (per-token p sampling)
decoder_z_vals = (
    torch.randn((B * L,), dtype=dtype, device=device)
    * config.decoder_p_std + config.decoder_p_mean         # = N(0.8, 0.8²)
)
decoder_lambda_t = torch.sigmoid(decoder_z_vals).reshape(B, L, 1)  # [B, L, 1] per-token
decoder_noise = torch.randn(x0.shape, dtype=dtype, device=device) * config.decoder_noise_scale
decoder_z = decoder_lambda_t * x0 + (1 - decoder_lambda_t) * decoder_noise

4.4 训练时 CFG(Eq 3,全文最 tricky 的点)

💡 在看 Eq 3 之前:CFG 是什么?为什么 ELF 要做"训练时"版本而不是标准做法?

Classifier-Free Guidance (CFG)(Ho & Salimans, "Classifier-Free Diffusion Guidance", 2022) 是 image diffusion 的一个条件信号放大器

训练时让一个网络同时学有条件 vθ(z, t, c) 和无条件 vθ(z, t, ∅)(10% 概率把 c 置空)。 推理时按系数 ω > 1 把两者外推:

vfinal = vuncond + ω · (vcond − vuncond)

直觉:朝"和无条件的方向"反方向多走一步,等价于更强地遵从条件 c。ω=1 是原始条件 forward;ω=3 时条件信号被显著放大(细节更锐利、和 prompt 更贴)。 代价:推理每步要跑 2 次 forward(cond 一次 + uncond 一次),算力 ×2。 这是 DALL·E / Stable Diffusion / Imagen 等图像 diffusion 的标配。

ELF 的问题:每步 ×2 forward 太贵

ELF OWT 默认 32 步 × 1 forward = 32 forward。如果套标准 CFG → 64 forward,推理算力翻倍。 更要命的是:ELF 把 sampling step 从 baseline 的 1024 步压缩到 32 步是它的核心卖点,再 ×2 就失去了步数效率优势

解决方案:训练时 CFG(Chen et al. "Visual Generation without Guidance", ICML 2025)

这个 trick 是 Chen 等人 2025 年 image diffusion 的工作(ELF App C 引用)。核心想法: 不在推理时做 CFG 组合,而是在训练时就把 CFG 组合烤进 v_target。 让模型直接学 post-combination quantity v_θcfg,推理只需一次 forward 就拿到等价于 CFG 后的 velocity。

标准推理时 CFG训练时 CFG (Chen 2025 / ELF)
训练 forward 数1(只学 v_cond / v_uncond 之一)3(uncond + cond + grad-tracked)
推理 forward 数 / 步2(cond + uncond)1
32 步总推理 forward6432
训练 1.5× ↔ 推理 2× 谁划算训练 5 epoch 1 次,推理跑无数次 — 训练时 CFG 完胜

ELF 把训练时 CFG 应用在哪?— Self-cond CFG,不是 input-cond CFG

注意 ELF 的 c 在 Eq 3 里不是文本 condition,而是self-conditioning 输入(见前面 self-cond 那个 callout):

所以 ELF 实际上把两个独立的 CFG 机制分开处理:

机制训练时还是推理时放大什么OWT 用XSum/WMT 用
SC-CFG(Eq 3)训练时(baked-in)self-conditioning 信号ω=3ω=1(已 baked,不额外推理)
Input-cond CFG(标准)推理时文本 condition 信号—(无条件)ω=2(推理时 ×2 forward)

为什么不把 input-cond CFG 也烤进训练? 因为 input-cond CFG 的 ω 通常需要在推理时 sweep 不同值(找最佳质量),烤进训练就锁死了。 SC-CFG 的 ω 在 ELF 里是固定行为(不用 sweep),所以烤进训练划算。 这是一个 nuanced 的工程权衡。

一句话总结

Eq 3 = Chen 2025 训练时 CFG 这个 image diffusion trick,移植到 ELF 的 self-conditioning 维度, 让 32 步无条件采样的算力优势在加 CFG 后仍然保住。剥掉这个 trick, ELF 要么没 SC-CFG(质量低),要么推理 32 步 ×2 forward(失去步数效率卖点)。

Image diffusion 的 classifier-free guidance 通常是推理时跑两遍 forward 然后线性组合:

推理时 CFG: vfinal = vuncond + ω · (vcond − vuncond)

ELF 把它烤进训练——模型直接学 post-combination quantity,推理只需一次 forward。Paper Eq 3:

vtarget = (x − ε) + (1 − 1/ω) · (vθcfg(zt | t, c, ω) − vθcfg(zt | t, ∅, ω))

📐 Eq 3 的 (1 − 1/ω) 系数怎么推出来?

这个系数不是 ad hoc 写的,是从标准 inference-time CFG 5 步代数推导出来的。Chen et al. ICML 2025 "Visual Generation without Guidance" 的核心贡献就是这个 reparameterization。

Step 1 · 标准 CFG 公式(Ho & Salimans 2022 推理时 CFG):

vfinal(ω) = vu + ω·(vc − vu) = vc + (ω − 1)(vc − vu)

vc, vu 是"不加 CFG"时网络对 cond / uncond 的 logical base 输出。

Step 2 · 训练时 CFG 的网络已经学 post-combination quantity

ELF 网络输出 vθcfg(c, ω) 直接就是 vfinal(推理不再组合)。所以:

两者之差已经被 ω 预放大

vcfgc − vcfgu = [vu + ω(vc−vu)] − vu = ω·(vc − vu)

Step 3 · 反推 logical base 差

vc − vu = (1/ω)·(vcfgc − vcfgu)

Step 4 · 构造 training target

我们要让网络的 vcfgc 收敛到 vfinal。代入 Step 3:

vtarget = vc + (ω − 1)·(vc − vu)
= vc + (ω − 1)·(1/ω)·(vcfgc − vcfgu)
= vc + (1 − 1/ω)·(vcfgc − vcfgu)

所以 (1 − 1/ω) 等于 (ω − 1) × (1/ω) 化简的结果——前一项是标准 CFG"超越 cond 的放大量",后一项是把"已被 ω 预放大的差值"还原回 logical base。

Step 5 · FM target 替换 vc

base FM target 就是 vc = x − ε(rectified-flow 标准 target),代入得 Eq 3:

vtarget = (x − ε) + (1 − 1/ω)·(vcfgc − vcfgu)

边界检查

ω(1 − 1/ω)vtarget含义
10x − ε无 CFG,退化为 plain FM ✓
20.5(x−ε) + 0.5·(vcfgc − vcfgu)中等放大
30.667(x−ε) + 0.667·(...)ELF 默认 SC-CFG=3
50.8(x−ε) + 0.8·(...)SC-CFG 上限
1(x−ε) + (vcfgc − vcfgu)极限放大

ω=1 退化 ✓、ω→∞ 系数趋于 1 ✓、ω=3 对应 ELF OWT 默认 ✓。

这里的 trick 是 self-cond 的"condition"不是输入文本 c,而是 self-cond 输入 x'。 所以"uncond"就是 x'=0,"cond"就是 x'=stopgrad(net1)。CFG scale ω∈[0.5, 5] 是 ELF 自己也学的输入参数(4 个 SC-CFG token 编码)。

实现需要 3 次 forward。代码实际执行顺序是:

  1. v_uncond:no_grad,self-cond 输入 = zeros(最早跑,作 baseline)
  2. v_pred带梯度 forward,self-cond 输入 = stopgrad(uncond 的 x_pred)。L2 loss 用这个 v_pred
  3. v_cond:no_grad,与第 2 次同输入,用于构造 v_target(不是 v_pred)

梯度只过第 2 次。最终 L2 loss = ‖v_pred − stopgrad(v + (1−1/ω)(v_cond − v_uncond))‖²。

# src/train_step.py — Eq 3 的自条件 CFG target 构造(简化版)

def compute_shared_uncond(z, t_input, x_tokens):
    # forward #1: self-cond input = zeros
    z_uncond = restore_cond(torch.zeros_like(z), x_tokens, cond_seq_mask)
    z_input_uncond = torch.cat([z, z_uncond], dim=-1)
    with torch.no_grad(), autocast(bf16):
        net_out_uncond = model(z_input_uncond, t_input,
                               self_cond_cfg_scale=self_cond_cfg_scale)
    return net_out_uncond

def get_sc_cond_and_uncond(z, t_input, cond_mask, x_tokens, shared_net_out_uncond):
    v_uncond, x_uncond = net_out_to_v_x(shared_net_out_uncond, z, t_input, t_eps)
    x_uncond = restore_cond(x_uncond, x_tokens, cond_mask)

    # forward #2: self-cond input = stopgrad(x_uncond)
    z_input_cond = torch.cat([z, x_uncond], dim=-1)         # 注意 stop-grad on x_uncond
    with torch.no_grad(), autocast(bf16):
        net_out_cond = model(z_input_cond, t_input,
                             self_cond_cfg_scale=self_cond_cfg_scale)
    v_cond, _ = net_out_to_v_x(net_out_cond, z, t_input, t_eps)
    return v_cond, v_uncond

def get_sc_guided_v(z, t_input, base_v_target, x_tokens, shared_net_out_uncond):
    v_cond, v_uncond = get_sc_cond_and_uncond(...)
    sc_w = self_cond_cfg_scale.reshape(B, 1, 1)
    sc_guidance = (1 - 1 / sc_w) * (v_cond - v_uncond)      # ← Eq 3 第二项
    sc_guidance = torch.where(use_self_cond_mask.bool(),
                              sc_guidance,
                              torch.zeros_like(sc_guidance))
    return (base_v_target + sc_guidance).detach()           # ← .detach() 是关键

# forward #3 (gradient-tracked) 在主 batch forward 里:
net_out, decoder_logits = model(model_input, t_mixed,
                                self_cond_cfg_scale=self_cond_cfg_scale,
                                decoder_step_active=decoder_step_active)
v_pred, _ = net_out_to_v_x(net_out, denoiser_z, t, t_eps=0.05)   # gradient-tracked
v_final_target = get_sc_guided_v(denoiser_z, t, base_v_target=v_target, ...)
l2_per_token = ((v_pred - v_final_target) ** 2).mean(dim=-1)     # MSE

💡 为什么不直接推理时 CFG?

推理时 CFG 要每步跑两遍 forward(cond + uncond),32 步采样 = 64 次 forward。 ELF 把 SC-CFG 烤进训练后,推理只跑 1 次 forward / step,效率 ×2。代价:训练时 3 次 forward。但训练只跑一次(5 epochs),推理跑无数次。划算。

注意 ELF 还另外保留了输入条件的推理时 CFG(label drop 训练的)。两个机制独立: SC-CFG 用 Eq 3 的 self_cond_cfg_scale(推理时通常 = 1,因为已烤入); input-cond CFG 是标准推理时 cond+uncond 组合(XSum/WMT 默认 = 2)。 label drop 的训练信号只上游改变条件 embedding 状态,不进入 Eq 3 的 SC-CFG 公式。

4.5 Per-example branching — 工程实现 vs 论文算法

Paper Algs 3+4 写成两个独立 training step,按 0.8 / 0.2 概率轮换。PyTorch port 改成:

# src/train_step.py — 关键混合 forward
# 每行独立 Bernoulli(0.2) → decoder mode;否则 → denoiser mode
decoder_step_active = torch.bernoulli(
    torch.full((B,), config.decoder_prob, dtype=torch.float32),
    generator=gen,
).to(device=device, dtype=dtype)             # (B,) 1.0=decode 0.0=denoise
decoder_mask_B11 = decoder_step_active.view(-1, 1, 1)
decoder_mask_B1  = decoder_step_active.view(-1, 1)

# t、z 都按 per-example 混合
denoiser_t = t                                # logit-normal per-sample
decoder_t  = torch.ones_like(t)               # 永远 1
t_mixed = decoder_step_active * decoder_t + (1.0 - decoder_step_active) * t
z_mixed = decoder_mask_B11 * decoder_z + (1.0 - decoder_mask_B11) * denoiser_z

# 单次 forward — mode token gate 也是 per-example
net_out, decoder_logits = model(
    model_input, t_mixed,
    self_cond_cfg_scale=self_cond_cfg_scale,
    decoder_step_active=decoder_step_active,   # (B,) per-row gate
)

# CE / L2 各自用 mask 分流
loss_mask_f = loss_mask.to(ce_per_token.dtype)
ce_mask = loss_mask_f * decoder_mask_B1
l2_mask = loss_mask_f * (1.0 - decoder_mask_B1)

# 关键:单一分母归一化(不是两个分母)
total_sum = (ce_per_token * ce_mask).sum() + (l2_per_token * l2_mask).sum()
loss = total_sum / torch.clamp(loss_mask_f.sum(), min=1.0)
# loss_mask_f = ce_mask + l2_mask,所以这等价于 sum/(ce_mask.sum()+l2_mask.sum())
# 注意:pad_token=="pad" 时 loss_mask 屏蔽 padding;pad_token=="eos" 时全 1

4.6 等价性证明(简版)

论文 Algs 3+4 是两个独立 step,按 (1−p):p 比例轮换(p = decoder_prob = 0.2)。 PyTorch port 是同一个 step 内 per-example 抽 mode。先注意: 每个 example 内所有 token 共享同一个 mode 抽样结果(不是 token-wise i.i.d.)。

记 b ∈ (0, 1) 为 example 的 mode 指示(1 = decode),B0 = denoiser 行数,B1 = decoder 行数。 固定 batch、固定 loss denominator M = loss_mask.sum(),PyTorch port 的 loss:

Lcode = [ Σb=1 行 Σtoken CE + Σb=0 行 Σtoken L2 ] / M

对 mode 抽样取期望(每行 Bernoulli(p)):

𝔼[Lcode] = (p · 𝔼row[Σ CE | decode] + (1−p) · 𝔼row[Σ L2 | denoise]) · (B / M)

= 𝔼[Lpaper] · (per-example token 数加权和) — 等价于 paper Algs 3+4 的固定 batch-size 加权期望

注意

4.7 输入条件 CFG — Label drop 机制(XSum / WMT 才有)

条件生成时还需要另一个 classifier-free guidance,针对 input condition(不是 self-cond)。 10% 概率把 cond sequence 直接 drop(zero out),让模型学到 p(x | ∅) 分布:

# src/train_step.py — label drop for input-condition CFG
if config.label_drop_prob > 0:
    drop = label_drop_mask.to(dtype=torch.float32).reshape(-1, 1, 1)  # (B, 1, 1) 0/1
    cond_mask = cond_seq_mask                                          # (B, S)
    # block_mask: 1 仅在 (non-cond row, cond col)
    # 目的:让 target token 看不到 cond token
    block_mask = (1 - cond_mask).unsqueeze(-1) * cond_mask.unsqueeze(1)
    encoder_attention_mask = encoder_attention_mask * (1 - drop * block_mask)

label drop 实际两阶段

  1. 先改 T5 encoder 的 encoder_attention_mask,让 target token 看不到 cond token (block_mask 只在 non-cond row × cond col 上为 1) — 这样 T5 encode 出来的 x₀ 本身就不含条件信息
  2. Encode 之后再把 dropped 行的 denoiser_zx₀ 在 cond 位置清零torch.where(drop & cond_seq_mask, zeros, denoiser_z))— 匹配 paper "zeroing condition embeddings"
4.8 完整训练超参(论文 Table 4 + PyTorch port 实现细节)
类别参数默认值
Optimizer & Schedule OptimizerMuon(2D 参数走 Newton-Schulz)+ Nesterov-AdamW(其余)
LR (peak)0.002(公式:blr=0.001 × global_batch / 256)
LR scheduleconstant after warmup
Warmup0.5 epoch(5 epochs 总 ~95K steps,对应 ~9.5K warmup steps)
Weight decay0(关闭 — Muon 自带 shape scaling)
Grad clip1.0 (norm)
Batching Global batch size512
Sequence length1024
Grad accumulation1 (硬件够大不用)
Diffusion Denoiser P_mean / P_std / noise scale−1.5 / 0.8 / 2.0
Decoder P_mean / P_std / noise scale0.8 / 0.8 / 5.0 (OWT)
Decoder prob (per example)0.2 (denoiser 0.8)
Self-cond prob0.5(denoiser only;decoder 永远 0)
CFG SC-CFG range[0.5, 5],power-bias sample 偏小值
SC-CFG tokens4
Label drop prob (cond only)0.1(XSum/WMT),0(OWT)
Numerics Precisionbf16 autocast;输出头强制 fp32
EMA decay0.9999
Random seed42(per-rank seed + rank offset,让噪声 desync)
训练量 OWT epochsELF-B: 5 ELF-M: 4 ELF-L: 3(大模型收敛快)
OWT 总 tokens≈ 45.2B(OWT 数据集约 9.04B × 5 ep)
HardwareTPU v5p × 64,1.5h/epoch(ELF-B)
4.9 训练 token 用量对比(Table 5)— ELF 的 12× 数据效率 claim 来源
MethodBase trainingDistillation trainingEffective tokensRatio vs ELF
MDLM512 × 1M × 1024524.3B11.6×
Duo512 × 1M × 1024524.3B11.6×
MDLM + SDTT512 × 1M × 1024512 × 10K × 5 × 1024550.5B12.2×
Duo + DCD512 × 1M × 1024512 × 10K × 5 × 1024550.5B12.2×
FLM512 × 1M × 1024524.3B11.6×
FMLM512 × 1M × 1024512 × 100K × 1024576.7B12.8×
LangFlow512 × 1M × 1024524.3B11.6×
ELF (ours)5 × 9.04B45.2B1.0×

Baseline 估算公式:batch_size × n_steps × seq_length。ELF 用 OWT 总 token × epochs。 精确 ratio 是 11.6× / 12.2× / 12.8×,paper 文字简称"约 12×"。 注意:ELF 的"45.2B"不包括 T5-small 预训练(Google 用了 1T+ tokens 训 T5)—— 这是 paper 最容易被 attack 的地方。

4.10 Muon 优化器简介

Muon = "Momentum + Newton-Schulz orthogonalization",2024 Keller Jordan 推。核心思想:

ELF 的 PyTorch port 用的是 PyPI muon-optimizer 包,加几层 wrapper / patches: (a) Nesterov-bias-corrected Adam update(替换上游 adam_update); (b) NS5 强制 fp32 + eps 1e-8(替换上游 bf16 + 1e-7); (c) Muon update 重写 + shape scaling layout 修正(区分 nn.Linear 和 bare Parameter); (d) _SafeMuonAuxAdam subclass:zero-fill missing grads + distributed all_gather padding 修复。 全部为了匹配 JAX optax.contrib.muon

论文 App C.5 ablation:Muon vs AdamW,SDE 采样下差距尤其大(paper 只定性写"more pronounced under SDE",没给数字)。

⚠️ 复现注意(来自我 + Codex 跨模型审计)

5 · 架构、参数量、shape(最详细一节)

5.1 三种 ELF 大小(论文 Tab 3)
ModelDepthHiddenHeadshead_dimMLP ratioBottleneckParams (DiT)OWT Epochs
ELF-B127681264128105M5
ELF-M2410561666128342M4
ELF-L3212801680128652M3

三个模型共享同一份 frozen T5-small encoder (~35.3M 参数,encoder-only,不参与训练)。条件生成(XSum / WMT14)时 T5-small encoder 也用来编码 source context — 因此条件 ELF-B 报作 "105M+35M"。 Decoder 没有独立训练 — denoise 和 decode 用同一份 Transformer 权重,靠 model-mode token 切换。

5.2 T5-small encoder 规格(来自 HuggingFace t5-small
Vocab size32128(SentencePiece vocabulary)
Layers6 (encoder only)
Hidden d_model512
Attention heads8 (head_dim 64)
FFN d_ff2048
ActivationReLU (非 gated)
Params~35M
训练时整个 encoder 都 requires_grad_(False),作 frozen embedder

训练 ELF 时永远只 forward 这个 encoder 一次(per batch),bf16 autocast 跑,last_hidden_state shape [B, L, 512]。然后做归一化:(x − latent_mean) / latent_std, 默认 latent_std = 0.2 等于把 raw T5 embedding 放大 5×。

5.3 控制 token 设计(12 个,prepend 到序列前)

不在 vocabulary 里,是可学的 nn.Parameter + 输入相关 embedding 的加和

类别个数编码作用
t_emb_tokens4learnable_tokens[1,4,d] + TimestepEmbedder(t)[B,d].unsqueeze(1)把当前扩散时间 t∈[0,1] 注入
self_cond_cfg_tokens4learnable_tokens[1,4,d] + TimestepEmbedder(ω)[B,d].unsqueeze(1)把 self-cond CFG scale ω∈[0.5, 5] 注入(训练时随机抽,推理时固定)
mode_tokens4learnable_tokens[1,4,d] × active_gatedenoise mode → gate=0(token 被乘 0);decode mode → gate=1(token 激活)

三组合起来共 12 个 prefix tokens。最终序列顺序是 [time(4) + sc_cfg(4)] + [mode(4)] + [main_x(L)] —— 代码先 cat([mode, main]),再 prepend [time + sc_cfg](见 5.12 forward 第 4-5 步)。 重要细节:RoPE 在所有 prefix(12 个) 上 cos=1, sin=0(不旋转), 主序列从位置 0 开始正常 RoPE 编码。这样添加 prefix 不会破坏 main token 的相对位置。

5.4 Bottleneck 设计(App C.2 ablation)

T5 embedding 512-d → bottleneck → DiT hidden。直觉:clean 文本数据其实在 低维流形上。 论文 App C.2 sweep 了 bottleneck ∈ (32, 128, 512):

App C.2 Fig 11 — Bottleneck ablation
Fig 11 (paper p.21, App C.2): bottleneck ∈ &#123;32, 128, 512&#125; 在 ODE 和 SDE 下的 Gen-PPL ↔ entropy 曲线。32-d 在 SDE 下能压到最低 PPL 但落到红色 entropy &lt; 5 的退化区;512-d 维持高 entropy 但 PPL 飙升;<strong>128-d 是 frontier 平衡点</strong>,所以是 ELF 默认。
class BottleneckTextProj(nn.Module):
    def __init__(self, text_encoder_dim, hidden_size, bottleneck_dim):
        super().__init__()
        self.proj1 = nn.Linear(text_encoder_dim, bottleneck_dim, bias=False)  # 512 → 128
        self.proj2 = nn.Linear(bottleneck_dim, hidden_size,    bias=True)     # 128 → 768

    def forward(self, x):                       # [B, L, 512]
        return self.proj2(self.proj1(x))        # [B, L, 768]
5.5 RoPE — prefix 不旋转的 trick
class TextRotaryEmbeddingFast(nn.Module):
    def __init__(self, dim, pt_seq_len=512, ft_seq_len=None,
                 theta=10000.0, num_empty_token=0):
        super().__init__()
        ft_seq_len = ft_seq_len or pt_seq_len
        # 标准 RoPE 频率
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:dim//2].float() / dim))
        # 位置缩放支持 fine-tune 时 seq_len 与 pretrain 不同
        pos   = torch.arange(ft_seq_len).float() / ft_seq_len * pt_seq_len
        freqs_main = torch.einsum('..., f -> ... f', pos, freqs)
        freqs_main = repeat(freqs_main, '... n -> ... (n r)', r=2)

        # prefix 的 cos=1, sin=0 → 不旋转
        if num_empty_token > 0:
            cos_prefix = torch.ones((num_empty_token, freqs_main.shape[-1]))
            sin_prefix = torch.zeros_like(cos_prefix)
            freqs_cos = torch.cat([cos_prefix, torch.cos(freqs_main)], dim=0)
            freqs_sin = torch.cat([sin_prefix, torch.sin(freqs_main)], dim=0)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

    def forward(self, t):                       # [B, n_heads, L_total, head_dim]
        cos = self.freqs_cos.to(t.dtype)        # 显式 dtype cast 防 bf16 精度漂移
        sin = self.freqs_sin.to(t.dtype)
        return t * cos + rotate_half(t) * sin
5.6 TimestepEmbedder — sinusoidal + 2-layer MLP
class TimestepEmbedder(nn.Module):
    # t (scalar in [0,1]) -> hidden vector (e.g., 768)
    # Init: MLP weights normal(0.02), biases zero
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp_0 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
        self.mlp_2 = nn.Linear(hidden_size, hidden_size, bias=True)

    @staticmethod
    def timestep_embedding(t, dim=256, max_period=10000):
        half = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(half).float() / half)
        args  = t[:, None].float() * freqs[None]
        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)   # [B, 256]

    def forward(self, t):                       # [B]
        emb = self.mlp_0(self.timestep_embedding(t, 256))   # [B, 768]
        return self.mlp_2(F.silu(emb))                       # [B, 768]
5.7 Attention — qk-norm + RoPE 都加上
class Attention(nn.Module):
    def __init__(self, dim=768, num_heads=12, qkv_bias=True, qk_norm=True,
                 attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim       = dim // num_heads                  # 64 for ELF-B
        self.qkv       = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm    = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.k_norm    = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.proj      = nn.Linear(dim, dim, bias=True)

    def forward(self, x, rope_fn, attention_mask=None):
        B, N, C = x.shape                                   # N = 1036 在 ELF-B 训练时
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)                    # [3, B, n_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = self.q_norm(q)                                   # qk-norm(论文与官方实现都开)
        k = self.k_norm(k)
        if rope_fn is not None:                              # 应用 RoPE(含 prefix-no-rotation)
            q = rope_fn(q)
            k = rope_fn(k)
        # 实际 layers.py 里包了一层 wrapper:int/float mask -> bool mask 再传 SDPA
        x = scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
        return self.proj(x.permute(0, 2, 1, 3).reshape(B, N, C))

注意三件事:(1) qk_norm=True 是 ELF 默认开(提升 bf16 训练稳定性,从 Henry et al. EMNLP'20); (2) 注意 F.scaled_dot_product_attention 内部用 PyTorch 2.x flash kernel; 源码 wrapper 在 mask 是 2D/3D 时 reshape 加 head 维并 cast 为 bool; (3) ELF 训练/采样调用模型时都不传 attention_mask(cond+target 全双向 attention), T5 encoder 那一侧才有 cond/target 不对称 mask。

5.8 SwiGLU FFN
class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0):
        super().__init__()
        # SwiGLU 标准做法:把 hidden 缩到 2/3 保持 param count
        hidden_dim_eff = int(hidden_dim * 2 / 3)             # 768*4 = 3072 → 2048
        self.w12 = nn.Linear(dim, 2 * hidden_dim_eff, bias=True)   # 768 → 4096
        self.w3  = nn.Linear(hidden_dim_eff, dim, bias=True)       # 2048 → 768

    def forward(self, x):                                    # [B, N, 768]
        x1, x2 = self.w12(x).chunk(2, dim=-1)                # 各 [B, N, 2048]
        return self.w3(F.silu(x1) * x2)                      # [B, N, 768]
5.9 FinalLayer — zero-init
class FinalLayer(nn.Module):
    # Last layer that maps hidden 768 -> embedding output 512.
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(hidden_size)
        # 关键: kernel + bias 都用 0 初始化(DiT 标配)
        # → 开局时模型预测的 clean x_pred ≡ 0;
        #   velocity 由后处理 v=(x_pred - z)/(1 - t) 计算
        self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels, bias=True)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x):                                    # [B, L, 768]
        return self.linear(self.norm_final(x))               # [B, L, 512]
5.10 ELFBlock — 标准 Pre-Norm Transformer block
class ELFBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size, eps=1e-6)
        self.attn  = Attention(hidden_size, num_heads, qkv_bias=True, qk_norm=True)
        self.norm2 = RMSNorm(hidden_size, eps=1e-6)
        self.mlp   = SwiGLUFFN(hidden_size, int(hidden_size * mlp_ratio))

    def forward(self, x, rope_fn, attention_mask=None):
        x = x + self.attn(self.norm1(x), rope_fn, attention_mask)
        x = x + self.mlp(self.norm2(x))
        return x

注意是 Pre-Norm + 残差(RMSNorm 在 attn/mlp 之前)。这是 LLaMA / DiT / GPT-J 等长 bf16 训练的常见稳定选择。

5.11 ELF.__init__ — 完整参数声明(删减为关键部分)
class ELF(nn.Module):
    # 源码默认 num_model_mode_tokens=0, vocab_size=0
    # train.py 实际从 config / tokenizer 把 4 / 32128 传进来
    def __init__(self, text_encoder_dim=512, max_length=1024,
                 hidden_size=768, depth=12, num_heads=12,
                 mlp_ratio=4.0, bottleneck_dim=128,
                 num_time_tokens=4, num_self_cond_cfg_tokens=4,
                 num_model_mode_tokens=4, vocab_size=32128):
        super().__init__()
        # Self-conditioning projection: [z; x_pred] (2×512) -> 512
        self.self_cond_proj = nn.Linear(2 * text_encoder_dim, text_encoder_dim, bias=True)

        # Bottleneck text projection (512 -> 128 -> 768)
        self.text_proj = BottleneckTextProj(text_encoder_dim, hidden_size, bottleneck_dim)

        # Prefix tokens + their (input-dependent) embedders
        self.t_embedder      = TimestepEmbedder(hidden_size)
        self.t_emb_tokens    = nn.Parameter(torch.empty(1, num_time_tokens, hidden_size))

        self.self_cond_cfg_embedder = TimestepEmbedder(hidden_size)
        self.self_cond_cfg_tokens   = nn.Parameter(
            torch.empty(1, num_self_cond_cfg_tokens, hidden_size))

        self.mode_tokens = nn.Parameter(torch.empty(1, num_model_mode_tokens, hidden_size))

        # RoPE with prefix no-rotation
        head_dim     = hidden_size // num_heads
        prefix_total = num_time_tokens + num_self_cond_cfg_tokens + num_model_mode_tokens
        self.feat_rope = TextRotaryEmbeddingFast(
            dim=head_dim, pt_seq_len=max_length, num_empty_token=prefix_total)

        # 12 ELFBlocks
        self.blocks = nn.ModuleList([
            ELFBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth)])

        # Flow-matching output head (zero-init)
        self.final_layer = FinalLayer(hidden_size, patch_size=1, out_channels=text_encoder_dim)

        # Factored decoder unembedding: 768 -> 512 (GELU) -> vocab
        self.proj_kernel    = nn.Parameter(torch.empty(hidden_size, text_encoder_dim))
        self.proj_bias      = nn.Parameter(torch.empty(text_encoder_dim))
        self.unembed_kernel = nn.Parameter(torch.empty(text_encoder_dim, vocab_size))
        self.unembed_bias   = nn.Parameter(torch.empty(vocab_size))
5.12 ELF.forward — 完整 forward(删减+加注释)
def forward(self, x, t, self_cond_cfg_scale=None, decoder_step_active=None):
    # x: [B, L, 512]  if no self-cond, OR  [B, L, 1024]  if self-cond cat([z, x_pred])
    # t: [B]                                            扩散时间 ∈ [0,1]
    # self_cond_cfg_scale: [B] or None                  ω 用 SC-CFG token 编码
    # decoder_step_active: None | True/False | Tensor[B]   控制 model-mode token gate
    B = x.shape[0]

    # ====== 1. self-cond projection: 2C -> C ======
    if x.shape[-1] == 2 * self.text_encoder_dim:       # 训练 & 推理都走这里
        x = self.self_cond_proj(x.float())             # [B, L, 1024] -> [B, L, 512]

    # ====== 2. bottleneck text projection ======
    x = self.text_proj(x.float())                      # [B, L, 512] -> [B, L, 768]

    # ====== 3. build prefix context tokens ======
    time_emb = self.t_embedder(t)                      # [B, 768]
    prefix   = self.t_emb_tokens.expand(B, -1, -1) + time_emb.unsqueeze(1)  # [B, 4, 768]
    if self_cond_cfg_scale is not None:
        sc_emb  = self.self_cond_cfg_embedder(self_cond_cfg_scale)          # [B, 768]
        prefix2 = self.self_cond_cfg_tokens.expand(B, -1, -1) + sc_emb.unsqueeze(1)
        prefix  = torch.cat([prefix, prefix2], dim=1)                       # [B, 8, 768]

    # ====== 4. model-mode tokens with per-example gating ======
    if decoder_step_active is None:
        gate = 0.0                                     # 默认 denoise: tokens 被乘 0
    elif isinstance(decoder_step_active, torch.Tensor):
        gate = decoder_step_active.view(-1, 1, 1)      # [B, 1, 1]  per-example
    else:
        gate = float(decoder_step_active)              # 1.0 in final decode
    mode_tokens = self.mode_tokens.expand(B, -1, -1) * gate                 # [B, 4, 768]
    x = torch.cat([mode_tokens, x], dim=1)                                  # [B, L+4, 768]
    model_mode_offset = self.num_model_mode_tokens     # = 4

    # ====== 5. prepend (time + sc-cfg) tokens ======
    # 最终顺序: [time(4), sc-cfg(4), mode(4), main(L)]
    x = torch.cat([prefix, x], dim=1)                                       # [B, L+12, 768]
    prefix_len = prefix.shape[1]                       # = 8 (time + sc-cfg)

    # ====== 6. 12 ELFBlocks with RoPE ======
    for block in self.blocks:
        x = block(x, rope_fn=self.feat_rope, attention_mask=None)           # full bidi

    # ====== 7. strip prefix (动态 = prefix_len + model_mode_offset) ======
    x = x[:, prefix_len + model_mode_offset:]                                # [B, L, 768]

    # ====== 8. flow-matching head (always computed) ======
    flow_output = self.final_layer(x.float())                               # [B, L, 512]

    # ====== 9. decoder head (only if decoder_step_active) ======
    decoder_logits = None
    if decoder_step_active is not None:
        x_f32  = x.float()
        hidden = F.gelu(x_f32 @ self.proj_kernel + self.proj_bias, approximate="tanh")
        decoder_logits = hidden @ self.unembed_kernel + self.unembed_bias   # [B, L, 32128]

    return flow_output, decoder_logits
5.13 完整 forward shape 走查(ELF-B, B=512, L=1024)
步骤张量shape说明
0input_ids[512, 1024]tokenized text,int64
1T5 encoder output[512, 1024, 512]frozen contextual embedding
2normalized x₀[512, 1024, 512]除以 latent_std=0.2(相当于 ×5)
3noisy z_t[512, 1024, 512]z = t·x₀ + (1−t)·ε·2.0;t = sigmoid(N(−1.5, 0.8²)) logit-normal
4self-cond input[512, 1024, 1024]cat([z_t, x_pred]),channel-wise concat,2×512
5self_cond_proj output[512, 1024, 512]线性投影回 512
6bottleneck proj1[512, 1024, 128]512 → 128(无 bias,强约束)
7bottleneck proj2[512, 1024, 768]128 → 768
8time emb[512, 768]sinusoidal(256) → MLP(768)
9time prefix tokens[512, 4, 768]learnable 加上 time_emb 广播
10sc-cfg prefix tokens[512, 4, 768]learnable 加上 ω embedding
11mode tokens (gated)[512, 4, 768]per-example gate 0 (denoise) 或 1 (decode)
12concat: mode + main[512, 1028, 768]中间态:mode 暂时在最前
13concat: prefix + above[512, 1036, 768]最终顺序 [time(4) + sc-cfg(4) + mode(4) + main(1024)]
14RoPE 应用于 q,k[512, 12, 1036, 64]n_heads=12, head_dim=64;prefix 位置 cos=1, sin=0
15each ELFBlock output[512, 1036, 768]共 12 个 block,shape 不变
16strip prefix[512, 1024, 768]取后 1024 个 token
17FinalLayer (flow head)[512, 1024, 512]RMSNorm + Linear 768→512,预测 clean x₀
18a(decode only) proj[512, 1024, 512]x_f32 @ proj_kernel + proj_bias,再 GELU(tanh)
18b(decode only) logits[512, 1024, 32128]hidden @ unembed_kernel + unembed_bias
5.14 ELF-B 参数量分解(实测 ~105M)
组件计算参数
self_cond_proj1024×512 + 512524,800
BottleneckTextProj.proj1512×128 (no bias)65,536
BottleneckTextProj.proj2128×768 + 76899,072
TimestepEmbedder (×2: time, sc-cfg)(256×768+768) + (768×768+768) ≈ 788K×2 = 1,575,936
t_emb_tokens + sc_cfg_tokens + mode_tokens3 × (4×768)9,216
每个 ELFBlock~7.09M(见下)
  RMSNorm × 22 × 7681,536
  Attention.qkv768 × 2304 + 23041,771,776
  Attention.q_norm + k_norm2 × 64128
  Attention.proj768 × 768 + 768590,592
  SwiGLU.w12768 × 4096 + 40963,149,824
  SwiGLU.w32048 × 768 + 7681,573,632
  子总(一个 block)7,087,488
12 个 ELFBlock12 × 7.09M85,049,856
FinalLayer (norm + linear)768 + 768×512 + 512394,496
proj_kernel + proj_bias (decoder)768×512 + 512393,728
unembed_kernel + unembed_bias512×32128 + 3212816,481,664
合计(trainable)各行精确相加104,594,304 (~105M, 假设 vocab=32128)
+ T5-small encoder (frozen)~35M

关键观察:decoder unembedding 占 16.5M(~16%),这是 32128 词表的主要开销。 而 12 个 transformer block 占 85M(~81%)—— 真正的核心。 self_cond_proj 只占 0.5%,但它是 self-conditioning trick 能 work 的关键 plumbing。

✅ 总结架构哲学

6 · 推理 / 采样 — 完整代码 + 时间网格 + Tab 6/7 全数字

6.1 推理主流程

给定 (B, L) 和采样配置(method, n_steps, cfg, sc_cfg, γ):

# src/utils/generation_utils.py — _generate_samples_single_batch (简化)

@torch.no_grad()
def _generate_samples_single_batch(model, generator, z, t_steps,
                                   cond_seq, cond_seq_mask,
                                   config, sampling_config,
                                   cfg_scale, self_cond_cfg_scale):
    method = sampling_config.sampling_method                # 'ode' or 'sde'
    B, L, d = z.shape                                       # d = 512

    if cond_seq is None:                                    # OWT 无条件生成
        cond_seq = torch.zeros((B, L, d), dtype=z.dtype, device=z.device)
        cond_seq_mask = torch.zeros((B, L), dtype=z.dtype, device=z.device)

    # 在条件位置上把 z 设回 clean cond_seq(cond 不去噪)
    z      = restore_cond(z, cond_seq, cond_seq_mask)
    x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)

    n = t_steps.shape[0]                                    # = n_steps + 1
    sde_gamma = getattr(sampling_config, "sde_gamma", 0.0)

    use_bf16 = config.use_bf16 and z.is_cuda
    with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
        # n_steps - 2 个中间步用 ODE/SDE
        for i in range(n - 2):
            t       = t_steps[i].item()
            t_next  = t_steps[i + 1].item()
            if method == "sde":
                z, x_pred = _sde_step(z, t, t_next, x_pred, gamma=sde_gamma, ...)
            else:                                           # 'ode'
                z, x_pred = _ode_step(z, t, t_next, x_pred, ...)

        # 最后一步强制 ODE — t 接近 1 不再注入新噪声
        z, x_pred = _ode_step(z, t_steps[-2], t_steps[-1], x_pred, ...)
    return z                                                 # [B, L, 512]


@torch.no_grad()
def _dlm_decode_batch(z, model, t_final_val, config, self_cond_cfg_scale):
    # 最终一步把 latent z (≈ clean embedding) 映射回 token IDs
    B = z.shape[0]
    t_final = torch.full((B,), float(t_final_val), dtype=z.dtype, device=z.device)
    sc_batch = (torch.full((B,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device)
                if config.num_self_cond_cfg_tokens > 0 else None)

    # 推理时 self-cond 输入永远 = zeros(与训练 decoder 分支一致)
    z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z

    with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=config.use_bf16):
        _, decoder_logits = model(
            z_input, t_final,
            self_cond_cfg_scale=sc_batch,
            decoder_step_active=True,                       # ← mode token gate = 1
        )
    return decoder_logits.argmax(dim=-1)                    # [B, L]

关键点:主循环 + 最终 decode 是两次独立 forward。前者把噪声 ε transport 到接近 clean x; 后者把 clean embedding 投影到 vocab。

6.2 ODE step — 标准 Euler(caller 在 generation.py 里先 init z = randn × denoiser_noise_scale)

def _ode_step(model, z, t, t_next, x_pred_prev,
              config, cfg_scale, self_cond_cfg_scale,
              cond_seq, cond_seq_mask):
    t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device)
    v_pred, x_pred = _forward_sample(
        model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev,
        config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    return z + (t_next - t) * v_pred, x_pred                # Euler: z_{i+1} = z_i + dt · v

6.3 SDE step — PyTorch branch(与论文伪码 Alg 6 update 基点略有差异:从 z_back Euler,不是从原 z)

def _sde_step(model, z, t, t_next, x_pred_prev,
              config, cfg_scale, self_cond_cfg_scale,
              cond_seq, cond_seq_mask, gamma, generator):
    h      = float(t_next - t)
    alpha  = max(0.0, min(1.0, 1.0 - gamma * h))            # 信号保留比例 ∈ [0,1]
    t_back = alpha * float(t)                               # 时间往回拉到 α·t

    eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale
    z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask)

    t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device)
    v_pred, x_pred = _forward_sample(
        model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev,
        config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    # 从 backtracked state 用 Euler 一步推到 t_next
    return z_back + (t_next - t_back) * v_pred, x_pred

三个直觉:

  1. γ = 0: alpha = 1, t_back = t, z_back = z — 退化为 ODE Euler
  2. γ > 0: alpha < 1,z 被 α 缩小并注入新噪声 (1-α)·ε — 相当于把"已经去噪一些"的状态拉回到更早时刻,再 denoise 一次
  3. 因此 SDE 是用额外随机性"修正"早期 denoise 错误。Tab 7 同 CFG 下 SDE 比 ODE PPL 低 12-35%(详见 6.8)

6.4 时间网格 — 训练同分布的 logit-normal

# 函数签名默认参数是 P_mean=-0.8(旧默认)
# 但 caller 在 generation.py 中传入 config.denoiser_p_mean = -1.5(OWT 训练值)
def get_sampling_steps(n_steps, time_schedule="logit_normal",
                       P_mean=-0.8, P_std=0.8, device=None, dtype=torch.float32):
    if time_schedule == "uniform":
        return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device)

    # logit-normal:从训练同分布抽 n_steps - 1 个点
    z = torch.randn((n_steps - 1,), dtype=dtype, device=device) * P_std + P_mean
    steps = torch.sigmoid(z)
    steps = torch.sort(steps).values                        # 升序排列
    # 强制端点 0 / 1
    lo = torch.zeros((1,), dtype=dtype, device=steps.device)
    hi = torch.ones((1,),  dtype=dtype, device=steps.device)
    return torch.cat([lo, steps, hi], dim=0)                # [n_steps + 1]

论文 App B.2:"smaller intervals when t is close to 0 and larger intervals as t approaches 1"。 P_mean=−1.5 使 sigmoid(N(−1.5, 0.64)) 的密度向 t 较小的 noisy 区集中, 所以中间 sample 点偏小 → 排序后noisy 区间隔细密、clean 区间隔粗,与训练 t 同分布匹配。

6.5 _forward_sample — 套两层 CFG

真实的每步 forward要处理两层 CFG:

# sampling_utils.py — 双层 CFG 嵌套

def _forward_sample(model, z, t_batch, x_pred_prev, config,
                    cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask):
    # 内层:带 cond 的 forward(条件 token + SC-CFG)
    v_cond, x_cond = _forward_sample_self_cond(
        model, z, t_batch, x_pred_prev, config,
        self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
    )
    if cfg_scale == 1.0:
        return v_cond, x_cond                       # 无 input-cond CFG,直接返回

    # 外层:input-cond CFG → 再跑一次 uncond
    z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask)
    x_pred_prev_uncond = (None if x_pred_prev is None else
                          restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask))
    v_uncond, x_uncond = _forward_sample_self_cond(
        model, z_uncond, t_batch, x_pred_prev_uncond, config,
        self_cond_cfg_scale=self_cond_cfg_scale,
        cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask,
    )
    # 标准 CFG 线性组合
    v_out = v_uncond + cfg_scale * (v_cond - v_uncond)
    x_out = x_uncond + cfg_scale * (x_cond - x_uncond)
    return restore_vx(v_out, x_out, cond_seq, cond_seq_mask)

每步代价(worst case,cond 生成):2 次模型 forward(cond + uncond)。 OWT 无条件生成 cfg_scale=1,每步只1 次 forward。SC-CFG 被烤进训练所以不用 ×2。

6.6 单步算力 vs Baseline

方法每步 forward 数32 步总 forward说明
传统 inference-time CFG 方法(启用时)2 (cond + uncond)常用于 MDLM/LLaDA/Duo 等启用 CFG 的配置
ELF OWT 无条件 (SC-CFG baked-in, cfg=1)132 + 1 decodeSC-CFG 单次 forward 已含 ω 信息
ELF 条件 (input-cond CFG=2)2 (cond + uncond)64 + 1 decodeSC-CFG=1 不额外 ×2,input-cond CFG 才 ×2

6.7 论文 Tab 6 — System-level @ OWT (6 seeds)

论文 page 26 Table 6(6 个 evaluation seed 平均 ± SE):

StepsSC-CFGγGen-PPL ↓Entropy ↑
832.067.32 ± 2.255.14 ± 0.085
1632.033.66 ± 1.095.16 ± 0.026
3231.524.08 ± 0.165.15 ± 0.002

三个观察:

  1. 32 步 SE 极小 (0.16)——结果非常稳定,不是单 seed 运气
  2. γ 从 2.0 → 1.5 — 步数多时不需要那么多噪声注入
  3. 从 8 到 32 步 PPL 砍 ⅔(67.32 → 24.08),但 entropy 几乎不变(~5.15)

6.8 论文 Tab 7 — Scaling × CFG × Sampler(64 步)

论文 page 26 Table 7。三个 size × 两种采样器 × CFG sweep。SDE 全方位优于 ODE。 表内灰色项是 entropy < 5.0(多样性不足,不算 valid):

SamplerSC-CFGELF-B 105M (PPL/Ent)ELF-M 342M (PPL/Ent)ELF-L 652M (PPL/Ent)
SDE (γ=1.0) 0.536.77 / 5.2839.21 / 5.3537.50 / 5.41
1.029.50 / 5.2333.45 / 5.3031.82 / 5.37
1.525.25 / 5.1828.42 / 5.2628.72 / 5.35
2.022.53 / 5.1425.34 / 5.2326.47 / 5.32
3.019.72 / 5.1021.69 / 5.1823.31 / 5.28
3.537.56 / 5.30 ⁱ36.48 / 5.34 ⁱ22.28 / 5.27
4.036.50 / 5.29 ⁱ34.93 / 5.33 ⁱ21.37 / 5.26
ODE 0.5104.29 / 5.5188.51 / 5.5168.27 / 5.52
1.065.30 / 5.4062.47 / 5.4449.72 / 5.45
1.544.85 / 5.3146.71 / 5.3739.97 / 5.40
2.034.65 / 5.2337.66 / 5.3233.72 / 5.36
3.026.62 / 5.1528.80 / 5.2426.57 / 5.29

ⁱ = 论文表中标灰的 cell:CFG > 3 后 ELF-B / ELF-M 出现 PPL 反转/上升, 不是 entropy < 5.0;它们 entropy 仍 ≥5.29。论文 App C 同时用 entropy < 5 或 PPL > 300 作"poor generation"红区。

关键观察(数字均来自 Tab 7 valid 区间内):

6.9 条件生成默认 (XSum / WMT)

src/configs/sampling_configs/cond_sampling_configs.yml

- sampling_method: ode
  num_sampling_steps: [64]
  cfgs: [2]                  # input-cond CFG = 2
  self_cond_cfg_scales: [1]  # SC-CFG = 1(已烤进训练)
  time_schedule: logit_normal

条件生成默认使用 ODE(paper/config 默认;论文未明确解释,可理解为条件任务在标准 CFG=2 下已经稳定,不需要 SDE 额外随机性)。 SC-CFG=1 因为推理时不需要额外 ω 调整;input-cond CFG=2 是标准 image diffusion 推理时 CFG

6.10 PyTorch port 与 JAX 原版的数值漂移承诺

README 承诺:PyTorch port 在 8× L40S / H200 上跑应与 paper TPU v5p-64 数字漂移 ≲ 1 Gen-PPL< 0.5 BLEU/ROUGE。漂移来源:

README 强调用 use_bf16=true(匹配训练 precision)和 use_compile=true(torch.compile ~3-4× speedup)作为推荐 eval flag。

💡 SDE γ 的微妙

论文 Alg 6 写 α = 1 − γ·dt但 dt 不是 1/N(uniform 间隔),是 logit-normal 抽出来的—— 不同步的 dt 跨度差很大(t 靠近 0 时 dt 可能 0.01,靠近 1 时可能 0.4)。 所以同一个 γ 在不同 step 里实际"重置强度"完全不同。代码用 clip(α, 0, 1),仅在 γ·dt ≥ 1 时才把 α clip 到 0: 比如 32-step 默认 γ=1.5,最大 dt 通常 ~0.4,γ·dt = 0.6 不会 clip; 但 8-step γ=2.0 时如果 dt ≥ 0.5 就会 clip。 这是 ELF 实现的关键细节,建议念 paper 时跟代码对照(src/utils/sampling_utils.py:226-251)。

7 · 主结果 — 全部数字

7.1 Scaling: 三个 size 都改善 Gen-PPL / Entropy frontier

Scaling Fig 6
Fig 6 (paper p.8): ELF-B/M/L 在 Gen-PPL ↔ Entropy 平面上的整条 frontier 都改善。同熵下大模型 PPL 更低;同 PPL 下大模型熵更高(多样性更高)。SDE 在所有尺度都比 ODE 更优。

Tab 7 各 size 内自身最低 valid PPL(valid = 落在 entropy 合理区): ELF-B SDE CFG=3 → 19.72;ELF-M SDE CFG=3 → 21.69;ELF-L SDE CFG=4 → 21.37。 绝对最低是 ELF-B 的 19.72;ELF-L 21.37 比 ELF-M 21.69 更低。 但 paper 强调 frontier 而非单点最低:scaling 整体向更优方向推进(同 entropy 下更低 PPL,同 PPL 下更高 entropy)。

7.2 系统级对比(Fig 7)— 三个独立维度全面胜出

System-level Fig 7
Fig 7 (paper p.8): (a) 步数效率 — ELF-B 32 步 ≈ baseline 1024 步;(b) 即使对比蒸馏过的 MDLM+SDTT / Duo+DCD / FMLM,ELF(无蒸馏)依然 win;(c) 训练 token 用量 — ELF 45B (1×) vs baselines 524–578B (12×)。

Fig 7(a) — Gen-PPL vs Sampling Steps

ELF-B 32 步达到 Gen-PPL ≈ 24。从 Fig 7(a) 视觉读数:ELF-B 32-step 已接近或优于若干 baseline 在高 step(如 1024)下的水平, 推理时间 substantially less than prior methods(paper §4.2 原文)。

Fig 7(b) — Distillation 后的 baseline 也输给原版 ELF

三种蒸馏过的 few-step variant:

这些都需要额外蒸馏阶段(10K-100K extra steps),但 32 步 PPL 还是不如未蒸馏的 ELF-B。 即在这套系统配置下,ELF 不加额外 distillation 仍然超过这些 distilled baselines("架构层面的优势"是我对这个现象的解读,非论文逐字 claim)。

Fig 7(c) — Training token 预算(柱状)

详见 4.9 节 Tab 5:ELF 45.2B 总 token,所有 baseline 都在 524-577B 区间, 其中蒸馏 variant 因为还要加蒸馏 epoch,token 更多。精确 ratio 11.6× / 12.2× / 12.8×(paper 简称 ~12×)。 Tab 5 只统计 ELF/baseline 自身训练与蒸馏 token,不包含外部 encoder 预训练成本

7.3 条件生成(Tab 1)— 2 任务 / 4 指标全部表内最佳

Tab 1 conditional results
Table 1 (paper p.9): ELF-B (105M+35M) 在 WMT14 De-En 拿 <strong>BLEU 26.4</strong>,XSum 拿 <strong>R-1 36.0 / R-2 12.2 / R-L 27.8</strong>,超过 MDLM/Duo/E2D2/SeqDiffuSeq/CDCD/AR baseline。
ModelSizeDe-En BLEU ↑XSum R-1 ↑R-2 ↑R-L ↑
AR (Transformer)99M25.230.5 ± 0.1310.2 ± 0.1124.4 ± 0.12
MDLM99M18.433.4 ± 0.1111.6 ± 0.1025.8 ± 0.10
Duo170M+35M21.3 ‡31.4 ± 0.1210.1 ± 0.1025.0 ± 0.12
E2D299M24.828.4 ± 0.118.3 ± 0.0922.0 ± 0.10
SeqDiffuSeq21.319.3 †1.7 †14.1 †
CDCD24.9
ELF-B (ours)105M+35M26.436.0 ± 0.1312.2 ± 0.1127.8 ± 0.12

† = 直接取自该方法原 paper(De-En 数据的默认来源); ‡ = ELF 团队用公开 codebase 重跑(XSum 数据的默认来源);Duo De-En 在 ELF 团队的对比里也是 ‡(重跑)。

重要观察:

7.4 关键 ablation 汇总(App C 全部 7 个 ablation)

论文 App C (pages 20-23) 系统 ablate 7 个设计选择。这是"为什么 ELF 能 work"的实证支撑:

AblationSweep结论 / 默认差距
C.1 Prediction targetx-pred / v-pred / ε-predx-pred 全 dim 稳定ε-pred 全 dim 都 collapse(512/768/1024);v-pred 在 512 dim ok,越高越差
C.2 Bottleneck dim32 / 128 / 512128 最佳 frontier32 偏低 entropy;512 偏高 PPL;128 balance
C.3 Denoiser mode prob0.2 / 0.5 / 0.80.8 (denoise) / 0.2 (decode)0.5 / 0.2 (denoise) PPL/entropy frontier 都明显劣化
C.4 Conditioning stylein-context tokens / adaLN-Zeroin-context 略优 + 省 43M 参数性能 ≈ adaLN,但 ELF-B 148M → 105M
C.5 OptimizerMuon / AdamWMuon 全面优于 AdamWSDE 下差距最显著(paper 定性 "more pronounced under SDE")
C.6 Sampler + time gridODE / SDE × uniform / logit-normalSDE + logit-normalSDE 通常降低 PPL(幅度随 model/CFG 变化,CFG=2 时 21-35%);logit-normal 在各 step 都更优,few-step 时尤其
C.7 Cond CFG scale1 / 2 / 3 / 4CFG=2 最佳1→2 substantially improves;3、4 逐步下降,过强 guidance 反而 degrade

C.1 — 为什么 x-pred 才行?

三种 prediction 是数学等价的,但训练 signal 完全不同。论文用三个 encoder size (T5-small/base/large = 512/768/1024 dim) sweep:

App C.1 Fig 10 — Prediction targets
Fig 10 (paper p.21, App C.1): 三种 prediction target 在不同 encoder dim 下的 Gen-PPL ↔ entropy frontier。详细解读见 §4 黑盒 callout 那张同图。

解释:clean 文本数据在 embedding 空间是低维流形。x-pred 预测的就是这个流形上的点; ε-pred 预测的是高维 Gaussian,模型必须学一个全维度等熵分布——更难。 这条 finding 支持"continuous DLM 的关键不是连续,是 x-prediction" 的 framing。

C.3 — 80/20 denoise/decode 比例

很反直觉:如果 decoder 占比上升到 0.5,按理说 decoder 应该学得更好——但实际整体 frontier 都退化。 解释:decoder 共享 transformer 主干。如果训练时频繁切换 mode,主干被两个目标拉扯;只占 20% 时 decoder 学到的是 "在已经 transport 到 clean 的 embedding 上做最后一步映射"——比例小但效果反而好。

C.6 — Time schedule 在 few-step 时尤其关键

Logit-normal time grid 让 noisy 区间更密——8-step 时这极其重要。 Uniform 8 步 PPL 远高于 logit-normal 8 步。32-step 之后两者差距收窄。
γ sweep(paper 选默认值):8/16 步默认 γ=2.0;32 步 γ=1.5;64 步 γ=1.0。 论文文字说 γ 控制 PPL/entropy trade-off,paper 默认 γ=1.0 作为各 step budget 的 balance; 8/16 步用更大 γ=2.0 是因为粗步长需要更多 stochasticity 修正噪声累积。

7.5 Tab 6 + Tab 7 数据汇总(detail 详见 6.7-6.8)

覆盖最佳 valid Gen-PPL
Tab 6 (system-level, ELF-B, 6 seeds)8/16/32 step32-step SDE γ=1.5 SC-CFG=3: 24.08 ± 0.16
Tab 7 (scaling, 64-step)B/M/L × ODE/SDE × CFG 0.5-4各 size 内最低 valid PPL:ELF-B 19.72 (CFG=3) / ELF-M 21.69 (CFG=3) / ELF-L 21.37 (CFG=4)。CFG>3 部分 cell 标灰

7.6 数据来源 & 实验环境

8 · 定性效果 — 去噪轨迹

Denoising trajectory Fig 17
Fig 17 (paper p.28): 从 t=0 的 gibberish/repetitive token 开始('strength will building building...'),随 t 增长 ELF 渐进地形成语义有意义的短语,最终(t=1)解码为流畅句子。连续轨迹的好处:每一步都是几何渐变,不是离散跳变。

这张图回答了"continuous DLM 到底在做什么"——它在 embedding 空间里描出一条平滑轨迹, 最后一步才把轨迹终点投影到 token 词表。和离散 DLM 每步都做 vocab argmax 完全是两套范式。

9 · 和字节 Cola-DLM 对比 — Field Landscape

Cola-DLM(ByteDance Seed, arXiv 2605.06548, 2026 年 5 月) 是和 ELF 几乎同时(2 周内)冒出来的同类工作。两边都是 continuous DLM,但设计哲学几乎相反: ELF 求简,Cola 求强。下面是我从 Cola 的公开 blog 和 arXiv 摘要整理的对比。

9.0 一句话定位

ELF = "把 encoder 冻结,diffusion 只学 transport"——最小架构、最高数据效率,刷 OWT Gen-PPL。

Cola-DLM = "diffusion 不应该恢复 noisy token observation,应该建模 semantic latent prior"—— 两阶段训练(VAE pre-train → joint VAE+DiT)+ block-wise 推理,~2B 参数,刷 reasoning task average。

9.1 关键差异表

维度ELF (MIT)Cola-DLM (ByteDance)
核心对象Contextual embedding 上 Flow MatchingText VAE latent 上 block-causal FM
EncoderFrozen T5-small (35M)Learnable Text VAE (~500M)
Latent spaceToken-aligned, 512-d (bottleneck 128)Explicit z, d=16 (默认)
Diffusion 目标恢复 clean contextual embedding把噪声 transport 到 learned latent prior
Decoder共享 Transformer (final-step 切换 mode)独立 VAE decoder + KV cache
BackboneDiT,全双向 attentionBlock-causal DiT (intra-block 双向、inter-block 因果)
训练单阶段 80/20 mix两阶段训练 (VAE pretrain → joint VAE+DiT) + block-wise 推理
损失80% MSE + 20% CEStage 2: λVAE·LVAE + λFM·LFM + λref·KL(q‖qref)
参数105M / 342M / 652M~2B 总 (1.8B DiT + 500M VAE)
采样32-64 步 ODE/SDE Euler8-16 步 / block, block-causal + KV cache
评测Gen-PPL, BLEU, ROUGETask Avg (LAMBADA/MMLU/SIQA/RACE/...)
对标MDLM/Duo/FLM/LangFlowAR + LLaDA at 2B

9.2 ELF / Cola 的客观技术差异

维度ELFCola
Diffusion 建模对象Token-aligned contextual embedding(T5 encoder output 上每个 position 一个 512-d vector)压缩后的 semantic latent z(VAE 编码出来的较低维向量)
是否有显式 p(x)=∫p(x|z)p(z)dz无(直接对 contextual embedding 做 FM,最后一步 decode 到 token)有(VAE 提供 p(x|z),diffusion 学 p(z)
Decoder共享 Transformer 主干,mode token 切换 + factored linear head独立 VAE decoder
训练时中间步是否做 token-space loss不做(中间全是 MSE on embedding;只在最后步混 20% CE)不做(diffusion 在 latent 空间,CE 由 VAE 在两阶段训练里分担)
Attention 模式非因果 / 全局 attentionBlock-causal(block 内非因果,block 间因果,可 KV-cache)
需要调的"杠杆"1 套 logit-normal schedule + denoise/decode 比例 + bottleneck dimVAE 质量 + latent dim + logSNR + block size + anti-drift KL + CFG + 评测协议

9.3 两边各自能给出而另一边给不出的 claim

Cola 能给出而 ELF 给不出

  • 显式层次化潜变量 p(x)=∫p(x|z)p(z)dz,diffusion 建 pψ(z)、VAE decoder 建 pθ(x|z)
  • Block-causal attention + KV-cache:扩散 serving 可以像 AR 一样按 block 增量推理
  • 2B 规模 + reasoning benchmark 数字(LAMBADA 50.80 / MMLU 19.30 @ 2000 EFLOPs)

ELF 能给出而 Cola 给不出

  • "中间步不需要 token-space supervision" 的干净对照(中间全程 MSE on contextual embedding,只最后一步 20% CE)
  • 无 distillation、无独立 VAE 栈下的小尺度数字(105M、45B tokens、32 步 Gen-PPL 24.08)
  • Encoder / decoder 主干完全共享(同一份 Transformer 参数 + mode token 切换)的可行性

9.4 Cola-DLM 两阶段训练 + 推理(来自他们 blog)

Cola 不像 ELF 那样单阶段端到端训练。两个训练阶段 + 独立的推理流程:

阶段训练对象损失目标
Stage 1 — VAE pretrainingText VAE encoder + decoder LVAE = −𝔼[log pθ(x|z0)] + β · KL(qφ‖pbase) + λmask·Lmask
(带 BERT-style masking loss)
学一个稳定的 text↔latent 映射,避免 semantic collapse / decoder shortcutting
Stage 2 — Joint VAE + DiTVAE + block-causal DiT Lstage2 = λVAE·LVAE-like + λfm·LFM + λref·KL(qφ‖qφ_ref)
(reference KL 抑制 latent drift)
DiT 在已稳定的 latent 上学 flow matching prior;reference KL 不让 VAE 漂移
Inference (非训练阶段) Block-wise prior transport(DiT 生成 latent block)+ VAE decoder(latent → tokens, KV-cached)

关键超参(来自 blog RQ4 sweet spot 表,released checkpoint 配置)

9.5 Block-Causal Attention(Cola 的核心架构选择)

Cola DiT 在序列维度上做块因果(block-causal)分解:

这种设计的好处:

  1. KV cache 兼容:因为 inter-block 是因果的,已经生成完的早期 block 的 K/V 可以缓存。下一个 block 不重新算前面的 attention,类似 AR 的推理加速。
  2. 训练时仍并行:teacher forcing 的常规做法,把 block-causal mask 加到 attention 矩阵,所有 block 一次性算完。
  3. 生成质量 vs AR 的赌注:保留 diffusion 的"全 batch 同时 denoise"特性(block 内并行),同时获得 AR 的推理时 KV cache 加速

Block size sweep(来自他们 RQ2/RQ3):

9.6 Cola-DLM benchmark 数据(released checkpoint @ 2000 EFLOPs)

Cola 不报 Gen-PPL。他们用"generative few-shot"协议把多选题转成生成任务, 和 AR + LLaDA 在 ~2B 同 scale 下对比。来自 ByteDance-Seed/Cola-DLM GitHub model card 的 released 数字

TaskCola @ 2000 EFLOPs说明
LAMBADA50.80段落补全
MMLU19.3057 类多选 — 数字仍低于 AR 同 scale
SIQA28.90社会场景推理
RACE19.60阅读理解
Story Cloze30.77故事结尾选择
OBQA23.00开放式问答
HellaSwag10.70常识 NLI
SQuAD30.90抽取式问答
Task Avg26.758 任务平均

注意:blog 内 RQ2/RQ3 ablation 表给出更小的训练 budget 下数字(LAMBADA 31.1-34.6 / MMLU 5.4-10.1 / SIQA 11.1-23.6 等), 但这些是消融区间,不是 headline。上面才是 released checkpoint 数字。

他们的 scaling 实验跑到 ~2000 EFLOPs。Blog 说 "Task Avg 在 ~2000 EFLOPs 还在 rising"—— 官方称仍有 headroom,没看到饱和。

9.7 Cola 的关键 ablation 发现(他们 RQ2/RQ3)

消融结论
Fixed-VAE vs Joint-VAE trainingJoint 在大算力下 win。冻结 VAE 训练 (像 ELF 那样) 在小算力 ok,但 scaling 时无法继续涨
All-Scratch baseline从零训 VAE + DiT 始终不如先 Stage 1 pretraining
Interval freezing (Stage 1.5)VAE 在中间阶段冻结一段再放开 — 比一直 joint 差
Sampling steps少步数明显不足;~8-10 步基本恢复;16-32 步饱和
Patch size 2 (压缩 latent)整体比 patch 1 差;但当 prompt length 对齐 patch 边界时反而略好(18.12 vs 17.31)— "Token-level segmentation 不一定最优"

这里有个有意思的对比:Cola 通过 ablation 证明 joint VAE+DiT 训练在大算力下更优;ELF 通过 ablation 证明 frozen encoder 在小数据小算力下更优。两个结论不矛盾—— 它们对应不同的训练规模、不同的"哪个组件更值得优化容量"的判断。

9.8 客观差异汇总

10 · Q&A

10.A 关于 ELF 自身

Q1: ELF 为什么能少用 ~12× 训练数据?

真正的来源是 frozen T5-small 的 contextual embedding 已经包含了语言的几何先验。 ELF 不学"语言是什么",只学"如何在这个空间里 transport"。等效于一种 transfer learning,T5 的预训练成本 (Google C4 上约 1T tokens)没被算进 ELF 的 45.2B tokens 里。Tab 5 只比较了 ELF vs baseline 的自身训练 token,不计 encoder pretrain。 这是这篇论文最容易被 attack 的点。

怎么辩护:把 ELF 看作"在 T5 表示上的 transfer-learning DLM"。MDLM/Duo 也间接用了 word-level tokenizer 的语言先验 (虽然程度不同)。但承认这是"有限的 12×",不是免费午餐。

Q2: 如果换 encoder(GPT-4 hidden state / LLaMA-7B hidden state)会更强吗?

大概率是的,但 paper 没做这个实验(App C.1 只 sweep 了 T5-small / base / large 三个 size)。 这是 ELF 最大的 follow-up 机会,也是它最脆弱的 claim—— ELF 的天花板可能就是 encoder 的天花板。如果有人用 LLaMA-3-8B hidden state + ELF flow 训一个 105M model,能不能逼近 LLaMA 自身 quality? 这个实验没人做过,是 obvious next step。

Q3: 为什么 final step 才做 CE?中间步加 CE 不更好吗?

Paper App C.1 ablate 了:x-prediction 在所有 embedding 维度都稳定,v-prediction 在高维 degrade,ε-prediction 全面崩溃。 背后假设(paper 引用 Li & He 2511.13720):clean 文本数据在 embedding 空间是低维流形。 中间步加 CE 等价于"把噪声大的 z_t 强行映射回 token",这等于把 quantization wall 偷偷搬回来—— 破坏了 ELF 的"only final step rounding"核心 framing。

Q4: 训练时 CFG(Eq 3)为什么不直接推理时跑两遍?

可以,但训练时把 SC-CFG 烤进 vtarget 后,推理只需要一次 forward,省一半算力。 代价是训练时3 次 forward(2 个 no-grad + 1 个 gradient-tracked)。 推理跑无数次,训练跑一次。划算。 这个 trick 是 image diffusion 圈 Chen et al. "Visual Generation without Guidance" (ICML 2025) 的方法,ELF 直接搬过来。 注意 ELF 仍然保留另一个推理时 CFG(input-cond CFG=2,用于 XSum/WMT),跟 SC-CFG 是两个独立机制。

Q5: 32 步 PPL 24.08,比 dataset 自身 PPL 怎么样?

论文用的就是 frozen GPT-2 Large 当 judge。Fig 7 里的 "Dataset" 虚线就是 OWT 真实文本在 GPT-2 Large 下的 reference PPL。 ELF-B 32 步 SDE 24.08 接近这条 reference;ELF-M 在 64 步 SDE / SC-CFG=3 下 21.69,更接近该 reference。 注意 "Dataset reference" 不是严格的理论下限—— 低于它也可能伴随低 entropy(即重复但流畅)。 而且这只是 GPT-2 Large judge 下的"流畅性"指标,不是通用质量下限: 换更强 judge(GPT-4 / Llama-3)数字会变。

Q6: bottleneck 为什么是 128,不是 256/64?

App C.2 sweep 了 {32, 128, 512}(不含 256/64,论文没测)。 128 是 ODE / SDE 双采样下的 frontier balance 点: 32-d 在 SDE 下 PPL 最低但 entropy < 5(多样性不够),512-d 把 PPL 推高了。 背后假设:clean text 在 embedding space 是低维流形,128-d 就足够"覆盖"这个流形。 对比 image diffusion 也常用类似 bottleneck(DiT、SD 都有)。256 这个中间值 paper 没测,按 frontier 趋势插值应该 fine。

Q7: Muon 为什么比 AdamW 强?尤其在 SDE 下?

App C.5。Muon 对 2D 参数用 Newton-Schulz orthogonalization 把梯度先正交化再 step, 这抑制 ill-conditioned 方向,相当于 implicit second-order。 SDE 采样需要更精确的 v 预测(噪声重新注入会放大 v 的误差), Muon 训出来的模型在 v 上更"光滑",所以 SDE 推理时优势更大。 Paper 只定性说 "more pronounced under SDE",没给具体数字。 工程上 Muon 还有 fp32 NS5 + Nesterov bias correction 等 patches(详 §4.10)。

Q8: ELF 怎么处理长序列?1024 已经是上限吗?

论文实验最长 seq=1024(OWT)和 1088(XSum)。当前 checkpoint 和 config 没有验证 8K 以上长度。 Architecture 上没有 causal LM 那种生成方向限制,但: (a) RoPE buffer 按 max_length 预构建(要扩 8K 需重建或加 RoPE scaling); (b) 全双向 self-attention 是 O(L²) 复杂度。 关于 "能 scale 到 8K context 吗"——architecture 上理论可行但 paper 没测; 可能需要 RoPE scaling + linear/sparse attention 改造,是 obvious extension。

10.B 关于 Cola-DLM / 两条路线对比

Q9: Cola-DLM 那种 VAE 路线不是更"正统"吗?为什么 ELF 看起来更干净?

VAE 路线的代价是 encoder 也得训,要解决 posterior collapse、latent drift、reconstruction trade-off 等额外问题。 ELF 用 frozen T5 把 encoder 部分外包给现成的预训练模型,所以 paper 短、ablation 干净、可以专心 ablate 7 个 ELF 自身的超参。 Cola 必须同时管理一大堆 VAE 超参(latent dim, block size, logSNR, KL ratio, anti-drift KL, multi-stage scheduling, ...)。 两边的代价不同:ELF 把复杂度外包给 Google 的 T5 训练;Cola 把复杂度内化到自己的训练 pipeline。

Q10: Cola 报 LAMBADA 50.80、MMLU 19.30,看起来很低?

对,这是 generative few-shot 协议下的数字(把多选题转成生成)。 按 Cola 自己的 model card 和同协议对照(我没单独验证 AR 同 scale 的具体数字), Cola 的 LAMBADA 接近部分 AR baseline,MMLU 明显偏弱。 但 Cola 论文强调的是 scaling 趋势:曲线还在涨,没饱和。 如果 reviewer 说 "Cola 数字一般"——对,它的卖点是 architecture 可行性 + scaling shape,不是当下数字。 这类对比不要直接列 "AR 47-55" 这种具体数字,没 cite 不严谨。

Q11: 为什么 ELF 没用 KV cache?

ELF 的 DiT 是全双向 attention,每步 forward 都要重新算所有位置的 attention。 KV cache 需要 causal mask(前面 K/V 缓存供后面 query 用),ELF 没这个结构。 这是 ELF 推理慢于 AR 的根本原因之一: 即使 ELF 32 步 << AR 1024 token,每步的 attention 复杂度还是 O(L²)。 Cola 用 block-causal 解决了这个——这是 Cola 的核心架构卖点。

Q12: ELF 和 Cola 谁更接近 production-ready?

都没有。但 ELF 更接近"clean scientific demonstration",Cola 更接近"engineering system"。 具体看:

11 · 引用 & 资源

资源链接
ELF paper PDFarXiv 2605.10938
ELF GitHub (官方 JAX)lillian039/ELF
ELF PyTorch portpytorch_elf branch
ELF HF checkpointsembedded-language-flows
Cola-DLM paperarXiv 2605.06548
Cola-DLM GitHubByteDance-Seed/Cola-DLM
Cola-DLM bloghongcanguo blog
Cola-DLM HFByteDance-Seed/Cola-DLM

背景文献(用来支撑 punchline 的 framing)


作者 Ruofeng Yang(杨若峰) (Shanghai Jiao Tong University, 2026-05)。文档由 ARIS (Auto Research in Sleep) 的生态 ARIS-in-AI-Offer 工作流产出,由 Claude Opus 4.7 整合 Codex GPT-5.5 xhigh + Gemini auto-gemini-3 跨模型讨论后撰写。本文是关于 ELF 等 continuous DLM 论文的第三方阅读笔记 / 综述,所有论文内容、图表、代码版权归各自原作者所有。 图片均截自 ELF arXiv PDF v1(2026-05-11)。代码片段来自官方 pytorch_elf 分支 @ b29d883。 中间审计 trace 保存在 .aris/traces/research-review/2026-05-26_run01/