2.7K Views
August 08, 24
スライド概要
DL輪読会資料
Scalable Wasserstein Gradient Flow for Generative Modeling through Unbalanced Optimal Transport Daiki Miyake, Matsuo Lab (M1) 1
書誌情報 • タイトル Scalable Wasserstein Gradient Flow for Generative Modeling through Unbalanced Optimal Transport https://icml.cc/virtual/2024/poster/33558 • 概要 Wasserstein Gradient Flowの学習にかかる計算量を減らし,CelebA-HQの ような高解像度画像(256x256)での学習を可能にした 2
背景: Wasserstein距離 • 最適輸送問題 (Optimal Transport; OT) ある確率分布 𝑝𝑝 𝑥𝑥 を別の確率分布 𝑞𝑞 𝑦𝑦 に移す総コストを求める問題 • 特にコスト関数 𝐶𝐶 𝑥𝑥, 𝑦𝑦 が 𝐿𝐿 ノルム 𝑥𝑥 − 𝑦𝑦 p-Wasserstein距離と呼ぶ 𝑝𝑝 𝑝𝑝 𝑝𝑝 の時,総コストを 3
背景: Wasserstein Gradient Flow • ある確率分布 𝜌𝜌0 = 𝜇𝜇 を始点として,𝑡𝑡 = 1 で 𝜌𝜌1 = 𝜈𝜈 となる軌道 𝜌𝜌𝑡𝑡 𝑡𝑡∈ 0,1 を考える – 例えば,𝜇𝜇 が標準正規分布で,𝜈𝜈 がデータ分布 • この軌道はWasserstein Gradient Flowとして偏微分方程式で表せる ℱ 𝜌𝜌 は 𝜌𝜌, 𝜈𝜈 間の 𝑓𝑓-divergence (𝜌𝜌 が 𝜈𝜈 にどれだけ近いかを表す,例えばKL-divergenceなど) 4
背景: JKO scheme • Wasserstein Gradient Flowを離散的に計算する方法として,JKO scheme[Jordan et al., 1998] がある • 𝜇𝜇0 = 𝜇𝜇 (ソース分布)として,以下の式で順に 𝜇𝜇𝑘𝑘 を更新 (Implicit Euler methodに対応する) ℎ はステップサイズ • 適当な回数 𝐾𝐾 だけ繰り返すことで,𝜇𝜇𝐾𝐾 が 𝜈𝜈 に漸近する 5
背景: JKO scheme • Mongeの定式化によって,ある凸関数 𝜓𝜓 を用いた最適化問題として 表せる [Fan et al., 2022] 6
背景: JKOの問題点 • 𝑥𝑥𝑘𝑘 ∼ 𝜇𝜇𝑘𝑘 のサンプリングに 𝑂𝑂 𝑘𝑘 の計算量がかかる – 𝑥𝑥0 ∼ 𝜇𝜇0 = 𝜇𝜇 として,𝑥𝑥𝑘𝑘 = 𝑇𝑇𝑘𝑘−1 ∘ 𝑇𝑇𝑘𝑘−2 ∘ ⋯ ∘ 𝑇𝑇0 𝑥𝑥0 • 𝑘𝑘 = 0 から 𝑘𝑘 = 𝐾𝐾 まで順に最適化していくため,学習全体としては 𝑂𝑂 𝐾𝐾 2 の計算量がかかる 7
背景: Unbalanced Optimal Transport • Unbalanced Optimal Transport (UOT) – 𝐷𝐷𝜑𝜑 は 𝑓𝑓-divergence,𝜋𝜋0 (𝑥𝑥), 𝜋𝜋1 (𝑦𝑦) は 𝜋𝜋(𝑥𝑥, 𝑦𝑦) をそれぞれ周辺化したもの • OTの制約条件を緩和したもの 8
背景: Unbalanced Optimal Transport • UOTの双対問題をMongeの定式化とc変換によって書き換える • これは 𝑇𝑇 と 𝑣𝑣 のmin-max問題として考えられる • GANのように,2つのニューラルネット 𝑇𝑇𝜃𝜃 , 𝑣𝑣𝜙𝜙 を交互に最適化する [Choi et al., 2023] 9
提案手法: JKOとUOTの統合 • JKOの各最適化ステップはUOT JKO UOT • JKOの各ステップはGANのように最適化できる 10
提案手法: 計算量の低減 • 従来は 𝑇𝑇𝑘𝑘 を 𝑇𝑇𝑘𝑘 # 𝜇𝜇𝑘𝑘 = 𝜇𝜇𝑘𝑘+1 となるように定義していたところを, 𝑇𝑇𝑘𝑘 # 𝜇𝜇0 = 𝜇𝜇𝑘𝑘+1 となるように定義 – つまり,𝑥𝑥𝑘𝑘 = 𝑇𝑇𝑘𝑘−1 ∘ 𝑇𝑇𝑘𝑘−2 ∘ ⋯ ∘ 𝑇𝑇0 𝑥𝑥0 ではなく,𝑥𝑥𝑘𝑘 = 𝑇𝑇𝑘𝑘−1 𝑥𝑥0 とした 従来手法 提案手法 11
提案手法: 目的関数 • 目的関数 • 𝑇𝑇𝑘𝑘 のパラメータは 𝑇𝑇𝑘𝑘−1 のパラメータで初期化する • 𝑓𝑓 ∘ は 𝐹𝐹 𝜌𝜌 に対応した関数 12
実験結果: 定性評価 • CelebA-HQでの生成結果 13
実験結果: 定量評価 14
実験結果: 学習時間の比較 • 従来手法(上段3つ)よりも 短い学習時間で高品質 • UOTM(提案手法の 𝐾𝐾 = 1 に 対応)よりも高品質 15
まとめ • まとめ – JKOの各ステップをUOTとして最適化できることを示した – ソース分布からのOT写像を定義することで,計算量を 𝑂𝑂 𝐾𝐾 2 から 𝑂𝑂 𝐾𝐾 に減 らした – 既存の画像生成モデルとcomparableな結果を示した • 感想 – Diffusion Modelとの対応 • 𝑓𝑓-divergenceがKLの場合,Wasserstein Gradient FlowはFokker-Plank方程式に対応する らしい – 本当は 𝑇𝑇𝜃𝜃 , 𝑣𝑣𝜙𝜙 に対して条件があるはずだが,無視して良いのか? • 𝑇𝑇𝜃𝜃 は,ある凸関数 𝜓𝜓 に対して 𝑇𝑇 = ∇𝜓𝜓 が成り立つ • 𝑣𝑣𝜙𝜙 は,1-リプシッツ連続 (Appendixを見るとGradient Penaltyを使っている?) 16
Reference • Richard Jordan, David Kinderlehrer, and Felix Otto. The variational formulation of the fokkerplanck equation. SIAM journal on mathematical analysis, 1998, 29(1),1-17. • Jiaojiao Fan, Qinsheng Zhang, Amirhossein Taghvaei, and Yongxin Chen. Variational Wasserstein gradient flow. arXiv preprint arXiv:2112.02424, 2021. • Jaemoo Choi, Jaewoong Choi, and Myungjoo Kang. Generative modeling through the semi-dual formulation of unbalanced optimal transport. Thirty-seventh Conference on Neural Information Processing Systems, 2023. 17