【DL輪読会】HarmonyDream: Task Harmonization Inside World Models

4K Views

August 08, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] HarmonyDream: Task Harmonization Inside World Models Presenter: Masahiro Suzuki, Matsuo Iwasawa Lab 2024/08/08 http://deeplearning.jp/ 1

2.

本発表情報  論文:  HarmonyDream: Task Harmonization Inside World Models(ICML2024)  著者:  Haoyu Ma, Jialong Wu, Ningya Feng, Chenjun Xiao, Dong Li, Jianye Hao, Jianmin Wang, Mingsheng Long  清華大学,Huawei,天津大学  既存の世界モデルのマルチタスク的側面と課題を明らかにし,それを解決するシンプルな手法 (HarmonyDream)を提案.  シンプルだが既存手法を大きく上回る結果. 2

3.

世界モデル  世界モデル(world model):  外界からの限られた観測を元に,世界の構造を近似するように学習するモデル.  観測から潜在表現を推論し,推論した表現から未来や未知のことを予測(生成)する.  将来が予測できることで,長期的な計画や意思決定をすることができる. 世界モデル 環境 近似 推論 観測 表現 予測  モデルベース強化学習の文脈では,世界モデルは外界と少ない相互作用でよい行動(方策)を獲得す るために重要.  1990年頃起源で,2018年頃から本格的な研究が進められている[Ha+ 18]. 3

4.

世界モデルの基本的な構造  基本的な世界モデルの構造  観測𝑜𝑡 から潜在状態表現𝑧𝑡 を推論する部分(エンコーダ)𝑞(𝑧𝑡 |𝑜𝑡 )と,現在の表現と行動から次の時刻 の表現を予測する部分(遷移モデル)  潜在状態表現が良い表現(観測の余分な情報を省いた低次元な表現など)になっていれば,より容易 に将来予測ができたり,良い方策を獲得することができる.  方策と世界モデルは交互に学習する. 行動𝑎𝑡 遷移モデル 𝑝(𝑧𝑡+1 |𝑧𝑡 , 𝑎𝑡 ) 状態表現𝑧𝑡 方策 𝜋(𝑎𝑡 |𝑧𝑡 ) エンコーダ 𝑞(𝑧𝑡 |𝑜𝑡 ) 観測𝑜𝑡 https://github.com/google-research/dreamer [Kaiser+ 20] 4

5.

世界モデルでできること  長期的な将来予測ができる:  ある時刻にある行動をとるとどうなるのかを想像できる[Hafner+ 21]  環境との少ない相互作用で方策を学習できる:  Minecraftにおけるダイアモンド収集タスクで,人間のデモンストレーションを使わずに達成 [Hafner+ 23]  ロボットの方策を世界モデルによる「夢」だけで学習することができる[Wu+ 22]  自動運転への応用[Hu+23] 5

6.

動画生成AIは世界モデルか?  Sora[Brooks+ 24]は,現実世界のシミュレータとして重要な要素が創発したと主張している.  3次元的な一貫性・長時間の一貫性・環境への相互作用・デジタル世界のシミュレーション https://openai.com/index/video-generationmodels-as-world-simulators/  一方で,LeCunはピクセル生成による世界のモデリングは「無駄が多く失敗する」と主張して いる.  It's much more desirable to generate *abstract representations* of those continuations that eliminate details in the scene that are irrelevant to any action we might want to take. https://twitter.com/ylecun/status/1759486703696318935 https://twitter.com/ylecun/status/1758740106955952191 6

7.

世界モデルの学習:観測モデリングvs報酬モデリング  世界モデル(エンコーダと遷移モデル)をどのように学習するか?  大きく2つの方向性がある.  観測モデリング(明示的MBRL):𝑝(𝑜𝑡+1|𝑜𝑡 , 𝑎𝑡 )  主に観測を予測するように世界モデルを学習する.  タスクに依存せずに,環境の遷移そのもののモデル化を目指す.  代表的な手法:Dreamer,Soraなど  報酬モデリング(暗黙的MBRL):𝑝(𝑟𝑡+1|𝑜𝑡 , 𝑎𝑡 )  報酬などタスクの目的を予測するように世界モデルを学習する.  タスクに依存したモデルを学習する(タスクに関係ない環境の遷移は学習しない)  代表的な手法:TD-MPCなど ※観測モデリングの中でも,Dreamerは(後述するように)報酬モデリングを含んでいるが,観測モデリングの割合が大きい. ※DreamerはSoraと異なり,明示的に状態表現を学習している. 7

