【DL輪読会】Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators

144 Views

February 20, 25

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP ”Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators” [DL Papers] Kensuke Wakasugi, Panasonic Holdings Corporation. http://deeplearning.jp/ 1

2.

2 書誌情報 ◼タイトル: Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operator ◼著者:Shi, Z., Hu, Z., Lin, M., & Kawaguchi, K. ◼所属:National University of Singapore、Sea AI Lab ◼選書理由 • NeurIPS2024 Best Paper • 自身の業務(マテリアルズ・インフォマティクス)において、 物理法則を満たした予測が重要となるため ※特に記載しない限り、図表は上記論文からの引用です。

3.

3 概要 損失関数に微分演算子を含む最適化問題を考える 入力 NNの出力 微分演算子 n階微分 ※ただし、入力がd次元の場合、 どの次元で微分するかで場合の数が多数ある 微分テンソルサイズ O(dk) ■応用例: • 偏微分方程式の求解 • Physics-informed machine learning(PINN) • 拡散モデルにおける入力変数の更新 • Adversarial Attack • attentionの可視化

4.

Contribution 汎用的で効率的な高次元自動微分手法を提案 1. 効率的な高次元自動微分方法を提案 2. 既存手法では主にラプラス演算に焦点を当てていたが、 本手法では一般の偏微分方程式に適用可能 3. 既存手法のSDGD(偏微分の確率的計算)や HTE(トレースの確率的計算)を一般化 4. 100万次元の偏微分方程式を8分で解けることを、実験的に示した @NVIDIA A100 40GB GPU 4

5.

5 事前知識:偏微分の計算方向 偏微分の計算方法として、前向き・後向きがある Forward mode AD • 提案手法はこちらに該当 • 計算グラフ長 O(L)、必要メモリ O(max(d,h)) Backward mode AD • Back Propagationはこちら • 計算グラフ長 O(2L)、必要メモリ O(2(d+(L-1)h))

6.

高階微分の非効率性 Backwardの場合、微分するごとに計算量が2倍に Forwardで高階微分を求めたい 6

7.

Forward mode AD 再考 高次微分をNNの層毎の演算で求める ◼ ポイント • Fについて合成関数に拡張できる → 層毎の計算で全体微分を計算 • 右辺をNNで計算する → 左辺が求まる • 上記を高次微分にも拡張したい 7

8.

8 合成関数の偏微分とテイラー展開 適当な関数g(t)を置くと、一変数合成関数の微分をみなせる x y t g(t) F(x) 0階微分、1階微分に拡張 k階微分に拡張

9.

9 合成関数の偏微分とテイラー展開 合成関数の微分が、関数g(t)のt微分とF(x)のx微分の積であらわされる Faa di Bruno’s formula:一変数合成関数の高次微分の公式 合成関数F○gのk階t微分がF,gそれぞれの微分で表現される。 x y t g(t) F(x) t [F○g](t) y jet表現を1層進める Jg={g(t), g’(t)} JF g ={[F○g](t), ○ 𝜕 𝜕𝑡 [F○g](t)}

10.

補足:合成関数の微分の一般項 一変数合成関数の微分に関する公式 10

11.

11 合成関数の偏微分とテイラー展開 再掲 適当な関数g(t)を置くと、一変数合成関数の微分をみなせる x y t g(t) F(x) 0階微分、1階微分に拡張 k階微分に拡張

12.

12 提案手法 Forward計算で高次微分を伝搬

13.

13 実用例 求めたい微分項に応じて、g(t)を設計

14.

14 実用例 二次の偏微分方程式

15.

15 実用例 浅水波の方程式

16.

16 補足 複数回の演算を組み合わせて算出する場合も • 三階微分の要素の一つを計算する例 • 任意の微分項に対する、jetの設計は自明ではないが、求めることはできる?

17.

17 提案手法 再掲 Forward計算で高次微分を伝搬

18.

関連研究:Stochastic Dimension Gradient Descent 入力次元をミニバッチ化し、計算効率を改善 ミニバッチ学習と類似した考え方 入力次元d → J と置き換わりは計算量など軽減できるが、 微分次数 kへの依存性は減らない。高階微分における計算量軽減が望まれる 18

19.

ランダム次元削減との併用 計算対象の次元を限定することで、計算量を削減 先行手法(再掲) 提案手法 サンプリングした次元に関しての期待値が 全体で求めたい微分の値に一致(不偏推定) 19

20.

Stochastic Taylor Derivative Estimator (STDE)のメリット Forward計算で微分計算実現し、計算量の発散を抑制 1. 任意の次元、次数の微分演算子に適用可能 2. データ次元dと微分次数kのスケーリング問題を同時に解決 3. サンプリングJに関して、並列計算が可能 ■Backwardと比較 メモリ:O(2k−1(d + (L − 1)h))→O(kd) 計算量:O(2k(dh + (L − 1)h2))→ O(k2dL)

21.

補足:jet形式のsparseとdense sparseなjetと、denseなjetを設計可能(場合による) • 基本はスパースな(標準ベクトルと0で構成)jet • 場合によりdenseなjetを作れる。 下記例では、ガウス分布からサンプリングしたvを利用 21

22.

参考:Hutchinson Trace Estimation 引用:Hutchinson Trace Estimation — BackPACK 1.2.0 documentation 22

23.

Experiments Physics-informed neural networksで実験 ◼ PDEs(ただし、境界条件は先行文献利用し、自動的に満たすように設計) ◼ Amortized PINNs ◼ 方程式の例 23

24.

Ablation study on the performance gain 一定の誤差範囲の下での計算時間を比較 24

25.

Ablation study on the performance gain メモリ量の比較 25

26.

比較結果 ベースラインとなるSDGDを最適化知したうえでも、STDEが大幅に優位 ◼ JAX vs PyTorch • SDGDはPythorchのため、JAXも実装。JAXだと~15倍速、~4倍メモリ効率 ◼ Parallelization • 元のSDGDは非並列のため、並列化も検証 • ~15倍速、ピークメモリの低減 ◼ Forward Laplacian • 入力次元が100までは最良だが、ランダマイズがないのでスケールしない • 1000次元を超えるとSDGDが有利 ◼ STDE(提案手法) • 最良のSDGDと比較して、10倍速、4倍メモリ効率 26

27.

27 まとめ 偏微分方程式を含む問題における、汎用性の高い手法を提案 ◼ Applicability • 偏微分方程式に汎用的に適用できる • 敵対的攻撃、特徴貢献度解析、メタ学習 ◼ Limitations • 分散提言は考慮できておらず、future works • ランダムバッチサイズを小さくすると、速度・メモリが改善するが、分散との兼ね合いは不明 • ネットワークパラメータの学習には適していない ◼ Future works • AD とランダム化数値線形代数のつながりを示唆 • 高次元ラプラシアンを計算する必要がある多体シュレーディンガー方程式や 数理金融で多数の用途がある高次元ブラック ショールズ方程式への応用を期待

28.

28 所感 ◼ 現実的に大きなk階微分を使うことは少なそうではあるが、汎用性の高さが際立つ ◼ 実用面では、g(t)の設計に苦労しそうではあるが、 よく使われる方程式であれば、ハードルは低そう。使い勝手に期待。