【DL輪読会】Stop Regressing: Training Value Functions via Classification for Scalable Deep RL

1.4K Views

May 02, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

Stop Regressing: Training Value Functions via Classification for Scalable Deep RL DL輪読会用資料 萩原 誠

2.

概要 課題 目的関数に平均二乗誤差などを使う回帰を行う価値関数ベースの強化学習をスケーリングするのは難しい 解決方法 価値関数の学習に回帰の代わりに分類タスクで使うカテゴリカルクロスエントロピーロスによる学習を実施 結果 スケーリング能力と性能が向上 検証方法 多様なドメインでの実験Atari 2600, Atrai(マルチタスクRL), robot制御, 探索をしないChess, language agent Wordle なぜうまくいった? ノイズや非定常性などに起因する価値ベースRL固有の問題を緩和するから

3.

次のような < , , ℛ, > を考える エージェントは行動At ∈ : × → を現在の状態St ∈ に従い、遷移先の状態St+1 ∈ でとって報酬Rt+1 ∈ ℛを受け取り、環境の遷移確率 に遷移する エージェントの目的 期待リターンGtを最大化する方策π : リターン(割引付き報酬和):Gt = ∞ ∑ k=0 → ℱ( )を学習すること γ kRt+k+1 where γ ∈ [0,1) 行動価値関数 qπ(s, a) = Eπ[Gt ∣ St = s, At = a]は方策πに従って状態で行動をとった場合に得られる価値関数の期待値をと 𝒮 𝒫 𝒜 𝒜 𝒜 𝒮 𝒮 𝒮 𝒜 𝒮 ってくることができる関数 𝒮 𝒫 強化学習の定式化(簡略化)

4.

論文のターゲット 価値関数ベース actor-critic オンラインRL ○(DQN) ○ オフラインRL ○(CQL) ○ オンラインRL:環境との相互作用がありデータセットが可変 オフラインRL:環境との相互作用がなく固定されたデータから学習する 価値関数ベース:方策が得られる価値を表現する関数を学習し、方策の関数は明示的に学習せず価値関数を使用して行動を決定する。 actor-critic:方策の関数と方策に従う場合に得られる価値の関数をそれぞれ明示的に学習する。

5.

𝒟 オンラインRL オンラインRL:環境との相互作用がありデータセットが可変 DQN 行動価値関数をQ(s, a; θ) ≈ qπ*(s, a)の形でパラメータθで表現される近似最適状態行動価値関数を学習する データセット からサンプルした遷移情報のサンプル(St, At, Rt+1, St+1)の平均二乗TD誤差TDMSE(θ)を最小化する TDMSE(θ) = E [(( ̂ Q)(St, At; θ −) − Q(St, At; θ))2] ターゲットネットワーク:モデルの更新の際に、パラメータθよりも更新を遅らせたパラメータθ −を持つモデルを更新に活用する ことで学習の安定化を図る手法(θ −を持つモデルの方をパラメータθをn回更新した際に同期することで実現する) ( ̂ Q)(St, At; θ −) = Rt+1 + γ max Q(St+1, a′; θ −) ∣ St = s, At = a a′ スカラーの回帰の目標値として定義されるベルマン最適作要素のサンプリングベースのバージョンに相当する。 大半の深層強化学習アルゴリズムは平均二乗TD誤差を最小化する方法の派生版で価値関数を学習する  𝒯 𝒯 𝒟  目標となる価値関数のネットワークから取得した推定値を回帰に利用する(ブートストラッピング)

6.

オフラインRL オフラインRL:環境との相互作用がなく固定されたデータから学習する CQL(オフラインRLにおける典型的な手法) 以下の目的関数を使用する min α(E[log( θ ∑  𝒟  a′ exp(Q(St+1, a′; θ)))] − E [Q(St, At; θ)]) + TDMSE(θ)

7.

