[DLHacks]PyTorch, PixyzによるGenerative Query Networkの実装

296 Views

December 20, 18

スライド概要

2018/12/13
Deep Learning JP:
http://deeplearning.jp/hacks/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

PyTorch, Pixyz Generative Query Network 1

2.

• • • • G G – – – • 4 4 N 4 B 4Q 2

3.

. – – – – – R – – – – – ( A VQ D W ) N MP LS C 12 GT ) 3

4.

. – – – – – R – – – – – ( A VQ D W ) N MP LS C 12 GT ) 4

5.

E A C A DA D [DL輪読会]GQNと関連研究,世界モデルとの関係について https://www.slideshare.net/DeepLearningJP2016/dlgqn-111725780 5

6.

Variational Autoencoder ! ! " # E%('|)) " #. " . log / .# # Inference /(.|#) = E%('|)) log " # . . Generation "(#|.) − D34 [/ . # ||" . ] #7 6

7.

Variational Autoencoder !o td n o o y a e – o – ( • f: 9B: s 243i n V !T b df e o e QN b ! dcpe d xT o h B f u "d bd A v Qd e o o e o b – /G B 1A9B: A 6 BA 0 E E. – 5G B BA 6 BA /G B A9B: E E. A B 8E B ( 8E bT A9 T o V l g e o ic ea A9B: Qd o d ea bor b f e ) - , 7

8.

Conditional VAE !(#) • % e , • 0 E+(,|-,/) &('|%) T 1 S I M S A g e & )|% E ! 3 #, 4 ! #|4 log 5 # 3, 4 3, 4 V N. = E+(,|-,/) log ! 3 #, 4 Inference 5(#|3, 4) #, 4 A − D9:[5 # 3, 4 ||! #|4 ] Generation !(3|#, 4) 3= KL Prior 5(#|4) # 8

9.

Conditional VAE Q – G Q ! : e N ( G W )(+, |./..0, +/..0, ., ) • A A – "#$..& "# – D G e '#$..& e A : R ( '# D MN A ( ) 9

10.

DRAW • • R • D !E n W e Q ) " ! # = % "& (!& |#, !-& ) &'( • i AV D D ng ng D D D G A ! AE A G N 10

11.

DRAW log $ % ≥ E( ) % log $ % ) − D,- [/ ) % | $ ) 8 8 = E( ) % log $(%|)) − D,- [4 /5 )5 %, ):5 || 4 $5 )5 ):5 ] 8 = E( ) % log $ % ) 567 567 − < E( ):5 % [D,- /5 )5 %, ):5 $5 )5 ):5 567 8 ≅ E( ) % log $ % ) − < D,- [/5 ()5 |%, ):5 )||$5 ()5 |):5 )] 567 • 11

12.

DRAW for $ = 1 to ( Prior Distribution 01 21 231 = 45 21 ℎ7 Encoder RNN ℎ> = ?@@ >AB C, CE1 , ℎ> , ℎ7 Posterior Sample 21 ~ L1 21 C, 231 = 4M 2 ℎ> Decoder RNN ℎ7 = ?@@ 7>B 21 , ℎ7 KL Divergence DRS [L1 21 C, 231 ||01 21 |231 ] Canvas CE1 = CE1 + 4 Y ℎ7 Likelihood \(C|CE^ ) • • N 2 A 2 D R 12

13.

GQN DCG • E" # $ " , & '..) , $ '..) , & " log - $ " #, & '..) , $ '..) , & " 2 − / D45 [70 #0 $ " , & '..) , $ '..) , & " ||90 (#0 |& '..) , $ '..) , & " )] 01' • @ @ = = ∑) A($ ,& ) @1' G E" # $ " , =, & " log - $ " #, =, & " 2 − / D45 [70 #0 $ " , =, & " ||90 (#0 |=, & " )] 01' A 13

14.