8.

観測モデリングの例:Dreamer  Dreamer[Hafner+ 20]  世界モデルとして以下を用意  エンコーダ:𝑧𝑡 ∼ 𝑞𝜃 𝑧𝑡 ∣ 𝑧𝑡−1, 𝑎𝑡−1, 𝑜𝑡  遷移モデル:𝑧ƶ𝑡 ∼ 𝑝𝜃 𝑧ƶ𝑡 ∣ 𝑧𝑡−1 , 𝑎𝑡−1  観測モデル:𝑜ƶ 𝑡 ∼ 𝑝𝜃 𝑜ƶ 𝑡 ∣ 𝑧𝑡  報酬モデル:𝑟ƶ𝑡 ∼ 𝑝𝜃 𝑟ƶ𝑡 ∣ 𝑧𝑡  エンコーダや遷移モデルにはRNNを用いて過去の情報を含めている(recurrent state space model)  行動選択のために,価値関数𝑣(𝑧𝑡 )と方策𝑞 𝑎𝑡 𝑧𝑡 を用意  Variational autoencoderの枠組みで,以下の目的を最小化するように世界モデルを学習. 𝑇 −𝔼𝑞𝜃 𝑧1:𝑇 ∣𝑎1:𝑇 ,𝑜1:𝑇 ෍ ln 𝑝𝜃 𝑜𝑡 ∣ 𝑧𝑡 + ln 𝑝𝜃 𝑟𝑡 ∣ 𝑧𝑡 − 𝐷𝐾𝐿 𝑞𝜃 𝑧𝑡 ∣ 𝑧𝑡−1, 𝑎𝑡−1, 𝑜𝑡 ∥ 𝑝𝜃 𝑧ƶ𝑡 ∣ 𝑧𝑡−1 , 𝑎𝑡−1 𝑡=1 観測損失 報酬損失 遷移損失 8

9.

観測モデリングの例:Dreamer  行動選択:世界モデルからのサンプリング上でactor-criticで方策𝑞 𝑎𝑡 𝑧𝑡 と価値関数𝑣(𝑧𝑡 )を学習.  全て微分可能な世界モデル上の想像に基づいて計算されるので,方策の勾配が計算できる.  難しい視覚的制御において,高いサンプル効率や性能を発揮.  Dreamer v2[Hafner+ 21]  Dreamerの状態表現を離散表現にして,正則化の学習部分を工夫する(Priorの学習率を大きくする)こと で,Atariにおいてモデルフリーを大幅に上回る結果を出した.  Dreamer v3[Hafner+ 23]  モデルを大規模にし, symlog関数による予測対象の正規化などの工夫を入れることで, Minecraft のダイ ヤモンド収集タスクのような複雑なタスクを(人間のデモ等を使わずに)初めて解いた. 9

10.

報酬モデリングの例:TD-MPC  TD-MPC[Hafner+ 20]  世界モデルとして以下を用意  エンコーダ:𝑧𝑡 = ℎ𝜃 𝑜𝑡  遷移モデル:𝑧ƶ𝑡 = 𝑑𝜃 𝑧𝑡−1 , 𝑎𝑡−1  報酬モデル:𝑟ƶ𝑡 = 𝑅𝜃 𝑧𝑡 , 𝑎𝑡  Dreamerと違い観測モデルがない(task-oriented latent dynamics model)  行動選択のために,価値関数𝑄𝜃 (𝑧𝑡 , 𝑎𝑡 )と方策𝜋𝜃 (𝑧𝑡 )を用意  以下の目的を最小化するように学習 2 𝑐1 𝑅𝜃 z𝑖 , a𝑖 − 𝑟𝑖 2 + 𝑐2 報酬損失 𝑄𝜃 z𝑖 , a𝑖 − 𝑟𝑖 + 𝛾𝑄𝜃− z𝑖+1 , 𝜋𝜃 z𝑖+1 TD誤差 2 2 + 𝑐3 𝑑𝜃 z𝑖 , a𝑖 − ℎ𝜃− s𝑖+1 2 2 遷移モデルとエンコーダに よる状態の一貫性損失 10

11.

