-- Views
July 03, 25
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] Learning Transformer-based World Models with Contrastive Predictive Coding Eri Kuroda, Matsuo-Iwasawa Lab http://deeplearning.jp/
書誌情報 • Learning Transformer-based World Models with Contrastive Predictive Coding • Authors: Maxime Burchi, Radu Timofte • Conference: ICLR2025 poster (spotlight) • arXiv: https://arxiv.org/abs/2503.04416 • github: https://github.com/burchim/TWISTER • 行動条件付きcontrastive predictive coding(AC-CPC)を導入したTransformerベースの 世界モデル TWISTER を提案 • 長期的な特徴予測によってAtari 100kベンチマークで従来の手法を上回るサンプル効率と 性能を達成 ※ 本スライドの図表は元論文より引用 2
背景|深層強化学習 • 深層強化学習アルゴリズムの著しい発展 • ハードウェアシステムの計算力向上により、高次元データ(動画や画像)から強力なエージェント を学習するとき、DNNを使うことで大きく進展 隣接する潜在状態の • CNNはコンピュータビジョンのようなパターン認識能力が得意 cos類似度が高い • CNNを視覚ベースのRLにも応用 • Atariゲームや囲碁、マインクラフトなどでも人間以上の性能を達成 • 逆伝播を用いて世界モデルを学習するアプローチが増加 • RNNベース(Dreamerなど)から始まり、Transformerベースの世界モデルが提案 • これまでは時系列の関係性を捉えるという意味でRNNが多かった • TransformerはRNNよりも訓練効率がよく、大規模化に適している • Transformerベースは訓練効率が向上しても、性能への影響が限定的 • ビデオの連続フレーム間の違いが小さい → 複雑なモデルでなくても予測できる(cos類似度の高さ) • 長期予測のようなより複雑な表現を学習させることで、性能がもっとよくなるのでは 3
関連研究|モデルベースRL • モデルベースRL • エージェントが環境内でどのように行動するかをシミュレートできる環境モデル(world model)を使って学習 • 汎化性能、サンプル効率、計画による意思決定能力 が向上 • 多くの場合、勾配逆伝播(gradient backpropagation)で学習 • 初期:自己感覚のタスク(proprioceptive tasks)などの低次元観測 • 最近:高次元観測(画像、動画)から世界モデルを学習 • SimPLe • Kaiser et al., 2020 • CAEを用いてAtariゲームにおけるピクセル空間で学習 • 過去の観測フレームと選択された行動から、次のフレームと環境報酬を予測 • PlaNet • Hafner et al., 2019 • RSSMを使って潜在空間における世界モデルを学習 • GRU(Gated Recurrent Unit)を用い、確率的および決定論的な状態遷移を学習し、モデル予測制御(MPC) を使ってプランニング • VAEによるピクセル再構成損失を用いて、観測データを確率的な状態表現にエンコード 4
関連研究|モデルベースRL • Dreamer • Hafner et al., 2020 • 世界モデル表現からアクター(方策ネットワーク)とバリュー(価値ネットワーク)を学習 • DreamerV2 • Hafner et al., 2021 • Atariゲームへの適用が可能になり、世界モデル内でStraight-Through Gradientによるカテゴリカル潜在状態を 用いて性能を向上 • DreamerV3 • Hafner et al., 2023 • 異なるドメインでも同じハイパーパラメータで安定して学習できるようアーキテクチャを改良 • Symlog変換による報酬と価値関数のスケール調整 • レイヤー正規化 • リターンと価値関数の正規化 • MuZero, EfficientZero • Schrittwieser et al., 2020; Ye et al., 2021 • モンテカルロ木探索(MCTS)とWM 5
関連研究|TransformerベースWM • TransDreamer • Chen et al., 2022 • DreamerV3のRSSMをMasked Self-Attentionを使うTransformer状態空間モデル(TSSM)に置 き換え、未来の軌跡(trajectory)を想像するモデルを提案 • 長期記憶や推論能力が求められるHidden Order Discoveryタスクで評価 • Visual DeepMind Control、AtariなどのタスクでDreamerV2とほぼ同等の性能 • TWM • Robine et al., 2023 • Transformerベースの世界モデル • 状態・行動・報酬をそれぞれ独立した連続トークン列としてTransformerに入力する autoregressive(逐次予測型)手法 • デコーダは過去の隠れ状態を使わずに画像を再構成 • 画像再構成における時間的文脈情報を無視 6
関連研究|TransformerベースWM • STORM • Zhang et al., 2024 • Atari 100kベンチマークでDreamerV3に匹敵する性能、より高い学習効率 • 状態と行動を1つのトークンに融合 → トレーニング効率と性能が向上 • TWMが状態・行動・報酬を別トークンで処理 • IRIS • Micheli et al., 2023 • 入力画像を離散トークンに変換するVQ-VAEと、将来トークンを予測するautoregressive Transformerから構成 • データ量が少ない状況でもよい性能を示した • Δ-IRIS • Micheli et al., 2024 • 時刻間の確率的変化(Δ, デルタ)をエンコードする手法 • 前ステップの行動と画像を条件にして、エンコーダーとデコーダーがデルタを学習 • VQ-VAEの圧縮率と画像再構成性能が向上 • CrafterベンチマークでSOTAを達成、Atari 100kでもよい結果 7
関連研究|contrastive predictive coding: CPC • Contrastive Predictive Codeing (CPC) • Oord et al., 2018 https://arxiv.org/abs/1807.03748 • 時系列信号を隠れ表現にエンコードし、自己回帰モデルが将来のエンコードされた表現と 出力特徴量との間の相互情報量(Mutual Information)を最大化するように学習 • Noise-Contrastive Estimation(Gutmann & Hyvärinen, 2010)にもとづくInfoNCE損失 • 音声の音素分類、画像分類、テキスト分類、DeepMind Lab 3D環境の強化学習の4つの 異なるドメインで有用な表現を学習し、よい性能を発揮 • 音声、画像、テキスト分野:事前学習用の事前タスクとしてCPC利用 • DeepMind Lab:A3C(Mnih et al., 2016)エージェントの補助損失(auxiliary loss)として使用 • 本研究:Action-Conditioned CPC (AC-CPC) • CPC予測を将来の行動系列に条件付け(condition) することで、世界モデルの予測精度向上と、 より高品質な特徴表現の学習 8
手法|TWISTER • TWISTER • 行動条件付きContrastive Predictive Coding(AC-CPC)を用いたTransformerベースの強化 学習アルゴリズム • • 高レベルな特徴表現の学習とエージェント性能の向上を目的 3つの主要なニューラルネットワークから構成 • 世界モデル(World Model) • 画像観測を離散的かつ確率的な状態表現に変換し、環境のシミュレーションによって仮想軌道を生成 • アクターネットワーク(Actor Network) • WMが生成した潜在空間内の軌道をもとに、将来報酬の期待値を最大化する行動(Action)を選択するように学習 • クリティックネットワーク(Critic Network) • 潜在空間内で、行動の価値(Value)を 評価するために学習 9
手法|TWISTER – World Model
•
潜在空間(latent space)での世界モデルの学習
• 入力画像観測𝑜! をconvVAEで隠れ表現にエンコード
• 隠れ表現:32カテゴリ 32クラス/カテゴリ のロジットに線形投影 → 離散・確率的な𝑧!
•
Transformer State-Space Model(TSSM)
• Chen et al., 2022
• 過去の状態𝑧":! , 行動𝑎":!
→ 次の確率的状態𝑧!$" を予測
• Transformerネットワークの出力である隠れ状態ℎ! と
確率状態𝑧! を連結 → 最終的なモデル状態𝑠! = {ℎ! , 𝑧! }
•
予測するもの
• 環境報酬𝑟!
• エピソード継続確率予測𝑐!̂
• AC-CPC特徴量𝑒̂!%
• シンプルな多層パーセプトロンで実装
10
手法|TWISTER – World Model • Loss • バッチサイズ B • 系列調 T • パラメタ 𝜙 • 𝐿"#$ :報酬予測損失 • Dreamer v3のSymlogクロスエントロピーloss • 𝐿%&' :エピソード継続フラグの予測損失 • 続くなら1、終わるなら0 • 𝐿"#% :入力画像観測の再構成損失 • 二乗誤差 • 𝐿()' :次状態𝑧! の予測損失 • KLダイバージェンス • Transformerからの予測分布h→z、実際の次状態o→z • stop gradient • 𝛽!"# = 0.5, 𝛽$%& = 0.1 (regは正則化損失) • 𝐿%*% :AC-CPCコントラスト損失 • K:ステップ数(=10) 11
手法|TWISTER – Agent Behavior Learning • Criticネットワーク、Actorネットワーク • 世界モデルから生成された仮想軌跡を使って訓練 • DreamerV3のエージェント行動学習設定を採用 • サンプル軌跡 • モデル状態は、サンプリングされたシーケンスのバッチ次元と時間次元を平坦化 • 𝐵&'( = 𝐵×𝑇 • self-attentionのkey, value • エージェント行動学習フェーズで再利用 • 過去の文脈情報(context)を保持 • Transformerネットワークとダイナミクス予測ネットワークの ヘッド(h→z)を用いて、H=15ステップの軌跡を想像 • 各ステップの行動は、actorネットワークのカテゴリカル分布 からサンプリングして選択 12
手法|TWISTER – Agent Behavior Learning • Critic Learning • 世界モデルが予測した報酬と継続フラグを用いて得られた𝜆リターン(𝑅() )を目標値として Symlogクロスエントロピー損失を最小化するように学習 • ターゲットネットワークを使用せず、自身の予測結果だけにもとづいて、 予測範囲を超えたリターン(報酬)を推定 • Actor Learning • 予測されるリターン(報酬の合計)を最大化するような行動を選ぶように学習 • 学習中やデータ収集中の行動の多様性を保つために、ポリシーのエントロピーも最大化 13
実験|Atari 100k benchmark • Atari 100kベンチマークを使った実験 • SimPLe • Dreamer v3 • Transformerベースのモデル(TWM, IRIS, STORM, Δ-IRIS) • Atari 100kベンチマーク • Kaiser et al., 2020 • 限られたデータ量での強化学習エージェントの性能を評価するときに使われることが多い • 26種類 • 400k(40万フレーム、環境との対話回数10万ステップ) • 人間が約2時間(1.85時間)プレイしたのと同程度の環境ステップ • EfficientZero V2 • Wang et al., 2024 • 現時点のSoTA 14
実験|Atari 100k benchmark - 結果 • human-normalized(人間基準)スコアを使い、26ゲーム全体の平均値と中央値を比較 • +,#'! -%&"# ."+'(&/ -%&"# 𝑛𝑜𝑟𝑚𝑒 𝑠𝑐𝑜𝑟𝑒 = 01/+' -%&"# ."+'(&/ -%&"# 平均値と中央値の信頼区間 15
実験|Atari 100k benchmark – 考察 • TWISTER:平均162%、中央値77% • 探索(look-ahead search)を用いないモデルベースで最も高い • 得意なゲームの傾向 • 報酬に関する重要なオブジェクトが多いゲーム(Amidar、Bank Heist、Gopher、Ms Pacmanなど) • Breakout、Pong、Asterixのような小さな移動オブジェクトを含むゲーム • 理由 • AC-CPCが将来を正確に予測するために「ボールの位置」などに注目せざるを得なくなるため • → 小さい物体も無視せず予測でき、再構成損失(Lrec)だけで学習していた従来の欠点を回避 • IRIS、Δ-IRIS • IRIS系の手法は、画像を高品質に再構成することを重視 • 画像を空間的な離散表現(VQ-VAE)に変換し小さい物体の情報を捉える • → 再構成誤差が小さくなり、小さい物体の扱いが重要なゲームでもよい成績が出せた 16
実験|Atari 100k benchmark – ablation studies • 4つのablation studies • CPC予測ステップ数 • WMのアーキテクチャ • CPCの将来への行動シーケンス への条件付けの有無 • データ拡張の効果 17
実験|Atari 100k benchmark – ablation studies • CPCのステップ数 • CPCで何ステップ未来を予測するか、を変えて実験 • 10ステップ(約0.67秒)先:最もよい平均スコア • ステップ 増:平均値・中央値ともにスコア向上 → 15ステップ先を予測すると逆に精度が低下、特に中程度の難易度のゲームで下がる • WMアーキテクチャ(TSSM vs. RSSM) • TWISTERのTSSMを、DreamerV3のRSSMに置き換えて比較 • AC-CPCを使わない場合:両者の性能に大きな差なし • AC-CPCを使用する場合:Transformer(TSSM)の方が明確に性能が向上 • 理由 • Transformerは再帰なしで長期依存関係を捉えるself-attentionを持っており、階層的で高次な特徴表現を 学びやすい • RNNは長いシーケンスで勾配消失しやすく、収束も遅い 18
実験|Atari 100k benchmark – ablation studies • CPCの将来の行動シーケンスへの条件付けの有無 • CPCの予測を将来の行動列にもとづいて条件付けするかどうかを比較 • CPC予測精度と特徴表現の品質が向上 • 条件付けをなくすと、CPCは正のサンプルをうまく予測できなくなり学習が困難 • 複数ステップ先の予測では、条件付けがないと精度が下がり、性能向上が見られなくなった training・validatiopnの損失と予測精度 19
実験|Atari 100k benchmark – ablation studies • データ拡張の効果 • 過去の研究(Kharitonov et al., 2021):音声CPCでのデータ拡張が有効 • 本研究:画像でのデータ拡張 → 学習を難しくすることでよりよい特徴表現になるか • 拡張方法 • ランダムクロップ+リサイズ • ランダムシフト(上下左右最大4px移動) → 学習への影響は小さい • ランダムクロップ+リサイズが最も高いスコア向上に寄与 • 拡張なしだとAC-CPCの効果が弱まり、平均・中央値ともにスコアが下がる データ拡張が与える影響 20
結論 • TWISTERの提案 • 新しい強化学習エージェント • Transformerベースのモデルを用いて、行動条件付きコントラスト予測学習(AC-CPC)によって、 高次の時間的特徴表現(temporal feature representations)を学習 • 探索を使わないモデルベース手法の中で、Atari 100kベンチマークにおいてSOTAを達成 • 平均値162%、中央値77% • Transformerベースの世界モデルにおける、コントラスト表現学習(CPC)の効果の検証 • AC-CPCがエージェントの性能向上に大きく寄与 • データ拡張や将来の行動シーケンスによる条件付けが、AC-CPCの目的をより複雑にし、 高精度な未来予測と質の高い表現学習に重要 21