回帰問題からクラス分類へ x ∈ Rd, Y ∣ x ∼ N(μ = y;̂ θ, σ 2) 入力に対して何らかのガウスノイズがのるような条件付き分布Y ∣ xをモデルに学習させたい σ 2:分散, 推論の関数 ŷ : Rd × Rk → R 関数はθ ∈ Rkでパラメータ化される 尤度を最大化するために、データ{xi, yi}Ni=1に対して平均二乗誤差(MSE)を考える min θ N ∑ i=1 ̂ i; θ) − yi)2 y(x ̂ ̂ θ*) = E[Y ∣ x]を満たすことになる y(x; 最適な関数yが学習できると 条件付き分布Y ∣ xの平均を直接学習する代わりの方法は次のような方法がある。その方法は何らかの目標値に関連する分布(ガウス分布じゃな ̂ Y ∣ xの平均値にyに関して確率密度関数p(y ∣ x)を くても良い)を一つ学習して、次に、その分布から推定値yを計算してやり、最後に、目標分布を 持つ分布として構築することである 例としては分布pを学習して、分布pからサンプリング等を行いy = Ep[Y ∣ x]を計算することで、Y ∣ xの平均値を計算するものがあげられる

8.

𝒵 𝒵 クラス分類として定義する方法 ̂ ∣ x; θ)を学習 回帰問題を目標値である分布p(y ∣ x)に対するKLダイバージェンスを最小化するような、θのパラメータで表現される分布p(y する問題として捉えることができる min θ N ∑∫ i=1 ̂ ∣ xi; θ))dy p(y ∣ xi)log( p(y ̂ θ) = Ep[Y ̂ y(x; ∣ x; θ]として推定値の目標で これがクロスエントロピー誤差による目的関数であり、この目的関数で学習した結果pを用いて ̂ あるY ∣ xの平均が計算される ̂ 分布pを学習する問題を微分可能な損失関数の形で与えられるように、pに対して以下のカテゴリカル分布となるように制約を加える [vmin, vmax]の範囲をm分割して、位置あるいはクラスという形でvmin ≤ z1 < ⋯ < zm ≤ vmaxについて、 ={ m ∑ i=1 piδzi : pi ≥ 0, m ∑ i=1 pi = 1} ここでpiは位置ziに所属する確率、δziは位置ziにおけるDiracのデルタ関数 𝒴 最後にやるべきことは目標の分布Y ∣ xを構築する方法とカテゴリカル分布 上に射映してやる方法を考えることである

9.

𝒵𝒵 価値関数に対応するカテゴリカル分布 価値関数Qをカテゴリカル分布 の期待値として表現する。(行動価値関数が最終的に復元したい対象と言える。) カテゴリカル分布 は確率分布pî (s, a; θ)に従うものとして表現する。 Q(s, a; θ) = E[Z(s, a; θ)], Z(s, a; θ) = m ∑ i=1 pî (s, a; θ) ⋅ δzi, pî (s, a; θ) = exp(li(s, a, ; θ)) m ∑j=1 exp(lj(s, a, ; θ)) z1, ⋯ . zmで表現されるカテゴリカル分布が目的の分布になっている(つまり、あるカテゴリカル分布について、カテゴリ カル分布から価値関数を表現できるものが存在する)必要がある 目標の分布もカテゴリカル分布になっていれば、直接クロスエントロピー誤差が以下の形で表現できる TDCE(θ) = E [ m ∑ i=1 pi(St, At; θ −)log pî (St, At; θ)] 𝒯 𝒟 目標の確率であるpiは m ∑ i=1 pi(St, At; θ −)zi ≈ ( ̂ Q)(St . At; θ −)を満たすように定義される。

10.

one-hotによるカテゴリカル分布の構築 スカラーの目標値である( ̂ Q)(St, At; θ)を{zi}m i=1でサポートされるカテゴリカル分布に投影(projection)する m箇所のbinに分けて、i番目のbinの中央の値をziが表現するようにone-hotで離散化する方法を考える このアプローチで生じるのは、元のスカラーの目標値に対して表現できない数値があるという意味で、情報 が失われる分布であり、Q関数に誤差を挿入することになる one-hot表現で生じた誤差はベルマンバックアップ(現在の状態の推定値の更新に遷移先の状態の推定値を用 いる方法)と呼ばれる更新の操作を行うと誤差が合成されてよりバイアスがかかり推定値の精度が悪化する 上記の問題を回避するために、two-hotのアプローチを利用する 目標値が含まれる区間となる二つのbinの値に対して確率密度を0より大きな値を取るように設定してやるこ とで目標の値を正確に表現する 𝒯 例. binを1~10の整数とした場合に、1.7という値に対して z1 = 1.z2 = 2.0,p1 = 0.3,p2 = 0.7, → 0.3 ⋅ 1.0 + 0.7 ⋅ 2.0 = 1.7

11.

two-hotによるカテゴリカル分布の構築 ziとzi+1をTDの目標値の下界と上界として、以下の式を満たすとする − ̂ zi ≤ ( Q)(St, At, θ ) ≤ zi+1 加えて、これらの位置での確率piとpi+1を以下の以下のように取る − − ̂ ̂ z − ( Q)(S , A ; θ ) ( Q)(S , A ; θ ) − z i+1 t t t t i pi(St, At; θ −) = , pi+1(St, At; θ −) = zi+1 − zi zi+1 − zi 他のすべての位置ではカテゴリカル分布で表現される確率は0になる 𝒯 𝒯 𝒯 two-hotの変換は表現方法が一意に定まり、TDの目標値とカテゴリカル分布の間が可逆な 変換になる

12.

HL-Gauss two-hotの問題点 離散回帰に基づく順序構造(例.アンケートにおける満足~不満)を持っているがそれを考慮しきれていない 順序構造:クラスが独立ではなく、順序があり、互いのクラスがその近傍のクラスに関連性を持つ HL-gauss 目標に近い複数のbinに対して確率密度を0より大きな値を取れるようにした 確率変数Y ∣ St, Atに対する密度関数:fY∣St,At、期待値が( ̂ Q)(St, At, θ −)となる累積分布関数FY∣St,Atを定義する 分布Y ∣ St, Atを幅:ζ = (vmax − vmin)/m, 中央:ziとなるbinを持つヒストグラム上に投影するために以下の積分を行い確率を計算する zi+ 2ζ ζ ζ pi(St, At; θ ) = fY∣St,At(y ∣ St, At)dy = FY∣St,At(zi + ∣ St, At) − FY∣St,At(zi − ∣ St, At) ∫z − ζ 2 2 i − 2 ガウス分布Y ∣ St, At ∼ 𝒯 𝒩 𝒯 用する。 (μ = ( ̂ Q)(St, At; θ −), σ 2)を用いて、σ 2によりlabel smoothingの度合いを制御してカテゴリカル分布に適

13.

distributional RLによる分布の学習 distributional RL:カテゴリカルモデル を用いて、将来のリターンに対する分布を直接モデル化する(もっと表現力を高く分布を学習する) C51:TDの目標値に類似した値に関する分布と予測された に対する分布の間のクロスエントロピー誤差を最小化する上でカテゴリカルな表現を使用 カテゴリカルなリターンの分布をモデル化するために 上におけるstochastic distributional Bellman operatorを定義する ( ̂ Z)(s, a; θ −) ≜ m ∑ i=1 pî (St+1 . At+1; θ −) ⋅ δRt+1+γzi ∣ St = s, At = a, At+1 = arg max Q(St+1, a′) a′ 作要素はカテゴリカル分布 に対して位置ziに対して確率のスケリーング(binのサイズを調整)と0以上の確率をとる位置の範囲を移動(分布を移動 させる)する効果がある この射影は近傍の位置ziに応じて、確率を割り当てる。近傍の位置を識別するために⌊x⌋ = arg max {zi : zi ≤ x}と⌈x⌉ = arg min {zi : zi ≥ x}を定義 して、確率を以下のように書き直せる pi(St, At; θ −) =  𝒵𝒵𝒵 𝒵  𝒯 ξj(x) = m ∑ j=1 x − zj zj+1 − zj pĵ (St+1, At+1; θ −) ⋅ ξj(Rt+1 + γzi) 1{⌊x⌋ = zj} + zj+1 − x zj+1 − zj 1{⌈x⌉ = zj}

14.

実験タスク • Atari 2600 • RLの実験でよく使われるゲームで色々なゲームがある • ロボット操作 • 7DoFのロボットを制御して、17種類の台所のオブジェクトを正確に掴み上げる • チェス • language agent: Wordle • 6回以内に与えられた文字数の単語を当てるゲーム • 使用文字かどうかと文字の場所があってるかのヒント

15.

AtariにおけるSingle-Task RLでの検証 回帰ベースとの性能比較、分布をベースとした学習手法間の性能比較 optimizer Adam 評価指標 95% strati ed bootstrap con dence intervalsした上での四分位数平均(IQM)正規化スコア シード数 n=5 対象タスク オンラインRL: 60種類のゲーム 人間のスコアで正規化されたスコアをベースに計算 fi fi オフラインRL: 17種類のゲーム データセットを作成した方策で正規化されたスコアをベースに計算

16.

スケーリングに対する実験(MoE) • ImPalaの最後から2番目の層をSoftMoEのモジュールで置き換える。 • エキスパートの数は{1,2,4,8}で各エキスパートはImapalaの元のレイヤーと 同じパラメータ構成 • 20種類のAtariのゲームで検証した(MoEベースのRLの既存研究を踏襲)

17.

スケーリングに対する実験(ResNet) • Atariの2種のゲーム(Asteroids(63種類と Space Invaders(29種類))で検証 • ImapalaのアーテクチャをCNN(パラメータ数2M以下)からResNet101(44M パラメータ)までそれぞれ検証した • シード数 n=5

18.

マルチゲームオフラインRLの検証 • 40のゲームについてゲームごとに独立に学習したエージェントのほぼ最適な 行動データを用いて40のゲームをプレイする1つの方策を学習する(下図左) • モデルのパラメータ数の増加に対して性能がスケーリングするかを確認する

19.

Transformersにおける価値ベースRLの検証 • 以下のゲームで検証 • Wordle • 探索なしのチェス • ロボット操作

20.

Wordleでの検証 • GPTみたいなdecoder-onlyのtransformer(パラメータ数125M)を学習する • suboptimalなゲームデータを用いてオフラインRLとして学習する • 20Kの勾配更新をするように学習する • DQNの更新とCQLで使われるデータ外の行動に対して行動価値が下がるよう なアプローチ(式2.4の第一項)を組み合わせた方法を使用する

21.

探索なしのチェスでの検証 • transfomerはアルゴリズムの蒸留を行う上で汎用的に活用できる • stock sh 16と呼ばれるチェスAIが持つ行動価値関数をcausal transformerに蒸留するためにstock sh 16で1000万のチェスのゲームデータに対して価値のアノテーションを行い、150億の手番データを作成 した • パラメータ数の異なる3種類のtransformer(9M, 137M, 270M)を学習 • オンラインチェスサイトのリーチェスにある10000のチェスパズルを解く能力で評価した fi fi • 推論時に400回のMCTSのシミュレーションを実行するAlphazeroと同等の性能を示した

22.

ロボット操作のタスクでの検証 • 4万エピソードの人間が遠隔操作した場合のデータに対して行動のノイズを乗 せて50万エピソードのデモンストレーションを作成した • 60MのQ-trasnformerで学習した

23.

ablation study • softmax関数を適用するとlogitsに相当する出力が出てくるが、効果があるのか?->クロスエント ロピー損失ほど効果はなさそう • クロスエントロピー損失はなぜ効果があるのか? • 近傍のクラスに対して確率を割り当てることで過適合を抑制している可能性がある • 適切なσがbinの数の設定とは独立してるので、HL-gaussは特定の範囲に目標値を絞った場合に よく汎化し、回帰問題における順序の構造をうまく活用できる • 過適合に関する部分のみで改善してるわけではないのでlabel-smoothingが全てでもなさそう

24.

• 2つ目の実験(クロスエントロピー損失はなぜ効果があるのか?)は13のatariの ゲームで検証した • σでラベルをどの程度smoothingするのか調整することになり、近傍に割り 当てる確率も変わってくる

25.

価値ベースRLのどの部分に効いているのか • 報酬にノイズを乗せた環境でより安定性がある • MSEよりも報酬のノイズに対して頑健 • 一定確率(25%)で前回と同じ行動をとる環境で性能が良い • MSEよりも確率的な環境で性能が安定する

26.

RLを分類タスクとして解くことで良い表現を学習するのか? • RLをMSEで解いたモデルの表現を後続タスクなどで活用するのは難しい • 以下の2点は既存研究で報告されている • カテゴリカル分布の使用がより良い表現の学習に繋がる(RLではない) • C51などの分布ベースの手法が表現学習が改善に部分的に寄与する • RLを分類タスクとして解くアプローチの表現学習への改善がリターンの分布の復元処理とクロスエントロピー損失のいずれに 寄与するかはまだよくわかっていない • 提案手法で200Mフレームで学習したモデルの最終的な特徴ベクトルに相当する表現がゼロから方策を再学習するために必要な 情報を保持してるかを確認する。 • 再学習時にはモデルは単一の線型層を持つQ関数を学習する

27.

クラス分類は非定常性に効果的なのか? • 価値ベースのRLでは目標値の計算時に、常に更新されるargmaxの方策や価値 関数を使用するため、非定常性が発生する。 5 • yi = sin(10 ⋅ fθ−(si)) + bを考えて、b ∈ {0,8,16,24,32}の値を取るようにそれぞれ で学習させる。この時bの値が大きいと方策の更新に伴い非定常性も相まって 誤差が悪化する可能性が高い • クラス分類ベースの方が誤差が抑制されており、非定常性による影響が緩和さ れている

28.

クラス分類は非定常性に効果的なのか? • o ine SARSAでは更新対象がデータ内に実在する遷移先の状態行動対の行 動価値関数の値に基づいて計算される • ベルマン最適作要素に基づいて更新する際にQついてmaxを使うアプローチで ffl はないため、MSEに対するHL-Gaussのスコアの差分が消えることから分類タ スクとして解くことの利点が価値ベースRLにおける非定常性を処理できるこ とによる証拠と主張しているが…(ちゃんとは理解できていない)