【DL輪読会】Multi Time Scale World Models

3.4K Views

October 24, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

Multi Time Scale World Models Tomoshi Iiyama, Matsuo Lab 1

2.

書誌情報 Multi Time Scale World Models ● Vaisakh Shaj, Saleh Gholam Zadeh, Ozan Demir, Luiz Ricardo Douat, Gerhard Neumann ● カールスルーエ工科大学,Bosch (ドイツ) ● NeurIPS 2023 Spotlight (top 3%) ● 論文 : https://arxiv.org/abs/2310.18534 実装 : https://github.com/ALRhub/MTS3

3.

書誌情報 Multi Time Scale World Models 概要 ● 複数の時間スケールで予測や推論を行う世界モデル ”MTS3” を提案 ● 高速/低速で動作する2つのSSMを組み合わせた確率的なフレームワーク ● 長期予測タスク(数秒後) や不確実性の推定において従来手法を上回る性能を達成

4.

背景 世界モデル ● 環境の遷移を予測する生成モデル ● 人間の小脳に形成されているといわれている内部モデルにヒントを得ている ● David Haの “World Models” [Ha+ 18] が火付け役となり、研究が進められている ● 強化学習と組み合わせた “Dreamer” [Hafner+ 19] などの手法が有名 現在時刻 次時刻 (想像) もしパドルを左に動かしたら どうなる...? 自分の行動で条件付けた未来を予測

5.

背景 従来の世界モデルの課題 ● ミリ秒単位の細かい周期で動作している → データの長期的な傾向やパターンを捉えられない ● 効率的に長期の予測や計画を行うには? ○ 複数の時間的抽象化レベルで予測できるモデルが必要 大域的な遷移をとらえる 高レベル (長い周期で動作) 低レベル (短い周期で動作) 時間

6.

背景 線形ガウス状態空間モデル 遷移モデル 潜在状態 エージェントの行動 観測モデル ● 観測→潜在観測へのエンコーダ (非線形)

7.

提案手法 MTS3: Multi Time Scale State Space Model ● 遅い時間スケールのSSM (赤色) ● 速い時間スケールのSSM (緑色) の2つで構成される 提案モデルが満たす性質 ● 複数の時間スケールでダイナミクスをモデル化できる ● 正確な長期予測&不確実性の推定ができる ● 確率的な定式化に基づきつつも、学習と推論がスケーラブル

8.

提案手法 MTS3: Multi Time Scale State Space Model 全体像

9.

提案手法 MTS3: Multi Time Scale State Space Model 全体像 遅いSSMの遷移 速いSSMの遷移

10.

提案手法 MTS3: Multi Time Scale State Space Model 全体像 ステップごとに抽象化

11.

提案手法 MTS3: Multi Time Scale State Space Model 全体像 : 時間ウィンドウ

12.

提案手法 速い時間スケールのSSM (低レベル) : タスク記述子 ● ○ ○ ステップの間固定される 後述の遅いSSMによって決定される

13.

提案手法 速い時間スケールのSSM (低レベル) ● 学習するパラメータ ○

14.

提案手法 遅い時間スケールのSSM (高レベル) : 抽象行動(低レベルの行動系列をエンコード) ● ○ : 抽象観測(低レベルの観測系列をエンコード) ● ○

15.

提案手法 遅い時間スケールのSSM (高レベル) ● 学習するパラメータ ○

16.

提案手法 MTS3の学習 対数尤度の最大化 ● 再構成ロス ○ 高レベルの遷移に基づいて低レベルの予測を行い、その潜在状態を元に観測を再構成

17.

提案手法 MTS3の学習 長期予測のための工夫 ● このロスは時刻 ○ ● このままでは1ステップ先の予測はできるようになっても、長期の予測には失敗してしまう そこで、長期予測の問題を 「欠損値」 問題として捉える ○ ● までの観測が全て手に入る前提になっている 未来の時刻の観測が 「欠損している」 ものとして扱う 観測の一部をランダムにマスクし、欠落した観測を補完するように学習させる

