【DL輪読会】Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions

768 Views

July 24, 25

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

ダウンロード

関連スライド

各ページのテキスト
1.

Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions Daiki Miyake, Matsuo Lab 1

2.

書誌情報 • タイトル Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions • 著者 Jaeyeon Kim, Kulin Shah, Vasilis Kontonis, Sham M. Kakade, Sitan Chen (ハーバード、テキサス大学) • 書誌 ICML2025 Outstanding Paper ICMLサイト: https://icml.cc/virtual/2025/poster/45990 OpenReview: https://openreview.net/forum?id=DjJmre5IkP arXiv: https://arxiv.org/abs/2502.06768 2

3.

概要 • Masked diffusion models (MDM) と Autoregressive models (ARM) を 比較 1. 学習時には、MDMはARMよりも膨大なマスクパターンを学習する 2. 推論時には、ARMは前から順にしか生成できないが、MDMは好きな順で生成 できる • 貢献 1. 学習困難なマスクパターンが存在することを定量的に評価 2. Top-k サンプリングなどでトークンの生成順を決めることで、数独などの タスクで性能改善 3

4.

選好理由 • MDMの推論方法について興味を持った – MDMでARMと同等の性能を達成するためには、系列長に比例した推論ステッ プが必要 [Nie+ 25, Feng+ 25] – もしこれが真なら、MDMの優位性はない • 画像や動画の連続時間Diffusion Modelでは、推論方法を工夫することでステップ 数を少なくすることが出来る (DDIM, DPM-Solver) • MDMにおいて、推論方法を工夫する余地はあるのか? [Nie+ 25] Shen Nie, et al. Large language diffusion models. ICML2025. [Feng+ 25] Guhao Feng, et al. Theoretical benefit and limitation of diffusion language model. arXiv:2502.09622. 4

5.

Masked Diffusion Models (MDM): 定式化 • 変数 : 語彙サイズ : 系列長 : マスクトークン 0 • 拡散過程 カテゴリカル分布 各トークンは 確率 でそのまま、 確率 でマスクトークン に置き換わる トークンだけ1の one-hotベクトル 非マスクトークン はそのまま • 逆拡散過程 0 0 マスクトークンはそのままか、 に置き換わる 5

6.

Masked Diffusion Models (MDM): 学習 • から を予測するニューラルネットワーク を計算する を用いて、逆拡散過程 • ELBO最大化による学習 0 Cross Entropy 0 6

7.

Masked Diffusion Models (MDM): 推論 • 全てマスクトークンの状態からスタート の遷移時には、マスクトークンのうち • の割合のトークンを ランダムに選び、元トークンを予測する • これを になるまで繰り返す 0 の割合の マスクトークンを選ぶ 0 0 0 0 0 0 0 7

8.

MDMの問題点 • 目的関数は以下のように書き換えることができる (Proposition 2.1) マスクされたトークンの インデックスの集合 {1,2,…,L} の部分集合の集合族 • MDMは、あり得るすべてのマスクパターン( 復元するように学習する • 一方、ARMのマスクパターンは 通り)に対して、元トークンを 通り • MDMがより膨大な空間を学習するメリットはあるのか? (→そもそも学習困難なマスクパターンが存在する) 8

9.

Latent tokens / Observation tokens • 分析のため、単純なトークン列の設定を考える (Definition 3.1) • Latentトークン: 事前分布に従う、独立したトークン • Observationトークン: Latentトークンの関数 で決まるトークン • トークン列のうち、最初の トークンがLatent、残りの Observationとする (必要であれば並び替える) トークンが Latentトークン • (日本, 東京, 富士山, True, True) Observationトークン : 国と首都の組み合わせがあってるか? : 国と最高峰の組み合わせがあってるか? • (日本, パリ, 富士山, False, True) 9

10.

学習困難なパターン • ARMでは学習可能 – 最初の トークン(Latent)はトークン間で独立な事前分布を学習できればOK – 後ろの トークン(Observation)は関数 を学習できればOK • MDMでは学習困難なマスクパターンが存在する – Latentトークンのみがマスクされていて、 がハッシュ関数である場合など 10