GQN for $ = 1 to ( Prior Distribution 01 21 231 = 45 21 ℎ7 Encoder RNN ℎ> = ?@@ >AB C D , F D , G, H1 , ℎ> , ℎ7 Posterior Sample 21 ~ O1 21 C D , F D , G, 231 = 4D 2 ℎ> Decoder RNN ℎ7 = ?@@ 7>B F D , G, 21 , ℎ7 KL Divergence DTU [O1 21 C, 231 ||01 21 |231 ] Canvas H1 = H1 + Δ ℎ7 Likelihood ^(C D |4 ` (Ha )) • • FD, G AW D R A () ) D C A N 14

15.

• N 2 n N ! r –( ,10 )21, 20 1 l !" = $(& " , ( " ) - ! = * !" "+, – – ,10 P ,10 . ((N N – – G g u x o R R s 8 tPe N Pa o t ,10 Nu c 8 15

16.

( ) GQN G N – L 2 l N c2 – ( M G • c c c – ( u u A 2 s GL 2 a G c L e2 u M a c2 r L G ) • e2 M o e2 PL M M ( L LQ c L G M M 16

17.

. – – – – – R – – – – – ( A VQ D W ) N MP LS C 12 GT ) 17

18.

GQN . – n – v – – / T a P a . . / – . a – – :gfhk wu – ro cmib t . - e:l / / sp / – – qy P 18

19.

Pixyz • P • sw an vr Klb T e F P K )() ( L T E P g i E P to an d P v EL c D 19

20.

• • z P P E L E P co l a s i T s hc P E • Pr , m L ( a yE PAA I m T x I 20

21.

• r zu Ras N – ! 10 26 4D D D – -,, nb ", ℎ 10 26 4D D – - . 20 0 Wa % 10 26 4D – • -,, ki l –R 4 gehdwc 1 0 vW Wa_ o 84 4 -,, Ra x Ra 4 D a t x 260 4 34 D W D D p D 3 0: 4 A d m 0: 4 gehN-,, nb W 033 x d 21

22.

dataset/convert2torch.py • • • • .. :: . - 1 /1: - D -.. T .: . S . b P T 1 - / - : .: c a MDF c a 22

23.

gqn_dataset.py – _ 1 a 23

24.

representation.py 論文で提案されている3種類を実装 Poolが一番いいらしい 24

25.

conv_lstm.py 25

26.

Core , 入力のサイズを合わせてConvLSTMに入れる 26

27.

Distribution (Pixyz) ネットワークを確率分布で隠蔽するクラス ここでは正規分布なので、平均 (loc) と標準 偏差 (scale) を辞書形式で返す 初期化の際に、条件づける変数 (cond_var) と 変数 (var) をリストで渡す。 forwardの入力とcond_varを揃えることに注意 27

28.

model.py PyTorch Pixyz *_coreが自己回帰の部分を担うConvolutional LSTM Pixyzではeta_* の代わりにPriorなどのpixyz.distributionsクラスのインスタンスを立てる 28

29.

model.py PyTorch Pixyz Pixyzではネットワークを 確率モデルで隠蔽している ため、q.sampleなどとする だけで分布からのサンプリ ングが可能で、コードが読 みやすくなる! 29

30.

train.py 負の変分下限をロス関数として学習する。 学習率と生成画像の分散のアニーリングはここで行 う。 TensorBoardXを用いてログを保存 30

31.

train.py • loh ( G d R ruG d N G : y R ip G s 2: p c _ 12 G cf :21. 8:2 p fGnt _ P 2 T )) P FU a R G F 12 G eT 2 : 2 _ R e G cf 31

32.

2 1 – – – – [ , r a 0 _ T t _P > r _ lho zx 9 0 . TP a _ 0 r _ _ < -1T] -1 - , _ 0 , – – cm r t – < 1 g y T s [> T P T _ _ T_ TP , 0 >9, Ti 0 a ]T > > . 1 _ >9, ]ep npd _ i g P , a 32

33.

(Shepard-Metzler) Ground Truth Prediction 2週間ほど回し続けた結果(71.5万ステップ) ほぼ見分けがつかないレベルで生成できるようになった 論文では200万ステップ回している(こちらのリソースでは1ヶ月くらいかかる…) 33

34.

(Mazes) Ground Truth Prediction 学習時間約2日 34

35.

• 2B – – ( ) N L1 L 1 ( ! B) • 8 – • 0 B N NK2B 1 A1 i 4 N B! BG X0T 4N – 0 L 失敗するケース zのチャンネル数64 8A A 自己回帰8回で1日回した結果 35