18.

実験 実験 長期予測において評価 1. 決定論的予測 ○ 2. 確率的予測 ○ 3. MTS3は長期の決定論的予測 (平均の推定) を正確に行うことができるか? MTS3は長期の確率的予測 (分散の推定) を正確に行うことができるか? チョイスの検証 ○ モデル設計時の仮定や学習方法はどのぐらい重要か?

19.

実験 データセット ① D4RL ● オフライン強化学習用のデータセット ● 3つの環境で検証 ○ HalfCheetah(6秒) ○ Franka Kitchen(3秒) ○ Medium Maze(4秒)

20.

実験 データセット ② マニピュレーション ● 実機から収集した2つのデータセット ● 2つの環境 ○ 掘削機(12秒) ○ Panda ロボット(2秒)

21.

実験 データセット ③ モバイルロボティクス ● 4輪ロボットのシミュレーション環境 ● 凸凹な地面を移動(3秒)

22.

実験 データセット 部分観測問題設定 ● 全てのデータセットにおいて、エージェントや物体の位置情報のみを観測として使用 ● 速度情報はマスクして消去

23.

実験 ベースライン RNN系 ● LSTM ● GRU RSSM系 ● RKN (Recurrent Kalman Networks) [Becker+ 2019] ● HiP-RSSM (Hidden Parameter Recurrent State Space Model) [Shaj+ 2022] Transformer系 ● 自己回帰型 ● Multi-step prediction [Zhou+ 2021]

24.

実験 ① 決定論的予測性能 ● 縦軸: 二乗平均平方根誤差 (RMSE) 横軸: 予測のhorizon ● 提案手法: 赤色 ● 全てのデータセットで一貫して良好な予測性能を発揮 ● 世界モデルで広く使用されているRNNは長期予測ができていない ● 通常のTransformer (自己回帰) は誤差が蓄積してしまう

25.

実験 ① 決定論的予測性能 ● 掘削機のデータセットにおいて予測された軌道 ● 黒線: 正解 青色: モデルの予測

26.

実験 ② 確率的予測性能 ● 不確実性を推定具合を負の対数尤度で比較 ● ほぼ全てのデータセットで最も正確に不確実性の表現を学習 ● ✗ : 値が高すぎる もしくは NaN

27.

実験 ③ チョイスの検証 3つの項目でアブレーション ● 青色: 抽象行動を用いなかった場合 ● 赤色: 潜在状態を「観測可能部分」と「メモリ部分」に分けなかった場合 ● 橙色: 観測のランダムマスクを行わなかった場合 ← 特に重要

28.

実験 ④ 抽象化の時間幅 ● の役割 が大きくなる → 高レベルの時間スケールが遅くなる ● 小さい区切り(2,3,5, 10ステップ) だと 性能が著しく劣化 ● 大きすぎる値でも悪化してしまう (75ステップ)

29.

実験 ④ 抽象化の時間幅 の役割 ● 掘削機データセットにおいて予測された軌道 ● 高レベルの状態の更新により、低レベルのダイナミクスが変化しているのがわかる (ジャンプ) ● 大きすぎる区切り(75ステップ)だと低レベルの誤差が蓄積してしまい、性能が低下してしまう

30.

結論 まとめと感想 ● 複数時間スケールで予測する世界モデル MTS3 を提案 ● ガウス状態空間モデルの確率論的な枠組みで定式化 ● 提案モデル(線形)の性能が、大規模なTransformerに匹敵することが示された ● 実験で扱っている観測は低次元の状態 → 高次元の観測(画像)だとどこまでできる? ● 高速&低速の2つのレベルで検証している → レベル数を増やせば性能も上がる? ● 階層的な潜在状態を学習している → 高レベルの潜在状態はどのような表現になっている?