3.9K Views
June 04, 25
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] Theoretical Benefit Language Model and Limitation of Diffusion Gouki Minegishi, Matsuo Lab http://deeplearning.jp/ 1
書誌説明 • おそらくICML2025(合否分からず) • Masked Diffusion Modelsの理論的なBenefitとLimitation – MDMが流行っているので,ARとの違いや限界がないか知りたかった – MDMの論文を初めて読んで,分からないところ結構あった • Youtube link: https://www.youtube.com/watch?v=qt4Q9LTgIdY • FirstはB3!? 2
背景 • LLMはAutoRegregssive(AR)なモデルが一般的(e.g., LLaMA) – 推論の効率が悪く,生成に時間ががかる • Diffusion Language Modelの登場 – Masked Diffusion Model (MDM)が特に流行る – 複数トークンを同時に生成できて早し,一部タスクでARの性能を上回る https://deepmind.google/models/gemini-diffusion/ https://ml-gsai.github.io/LLaDA-demo/ 3
Research Question • MDMは,ARより速くかつ同等品質で文章を生成できるのか? • もし可能/不可能なら,その理論的条件と限界は何か? ざっくりまとめ 1. Token-levelの誤り率(PPL)ならMDMはARよりも早くて高品質 (具体的には系列長に依存しないオーダーのステップ数で高品質な生成可能) 2. Sequence-levelの誤り率は,MDMは無限ステップあれば高品質 3. ただステップ数が系列長に対して線形オーダーの場合,MDMのSequence-level の誤り率が大きくなる設定がある 4
MDM: Forward Process • MDMでは,データ(𝒙𝟎 )を全てマスクの系列(𝒙𝟏 )にする forward 1 forward • 𝑥𝑡𝑖 について確率𝛼𝑡 でデータ(𝑥0𝑖 ) 0 𝑡 • 確率1 − 𝛼𝑡 でデータ(𝑥1𝑖 = [𝑚]) • 𝛼0 = 1, 𝛼1 = 0 𝑥01 𝑥𝑡1 [𝑚] 𝑥02 𝑥𝑡2 [𝑚] 𝑥𝑡𝑖 … … 𝑥0𝑖 … … … 系列長𝐿 [𝑚] … … … 𝑥0𝐿 𝑥𝑡𝐿 [𝑚] 5
MDM: Reverse Process • 𝒙𝟎 , 𝒙𝒕 が与えられた時の𝑥𝑠𝑖 (𝑠 < 𝑡)の確率は→ – 𝑥𝑡𝑖 = 𝑥0𝑖 の時, 𝑞𝑠|𝑡 𝑥𝑠𝑖 = 𝑥01 𝒙𝒕 , 𝒙𝟎 = 1 1−𝛼 𝛼 −𝛼 𝑠 𝑡 – 𝑥𝑡𝑖 ≠ 𝑥0𝑖 つまり𝑥𝑡𝑖 = [𝑚]の時, 𝑞𝑠|𝑡 𝑥𝑠𝑖 = [𝑚] 𝒙𝒕 , 𝒙𝟎 = 1−𝛼𝑠 , 𝑞𝑠|𝑡 𝑥𝑠𝑖 = 𝑥0𝑖 𝒙𝒕 , 𝒙𝟎 = 1−𝛼 𝑡 • 上の式を𝒙𝟎 で周辺化すると以下 • この𝑞𝑠|𝑡 (𝑥𝑠𝑖 |𝒙𝒕 )を𝑝𝜃 でモデル化し, 推論時にはフルマスク(𝒙𝟏 )から 予測していく reverse 𝑡 0 𝑠 𝑡 1 𝑥01 𝑥𝑠1 𝑥𝑡1 [𝑚] 𝑥02 𝑥𝑠2 𝑥𝑡2 [𝑚] 𝑥𝑡𝑖 … … 𝑥𝑠𝑖 … … … … … 𝑥0𝑖 [𝑚] … … … … 𝑥0𝐿 𝑥𝑠𝐿 𝑥𝑡𝐿 [𝑚] 6
Task: Hidden Markov Model and n-gram • Hidden Markov Model (HMM) 𝐴𝑠1 ,𝑠2 隠れ状態 𝑠1 𝑠2 遷移確率 Unembedding行列 𝐴𝑠1 ,𝑠3 𝑠3 • 𝑠1 からスタートし, L回HMMを動かしてL個の隠れ状態の系列を得る • トークンの空間に変換し,𝒙を得てモデルはこれを観測できる • • n-gramはHMMの隠れ状態を直前のn-1トークンに固定したものとして考えられる 一般にHMMの方が隠れ状態があるので長期依存や階層構造を吸収できる HMMやn-gramを真の分布として扱う 7
Token Error Rate / Sequence Error Rate • Token Error Rate (TER) – |𝒙|は系列長のことでPPLと同じ • Sequence Error Rate (SER) – 個々のトークンだけでなく,系列の全体が正しいかを測る – 真の分布𝑞(𝒙)が許容する系列のみ計算する • TER (PPL)は,単語ごとのメトリクス. – 文全体の流暢さなどはこっち • SERは文章ごとのメトリクス. – Reasoningなど文章の全部の単語が合ってないとだめな場合に重要なメトリクス 8
定理4.2: TERについて • 仮定4.1 真の逆拡散仮定に対してKLの意味で十分に学習できたとする • 定理4.2 n-gramモデルに対してTERの上界は以下 → 系列長Lに依存しないサンプリングステップ数でTERが十分下がる (ARモデルの推論ステップはLに依存しているので嬉しい) サンプリングステップ数オーダー 9
定理4.3 : SERについてポジティブ理論 • 定理4.3 任意のHMMに対して仮定4.1がなりたつ時,SERの上界は以下 つまりSERも無限に小さくできる (サンプリングステップ数の制約がなければ,,,) 10
定理4.3: 証明 • Reverse ProcessのInstance (𝜏) Miはi番目のステップでサンプリングされるトークンidの集合 Nステップ分行うと reverse Miは以下を満たす. つまりモレなくダブりなく 0 𝑁 𝑠 𝑡 𝑥01 𝑥𝑠1 𝑥𝑡1 [𝑚] 𝑥02 𝑥𝑠2 𝑥𝑡2 [𝑚] 𝑥𝑡𝑖 … … 𝑥𝑠𝑖 … … … … … 𝑥0𝑖 [𝑚] … … … … 𝑥0𝐿 𝑥𝑠𝐿 𝑥𝑡𝐿 [𝑚] 11
定理4.3: 証明 • 補題C.1: ステップ数が大きいとで同時に複数トークンがサンプルされ る確率は小さくなる あるステップで複数のトークンがサンプリングされる確率を𝑝𝑚𝑢𝑙 とする サンプリングステップ数(N)を無限に大きくれば𝑝𝑚𝑢𝑙 は小さくなる • 補題C.2: MDMは,1つずつ生成すると性能が高い のような設定を考える. つまり,各ステップで1つサンプリングするか,何もしないか この時その文章が全部あっているかどうか𝑝𝑎𝑐𝑐 の下界は以下( ) 補題C.1,C.2をもってして十分にNが大きいとSERが小さくなる(定理4.3) 12
定理4.4: SERについてネガティブ理論 • 定理4.4 あるHMMを考えると,N=CL(Cは定数)の時,SERの下界は以下 つまり,ある真の生成分布でMDMは系列長Lに対して線形オーダーのサ ンプリングステップ数(N)では,どうやっても半分も文章が当たらない その真の生成分布とは...? 13
定理4.4: インターバル設定 𝑥01 𝑥02 𝑥03 𝑥04 … … • インターバル設定 長さ𝐿の系列を𝑀個等間隔に分割(長さ𝑙 = 𝐿/𝑀) この各𝑀個のインターバルにそれぞれ異なるHMMを 考え,インターバル間は独立とする またステップ𝑖でインターバル𝑗でサンプリングされた (j) トークンを𝑥 𝑖 とすると,インターバル間は 独立なので,以下の式が成り立つ 𝑥0𝑖 … … 𝑥0𝐿−1 𝑥0𝐿 14
定理4.4: 証明 • 補題C.5 インターバル𝑗の全てのトークンが異なるステップでサンプルされる確率をℎ𝑗 とする ステップ数(𝑁)が大きくなるとこの確率は大きくなる. • 補題C.6 インターバル内で同時サンプリングが起こって失敗する確率を𝑝𝑒 とすると SERの下界は以下 この上で語彙数=16, 𝑁 = 𝐶𝐿, 𝑙 = 5, の条件を考えると右辺が1/2になる 15
補題C.6: 証明+α • 補題C.5よりインターバル𝑗で同時生成が起きる確率 1 − ℎ𝑗 は • 𝑗 インターバル𝑗が失敗する確率(𝑝𝑒𝑟𝑟𝑜𝑟 )は失敗確率×同時生成確率 • インターバル間は独立のため系列の成功確率は • SERの下界は 16
補題C.6: 証明+α • 𝑝𝑒 /𝑁が十分に小さい時 • テイラー近似 • 例えば,SERを𝛿以下にしたかったら必要なステップ数は →この設定だとステップ数𝑵が系列長𝑳の線形オーダー必要 インターバルの数(𝑴)が多い設定も多くのステップ数が必要 17
実験1: トイ実験 • n-gram(n=2,3,4)とHMMで同じパラメータのMDMとARを実験 • PPLの場合は,ARより早い推論速度で同程度の性能 • SERはサンプリングステップを増やしてもARに及ばない 18
実験2: LLM • PPLの場合は,同程度のパラメータ数のGPT2-mediumより 早い推論速度で勝つ • GSM8Kでは同程度のパラメータのQwenに大敗 – SlimpajamaでpretrainされたMDMをGSM8KでSFT 19
まとめ・議論 • PPLではMDMがARよりも有利・SERは不利という結論 – PPLが必ずしもモデルの本当のPerformanceを反映していないという研究は LLMでいっぱいある • Beyond Perplexity: Examining Temporal Generalization in Large Language Models via Definition Generation • Can Perplexity Reflect Large Language Model's Ability in Long Text Understanding? • Emergent AbilityもPPLとPerformanceが整合しない例 • 今のMDMだと全てのタスクでARを超えることは理論的にできなそう という気にさせる論文 • あとインターバルの設定は実際の言語タスクだとどんな設定に該当す るかが分からなかった – もうちょっとリアルでかつMDMが全然解けないトイタスクは考えられそう – インターバル内は同時にサンプリングされずらいみたいな仕組みは理論的に保 証できそう 20