5.4K Views
April 16, 24
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] “DENOISING DIFFUSION IMPLICIT MODELS” Gouki Minegishi, Matsuo Lab http://deeplearning.jp/ 1
書誌情報 Diffusionのサンプリングを高速化する技術 いろんなところで応用されているけど,難しそうなのでわからないので読みます 2
背景 • DDPMすごいけど遅い – e.g., 50kの画像⽣成に1000hoursかかる • GANとかの暗黙的⽣成モデルとのギャップを埋めるためにDDIMを提 案する • DDIMは – DDPMと同じ⽬的関数で学習できる – DDPMを⼀般化した⾮マルコフ連鎖の確率過程 – ⾮マルコフ連鎖の拡散過程としてマルコフ連鎖の拡散過程を短くできる – ODEとの関係(あんまわかんなかった) – EmpiricalにもDDIMは有効 3
DDPM • Forward Process • Reverse Process • Objective function(ELBO) 4
Non-Markovian Forward Process • Key Observation of DDPM Objectives – DDPMの⽬的関数𝐿! はq(𝑥" |𝑥# )に依存していて, q(𝑥$:& |𝑥# )の同時分布は関係ない • Non-Markovian Forward Process – 以下のような同時確率分布q ' (𝑥$:& |𝑥# )を導⼊(𝜎 ∈ ℝ&≧# ) – こうするとq(𝑥" |𝑥# )は変わらない – 拡散過程はbayseʼs ruleより以下のようになり,⾮マルコフ連鎖である – 𝜎はこの拡散過程がどれほどstochasticかをコントロールする • Extrame caseで𝜎 = 0を考えると𝑥!, 𝑥" がgivenの時, 𝑥"#$ は決定論的 5
Generative Process and Unified Objectivbe • Generative Process( 𝑝)" (𝑥"*$ |𝑥" ) ) – Niosyな観測𝑥" , 観測𝑥# がgivenの時に𝑥"*$ は, – 𝑥!は以下の𝑓% で予測できる – よって⽣成過程は以下となる • Unified Variational Inference Objective – 導⼊した同時確率分布q& 𝑥$:( 𝑥! でDDPMの⽬的関数を書いてみる – これをごちゃごちゃ式変形すると𝐿) +定数になる – つまりDDIMの最適化はDDPMと同じ⽬的関数で最適化できる • サンプリング時にあたかも別の推論(拡散)モデルを使った場合の逆拡散過程をたどるようにサンプリングする 6
Sampling from Generalized Generalization Process • DDIM – 𝑞! の逆拡散過程をたどるようにしてサンプリング – 特に の時,DDPMと同じ • つまりDDPMを特殊ケースとして⼀般化している – 𝜎 = 0の時,決定的な過程になる→暗黙的⽣成モデル(⽣成モデルをなんかの 確率モデルに仮定しない) 𝜎 = 0をDDIMと呼ぶ 7
Accelerated Generalization Process • 𝑞(𝑥" |𝑥# )が変わらなければ,objectiveは変わらないので,導⼊した𝑞' (𝑥" |𝑥# ) でより少ないTの forward processがあるかもしれない – 𝜏を[1, … 𝑇]の⻑さ𝑆で昇順のサブシーケンスを考える • 𝜏は補集合 ̅ – それぞれの要素は以下 – サブシーケンスのみの逆拡散過程は以下 – Objectiveを書き換えると以下になる(𝐿! と同様) 8
Implementation 実装も簡単でサンプリングの関数を以下にするだけ. def generalized_steps(x, seq, model, b, **kwargs): with torch.no_grad(): n = x.size(0) seq_next = [-1] + list(seq[:-1]) x0_preds = [] xs = [x] for i, j in zip(reversed(seq), reversed(seq_next)): t = (torch.ones(n) * i).to(x.device) next_t = (torch.ones(n) * j).to(x.device) at = compute_alpha(b, t.long()) at_next = compute_alpha(b, next_t.long()) xt = xs[-1].to('cuda') et = model(xt, t) x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() x0_preds.append(x0_t.to('cpu')) c1 = ( kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() ) c2 = ((1 - at_next) - c1 ** 2).sqrt() xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et xs.append(xt_next.to('cpu')) return xs, x0_preds 9
Experiments • いろんな𝜎で試す – の時,DDPMと同じ – Randomnessの⾼いDDPMも⽤意 – 𝜎 = 0の時,DDIMと呼ぶ(決定論的なので) • Results – 𝜎 = 0が⼀番いい – Tが⼤きい時は流⽯に勝てない 10
ODEとの関係 • DDIMの逆拡散過程を𝛿𝑡で表すと • 𝛿𝑡を⼩さくすると(連続にすると)以下のODEとして表せる – 𝜎 = 1 − 𝛼/ 𝛼, 𝑥̅ = 𝑥, / 𝛼とすると • ⼗分にstepsが⼤きかったら𝑥! から𝑥" もこのODEでシミュレートできる – これはDDIMがDDPMと異なり𝑥- として観測をエンコードしているとみなせる – ダウンストリームタスクとかに有効かもしれないらしい • そしてこれがVariance-Exploding SDEと対応しているらしい – AppendixB 11
なんで早くサンプリンできるの • あんまよくわかんなかった • gDDIM(ICLR2023) – DDIMがサンプル効率がいいのはempiricalに⽰されているが,よくわかんない – 1)why does solving probability flow ODE provide much higher sample quality than solving SDEs, when the number of steps is small? – 2) It is shown that stochastic DDIM reduces to marginal-equivalent SDE (Zhang & Chen, 2022), but its discretization scheme and mechanism of acceleration are still unclear. – 3) Can we generalize DDIMs to other DMs and achieve similar or even better acceleration results? – ざっくり • データ分布がデータが1点しかないディラック・デルタ分布の場合,サンプリング経路上で, 真のスコア関数に対応するデノイジング関数は定数であると証明.つまり⼀回のサンプリンで 復元できる • そしてこの時の経路はDDIMの復元経路と⼀致することが⽰されている • しかし,ディラック・デルタ分布以外ではわかってない 12