1.8K Views
July 09, 24
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] Diffusion Forcing: Next-token Prediction Meets FullSequence Diffusion Yuta Oshima, Matsuo Lab http://deeplearning.jp/
Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion 書誌情報 著者 • Under Review (NeurIPS2024) • Boyuan Chen, et al. 概要 • Diffusion Forcingという新しい系列拡散モデルの学習⽅法を提案 • Next token prediction の可変⻑⽣成の強みと,全シーケンス拡散モデルの持つ系列 全体を望ましい軌跡に誘導する能⼒の強みを融合 2
背景 • 確率的な系列モデリングは多様な場⾯で重要 • 特にNext token predictionでは,⾃⼰回帰により,単⼀から「無限」までの⻑さの 系列を⽣成可能 • しかし教師強制(Teacher Forcing)により学習されるNext token predictionには⼆つ の限界がある • ある特定の⽬的の最⼩化するために,系列全体をガイダンスする機構がない • 誤差が蓄積するため,⻑期の⽣成は実質的に困難 • 訓練時の系列⻑を超える場合は顕著 3
背景 • 全系列を⽣成する最近の拡散モデル(動画⽣成 [Ho et al. 2022] ,Diffuser [Janner et al. 2022] など)も存在 • ⼀定数のトークンの連結を⼀気に⽣成するため,系列 全体のガイダンスの⽋如や,誤差蓄積の問題を回避 • しかし,⾮因果的な⽣成になる上に,決められた系列 ⻑しか⽣成が不可能 • 素朴に全系列⽣成モデルのために,⾃⼰回帰モデルを 学習するナイーブは試みは,時間的に初期のトークン と後期のトークンの不確実性の違いをモデル化できな いため,うまくいかない 4
⽬的 • そこで,本研究ではDiffusion Forcing (DF)を提案 • 各トークンに対して異なるノイズスケジュールでノイズ除去する • トークンにノイズを与えることはマスキングである • DFは様々なノイズの乗ったトークンを「アンマスク」する • ⾃⼰回帰と全系列⽣成モデルの良いとこどりを⽬指す 5
⼿法 部分マスクとしてのノイズ付与 • ノイズ付与は,データの⼀部を隠すマスキングの⼀形態 • 通常の⾃⼰回帰では,時間軸に沿ってマスキング (𝑥!:#$% から 𝑥# を予測) ' • 𝐾ステップにノイズ付与し,ピュアノイズ 𝑥!:& とするの は,ノイズ軸に沿ったマスキング • ( 本研究ではこれらを統⼀的に扱い,(𝑥# ! )!)#)& を,各トーク ンが異なるノイズレベル𝑘# のノイズを持つことと表現する 6
⼿法 Diffusion Forcing Training • RNNによりマルコフ的に 𝑧# を推論する ( • 𝑧# ~ 𝑝* (𝑧# |𝑧#$% , 𝑥# ! , 𝑘# ) • 𝑘# = 0 の時は世界モデルでよくある事後分布, 𝑘# = 𝐾 の時は事前分布に対応 • ( ( また, 𝑥# ! に乗っているノイズ ε# を予測するモデル ε* (𝑧#$% , 𝑥# ! , 𝑘# ) を学習(これがLoss) ( • これは,期待対数尤度 ln(𝑝* ((𝑥# ! )!)#)& )) のELBOの重み付き最適化 • 𝑘%:& ~ [𝐾]& • この最適化により,全てのノイズレベルの系列について最適化が可能 7
⼿法 Diffusion Forcing Sampling • ⼆次元 𝑀×𝑇 のノイズマトリクス K ∈ [𝐾]!×# を⽣成 • ここで 𝑀 はデノイジングのステップ数 • 安定的な⻑期⾃⼰回帰 • データにノイズを乗せることで,BCにおいて⻑期的な エラーを軽減することは知られている(DART[Laskey+ 2017]など).DFも訓練時に過去のノイズ付き観 測から推論 • 他には,未来の不確実性の保存,⻑期の因果的ガイダンスが可能 8
実験 動画予測 • 300 frameで動画を学習し,1000 framesで推論 • フルシーケンスモデルでは動画の連続性が保てない,教師強制では⽣成が発散する • DFではうまく,訓練シーケンスより⻑い動画も⽣成可能 9
実験 プランニング • フルシーケンスプランニングモデルであ るDiffuserと⽐較 • Monte Carlo Tree Guidance (MCTG) を⽤いる ことで,DFは⾼い性能を発揮できる • MCTG:将来のトークン𝑥#+%:& の分布に基 づいて, 𝑥# の予測にガイダンスを加える 10
実験 ロボティクス • 初期位置を覚えていないといけないタスクなので,Diffusion Policy[Cheng+ 2023]のような 短期の 記憶しかないモデルでは失敗するが,DFでは𝑧# に記憶を保持しているので80%の成功率 • ロボットのタスクでも,ノイズや⽋損のある場合においてもうまく挙動できる(4%の性能低下に 留まる) 11
結論 • 限界としては,⼩規模なRNNで検証されているので,スケーリングについて検証 したい • DFという新たな⼿法を提案した,next token predictionとフルシーケンスモデル のいいとこ取りをした • 逐次決定タスクでの性能が劇的に向上 12
感想 興味深かった点 • ノイズを乗せることにより⻑期でも⽣成が壊れないというのは興味深かった • Diffuser, Diffusion Policyに勝てていることで,確かに良いとこどりができているよ うに感じた 考える余地のある点 • ガイダンスなどの⼿法が,⼤きなモデルにスケールするかは微妙 • ⽣成時間についての議論が⾒つからなかったが,だいぶ時間かかりそう • サンプル時の各ノイズレベルの決め⽅がかなり恣意的に感じるので,もう少し詰 めれそう 13