報酬モデリングの例:TD-MPC  行動選択:世界モデル上でのモデル予測制御(MPC)と方策𝜋𝜃 (𝑧𝑡 )の組み合わせ  MPC:ある時刻の行動選択の際に,有限期間までの予測に基づく軌道最適化をした上で求める.  入力からタスク埋め込みを推論し,世界モデルに条件づけることで,多様なタスクを単一モデ ルで学習することができる(TD-MPC2)[Hansen+ 23]  現状のモデルベース強化学習の中では最も性能の良い手法の一つ. 11

12.

世界モデルの学習:観測モデリングvs報酬モデリング  観測モデリングと報酬モデリングのどちらがいいのか?  TD-MPCの例から,報酬モデリングが良さそう?  本研究では,観測モデリング(Dreamer)をベースとして,モデリングの違いを分析する.  Dreamerはマルチタスク的な学習しているが,各タスクの重みづけでどのように性能が変わるのか? 𝑤𝑜 ℒ𝑜 𝜃 + 𝑤𝑟 ℒ𝑟 𝜃 + 𝑤𝑑 ℒ𝑑 𝜃 観測損失 報酬損失 遷移損失  上記の考察に基づき,シンプルな改善手法を提案する(HarmonyDream) 12

13.

世界モデルにおけるマルチタスクの分析  Finding 1:報酬損失の係数の調整によってサンプル効率に大きな影響がでる. 𝑤𝑜 ℒ𝑜 𝜃 + 𝑤𝑟 ℒ𝑟 𝜃 + 𝑤𝑑 ℒ𝑑 𝜃  Meta-worldの操作タスクで検証  タスクによって異なるが,適切に係数を大きく調整すればサンプル効率が向上する. 13

14.

世界モデルにおけるマルチタスクの分析  Finding 2:観測モデリングを大きくすると,報酬予測が誤ったことに気づかずに,偽の相互作 用を確立してしまう.  赤の報酬が誤った報酬予測.  観測モデリングの影響が大きいと,レバーの予測に失敗する(ハルシネーション) 14

15.

世界モデルにおけるマルチタスクの分析  Finding 3:観測なしで報酬のみで学習することは,サンプル効率の良いモデルベース学習を実 現する上では不十分.  Meta-worldの操作タスクの再掲  観測損失を0にすると,サンプル効率が悪化する(表現学習能力の劣化).  その他の報酬ベースモデリングではどうなのか?->TD-MPCと同じくらい(後述) 15

16.

HarmonyDream  ここまでの検証から,各損失のバランスを適切に取ることで,観測モデリングと報酬モデリング の利点を活かせそう.  観測モデリング:表現学習能力を促進する.  報酬モデリング:タスク依存の表現を高める.  単純なアプローチとして,各項のスケールで係数を正規化 𝑤𝑜 ℒ𝑜 𝜃 + 𝑤𝑟 ℒ𝑟 𝜃 + 𝑤𝑑 ℒ𝑑 𝜃 観測損失 報酬損失 遷移損失 1 𝑤𝑖 = sg ,𝑖 ∈ 𝑜 𝑟 𝑑 ℒ𝑖  実際にはミニバッチから計算するだけなので,異常値の影響を受けやすく,不安定になる可能性がある.  遷移損失も上記の方法でスケーリングした方がいいことがappendixの実験で示されている(通常は0.1 に設定) 16

17.

Harmonious Loss  変分法による以下の目的を導入する(harmonious loss). ℒ 𝜃, 𝜎𝑜 , 𝜎𝑟 , 𝜎𝑑 = ෍ ℋ ℒ𝑖 𝜃 , 𝜎𝑖 = 𝑖∈{𝑜,𝑟,𝑑} 1 ෍ ℒ𝑖 𝜃 + log 𝜎𝑖 𝜎𝑖 𝑖∈{𝑜,𝑟,𝑑}  動的かつスムーズに異なる損失をリスケールできる.  𝜎𝑖 > 0は学習可能パラメータで,期待損失𝔼[ℋ(ℒ, 𝜎)]の最適解𝜎 ∗ は以下を満たす. 𝜎∗ = 𝔼 ℒ ℒ 𝔼 ∗ =1 𝜎 17

18.

Rectified Harmonious Loss  最適化の際の符号の制約をなくすため,以下のように設定. 𝜎𝑖 = exp 𝑠𝑖 > 0  𝜎𝑖 が小さくなり係数が大きくなるのを防ぐため,正規化項に定数を足す(rectification) 1 ෍ ℒ𝑖 𝜃 + log(1 + 𝜎𝑖 ) 𝜎𝑖 𝑖∈{𝑜,𝑟,𝑑}  これによって,それぞれの項が大きくなりすぎるのを防げる. 𝔼 ℒΤ𝜎 ∗ = 2 1 + 1 + 4Τ𝔼 ℒ <1 18