11.

学習困難なパターン: 理論的に証明可能な例 • Latentトークンのうち、 個のトークンに関する条件式 トークンとする (Example 3.2) • ある定数 が存在して、マスク率 の真偽をObservation が : の引数をランダムに 決めた時に真となる確率 を満たすとき、マスクされたLatentトークンは情報理論的には復元可能だが、 多項式時間アルゴリズムで復元することはできない (Proposition 3.3) 11

12.

学習困難なパターン: 実験的に証明可能な例 • MDMの各トークンをマスクする処理を、「トークンを並び替えて後半をマスクす る」処理に置き換える →マスクパターンを並び替え のパターンに置き換える 並び替えの問題に置き換える 従来のMDMs 0 0 0 0 • 「学習が難しいマスクパターン」の代わりに、「学習が簡単な並び替え の パターン」を考える →自然言語の場合、ARMのように文章を前から順に予測する( が恒等写像) のが簡単と考えられる 12

13.

学習困難なパターン: 実験的に証明可能な例 • 並び替え を一様にランダムに 選ぶよりも、恒等写像に近くなる ように選ぶ方(-closer)が、 尤度の観点で良い →自然言語の場合、前から順に マスクされるパターンが簡単 13

14.

提案手法: 適応的サンプリング • MDMの学習においては、学習が困難なマスクパターンが存在する →サンプリングの際にはこうしたパターンを避けたい • 推論時のアンマスクするトークンの選び方を工夫する (従来はランダム) 0 の割合の マスクトークンを選ぶ 0 0 0 0 0 0 0 14

15.

提案手法: 適応的サンプリング • Top-K サンプリング 予測語彙の確率が最大となるトークンをアンマスクする • Top-K margin サンプリング 予測語彙の確率差が最大となるトークンをアンマスクする Vanilla Top-K Top-K margin 0 0 ○ ◎ ○ ◎ 15

16.

実験結果: 数独 • Top-Kサンプリングによってvanillaよりも性能改善 Top-K marginではさらに改善 • ARMで明示的に解く順番の情報を与えた場合よりも より小さなモデルでより良い性能を達成 →解く順番をモデル自身が決められるメリット https://ja.wikipedia.org/wiki/% E6%95%B0%E7%8B%AC 16

17.

実験結果: ゼブラパズル • Top-K, Top-K marginによってvanillaよりも性能改善 • モデルサイズの大きなARMよりも良い性能 https://ja.wikipedia.org/wiki/%E3%82%B C%E3%83%96%E3%83%A9%E3%83%B B%E3%83%91%E3%82%BA%E3%83%A B 17

18.

実験結果: テキスト • 学習済みのLLaDA 8Bで実験 • Top-K, Top-K marginによってvanillaよりも性能改善 • 特にコードや数学のように、「途中過程でのミスが即誤答に繋がるタスク」では Top-K marginが有効 Pythonコード 数学 知識理解 因果関係 18

19.

議論: MDM×Top-K sampling • 最近のMDMでは、Top-K samplingが既に使われつつある [Nie+ 25, Zheng+ 24] – ベースラインとして用いたLLaDAについても、元論文ではTop-K samplingを採用している • 今回の論文では、学習時のマスクパターンの難しさからTop-K samplingを使う べきと理由付けをした • LLaDAの論文ではARMとの比較が 行われているが、Qwen系には 性能が劣る結果 [Nie+ 25] Shen Nie, et al. Large language diffusion models. ICML2025. [Zheng+ 24] Lin Zheng, et al. A reparameterized discrete diffusion model for text generation. ICLR2024. 19

20.

まとめ • MDMにおいて、データによっては学習困難なマスクパターンが存在することを 理論的・実験的に示した • 推論時にそうしたマスクパターンを避けるために、適応的なサンプリング手法を 提案 • 数独、ゼブラパズル、テキスト生成において、従来のサンプリング手法よりも 性能向上 20