19.

実験タスク  HarmonyDreamはDreamerV2をベース.  比較手法: DreamerV2(重み係数が1)  それ以外のMBRL(DreamerV3, DreamerProなど)とも比較 19

20.

実験結果: Meta-world & RLBench  Dreamer V2と比較すると,いずれのタスクもサンプル効率や性能が高い 20

21.

実験結果: DMC Remastered  タスク: Cheetah Run, Walker Run, Cartpole Balance  視覚の多様性のために,各試行でランダムに背景画像が生成されている[Grigsby+ 20]  Dreamer V2よりもサンプル効率や性能が高い 21

22.

実験結果: DMC Remastered  DreamerV2では,遷移損失が学習が進むことによって大きくなっていることを確認.  Dreamerでは観測モデリングが困難なため学習過程が阻害されている.  提案手法はタスク依存の遷移を学習しているので,より学習しやすくなっている. 22

23.

他のモデルベースRLでの検証  DreamerV3[Hafner+ 23]とDreamerPro[Deng+ 22]で検証  DreamerProでは報酬重みが1000に設定されている.  提案法自体はこれらと直交しているので,それらに提案法を導入した場合の性能と比較.  DreamerV3(左)  Harmony lossはDreamerV3でも有効.  DreamerPro(右)  既存の報酬重み(1000)よりも今回の自動調整の方が良い結果. 23

24.

他のモデルベースRLでの検証  Atari100KでのDreamerV3との比較  Hermony lossによって性能は大幅に改善.  26タスク中23タスクで上回っている. 24

25.

他のモデルベースRLでの検証  Minecraft(MineDojo[Fan+ 22]でのHunt Cowタスク)でのDreamerV3との比較  提案手法は1M interaction以内で概ね学習できている. 25

26.

暗黙的MBRLとの比較  報酬モデリング手法であるTD-MPC[Hansen+ 22]と比較.  HarmonyDreamの方が,TD-MPCよりも良い性能->観測モデリングによる表現学習の効果 26

27.

Dreamerベースのタスク依存の手法  Denoised MDP[Wang+22]  報酬依存でコントロール可能な表現を分離するように学習  RePo[Zhu+23]  情報ボトルネックに基づき遷移と報酬の予測を最大化するように学習しつつ,テスト時適応を行う.  実験結果(左:Denoised MDP,右:RePo)  通常の背景だとRePoと提案手法は同じくらいだが,背景がランダムになると提案手法が高い性能. 27

28.

マルチタスク学習手法での検証  マルチタスク学習では,タスクのバランスを取るための工夫が提案されている.  損失ベースと勾配ベースがあるが,ここでは損失ベースの既存手法と比較  Uncertainty Weighting[Kendall+18]:出力の不確実性を測定してバランスする.  ビクセルごとに見るので,次元サイズの差を見落とす可能性.  Dynamics Weight Average[Liu+19]:タスクの学習進捗に応じてバランスする.  リプレイバッファが増加して非定常なので,学習進捗を正確に測定できない可能性.  NashMTL[Navon+ 22]:最適化の方向を個別の勾配方向の射影でバランスする.  実装が複雑な上,実験的に最適化パラメータの慎重な調整が必要であることを確認.  提案手法が全体的に最も良い結果 28

29.

Ablation studies  遷移損失のスケーリングは効いているのか  手動の重み調整と比較して良い結果なのか 29

30.

Ablation studies  Rectification( log(1 + 𝜎𝑖 ) )は効いているのか. 30

31.

まとめ  観測モデリングと報酬モデリングをマルチタスクと捉えてバランスすることで,単純だが効果 的な世界モデル学習手法を提案.  観測モデリングが表現学習のために重要だが,観測のみに重みづけが大きくなっているタスクに有効.  細かいタスク関連の観測や,タスクに関係ない変化する観測の場合など.  その他世界モデルに関するICMLの論文  AD3: Implicit Action is the Key for World Models to Distinguish the Diverse Visual Distractors  Hieros: Hierarchical Imagination on Structured State Space Sequence World Models  Learning Latent Dynamic Robust Representations for World Models  Improving Token-Based World Models with Parallel Observation Prediction  Do Transformer World Models Give Better Policy Gradients? 31