【グラフニューラルネットワーク】7.2

392 Views

July 08, 24

スライド概要

profile-image

AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

2024前期輪読会#10 – グラフニューラルネットワーク7.2節 過平滑化の対策 京都⼤学理学部 地球物理学教室 B4 松⽥ 拓⺒ 0

2.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 1

3.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 2

4.

過平滑化の測定と正則化 平滑化の度合いは,埋め込みどうしの距離で測ることができる 平滑指標 (頂点埋め込みの距離によるもの) 精 と 標 指 滑 平 1 # 𝑑(𝒛" , 𝒛$ ) ! 𝑛 層数(横軸)と正解率・平滑指標(縦軸)の関係[1] 低 が 度 ",$ 下 𝑛︓頂点の個数 𝑧! ︓𝑖番⽬の頂点埋め込み 𝑑︓距離関数(cos⾮類似度など) 平滑指標が⼩さい =埋め込みの値が似ている GNNの層数が増えるほど平滑指標が減少し, =過平滑化が起きている 同時に精度も低下することが知られている [1] Liu, Meng, Hongyang Gao, and Shuiwang Ji. “Towards deeper graph neural networks.” Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery & data mining. 2020. のFigure4に加筆 3

5.
[beta]
過平滑化の測定と正則化

“適度に”平滑かどうかは,MADGapという指標で測ることができる
適度に平滑なグラフは…

Mean Average Distance Gap

MADGap = MAD)*+ − MAD,-.
MAD#$# = average
"∈'

𝑑"
近い頂点は埋め込みが似ているが,
遠い頂点では埋め込みが⼤きく異なる

()# = 9
𝐷%"

#$#

= average 𝐷%"
%∈'

𝑑 𝑧% , 𝑧" ,
0,

𝑑* 𝑢, 𝑣 ≥ 8
𝑑* 𝑢, 𝑣 < 8

𝑑 𝑧% , 𝑧" ,
0,

𝑑* 𝑢, 𝑣 ≤ 3
𝑑* 𝑢, 𝑣 > 3

+,- = 9
𝐷%"

全⾴の平滑化指標が単に⼤きければ良い
というわけではない︕︕

#$#

𝑑"

#$#

近い頂点どうしの距離と遠い頂点どうしの距離
の差が⼤きければ理想
4

6.

過平滑化の測定と正則化 MADGapをペナルティに加えて学習する⼿法をMADRegという MADRegの損失関数 MADRegの効果[2] ℒ!"#$%& 𝜽 = ℒ'() 𝜽 − 𝜆 ⋅ MADGap(𝜽) ℒ!"# 𝜽 ︓通常の損失関数 𝜆︓正則化の強さ 罰則項を⼊れることで MADGapが⼤きくなるように 学習が進むようになる 層数が多いときほどMADRegの効果が⼤きい [2] Chen, Deli, et al. “Measuring and relieving the over-smoothing problem for graph neural networks from the topological view.” Proceedings of the AAAI conference on artificial intelligence. Vol. 34. No. 04. 2020. のTable 5から抜粋 5

7.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 6

8.

対策︓辺の削除 辺の数が多いほど過平滑化が起きやすくなる グラフ中の辺の数が多い 信号が混ざりやすく,各頂点の値が均質化しやすい グラフ中の辺の数が少ない 信号が混ざりにくく,各頂点の値が均質化しにくい 7

9.

対策︓辺の削除 DropEdge︓⼀様ランダムに辺を削除する l 訓練中に⼀部の辺をランダムに選んで削除する 元のグラフ 各辺を 確率𝑝で削除 Dropout probability 3 𝑝= 14 8

10.

対策︓辺の削除 AdaEdge︓ラベルごとに連結成分ができるように辺を削除 学習 推論 通常通り学習する 通常通り推論する 辺の削除 再学習 異なるラベルと予測された 頂点間の辺を削除 辺を削除した グラフで学習する [2] Chen, Deli, et al. “Measuring and relieving the over-smoothing problem for graph neural networks from the topological view.” Proceedings of the AAAI conference on artificial intelligence. Vol. 34. No. 04. 2020. のTable 3から抜粋 9

11.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 10

12.

対策︓スキップ接続 スキップ接続とは,変換の結果に⼊⼒を⾜し合わせる機構のこと 通常のニューラルネットワーク スキップ接続を導⼊したネットワーク Input 𝑿 Input 𝑿 全結合層 𝑾 全結合層 𝑾 Output 𝒚 Output 𝒚 𝑦 = 𝜎(𝑊𝑋) 𝑦 = 𝜎 𝑊𝑋 + X ※上の例のように1層スキップするだけでなく,複数層スキップさせる場合も多い 11

13.

対策︓スキップ接続 スキップ接続により,より多層なNNの学習が可能になる l 層が深いとLossが序盤の層まで伝播しにくくなる(掛け算が積み重なり勾配が0に近くなる) 通常のニューラルネットワーク Input 𝑿 全結合層 𝑾 𝑍 Output 𝒚 𝜕𝐿 𝜕𝐿 𝜕𝑍 𝜕𝐿 𝜕 𝑊𝑋 𝜕𝐿 . = ⋅ = ⋅ = 𝑊 𝜕𝑋 𝜕𝑍 𝜕𝑋 𝜕𝑍 𝜕𝑋 𝜕𝑍 スキップ接続を導⼊したネットワーク Input 𝑿 全結合層 𝑾 Output 𝒚 𝜕𝐿 𝜕𝐿 𝜕 𝑊𝑋 + 𝑋 𝜕𝐿 . 𝜕𝐿 = ⋅ = 𝑊 + 𝜕𝑋 𝜕𝑍 𝜕𝑋 𝜕𝑍 𝜕𝑍 スキップ接続により,上流の勾配が減衰せずに保持される 12

14.

対策︓スキップ接続 スキップ接続をGNNにも適⽤して精度低下を防⽌ GCNにおけるスキップ接続 𝐻(WXY) = 𝜎 𝐴8 [\* 𝐻 W 𝑊 WXY + 𝐻(W) 8層以上での精度低下を防⽌できたものの, 層を深くしても精度は上がっていない [3] Kipf, Thomas N., and Max Welling. “Semi-supervised classification with graph convolutional networks.” arXiv preprint arXiv:1609.02907 (2016). のFigure 5から抜粋 13

15.

対策︓スキップ接続 JKNetでは,各層の中間表現を”集約”したものを埋め込みとする 単純なスキップ接続 JKNet (Jumping Knowledge Networks) ( 012𝑯 . 𝑾 ./( 𝑯(./() = 𝜎 𝑨 + 𝑯(.) 集約の⽅法②︓LSTM - Attention 集約の⽅法①︓Concatenateする 1層⽬での 2層⽬での 中間表現 中間表現 (&) 𝐻$ 𝑍$ L層⽬での 中間表現 ・・・・・・ (() 𝐻$ (+) 𝒁' = 𝑓 𝑯'( , 𝑯') , ⋯ , 𝑯' 𝛼& 𝛼( ()) 𝐻$ (&) 𝐻′$ (() (&) (() 𝐻$ ・・・・・・ ()) 𝐻′$ 𝐻′$ ・・・ 1層⽬〜L層⽬での中間表現をすべてつなげる 𝐻$ 𝛼) ()) 𝐻$ ※論⽂では上記に加えて,Max-poolingによる集約⽅法も提案されている (&) 𝐻$ (() 𝐻$ ()) 𝐻$ 14

16.

対策︓スキップ接続 DeepGCNsでは,以前のすべての中間表現をconcatする 単純なスキップ接続 DeepGCNs ( 012𝑯 . 𝑾 ./( 𝑯(./() = 𝜎 𝑨 + 𝑯(.) 𝑯(./() = Concat 𝑿, 𝑯(() , ⋯ , 𝑯(.) CNN系モデルのDenseNetを模倣 [4] https://cvinvolution.medium.com/why-isnt-densenet-adopted-as-extensive-as-resnet-1bee84101160 15

17.

対策︓スキップ接続 GCNIIでは,初期残差接続と恒等写像を組み合わせて64層のGNNが可能に 単純なスキップ接続 $ 45)𝑯 0 𝑾 012 𝑯(012) = 𝜎 𝑨 GCNII + 𝑯(0) 𝑯("#$) = 𝜎 1 &'(𝑯 " + 𝛼" 𝑯 ) 1 − 𝛼" 𝑨 1 − 𝛽" 𝑰* + 𝛽" 𝑾 "#$ なんでこの式になったのかの経緯 ※イメージ 初期残差接続 スキップ元を前の層→第1層にする ちょっと変形… 恒等写像 深い層ほど変換⾏列𝑾を𝑰に近く 重みづけ… ( 012𝑯 . 𝑾 ./( 𝑯(./() = 𝜎 𝑨 +𝑯 6 𝑯(./() = 𝜎 ( 012𝑯 . + 𝑯 6 𝑾 ./( 𝑨 𝑯(./() = 𝜎 ( 012𝑯 . + 𝑯 6 𝑨 𝑯(./() = 𝜎 ( 012𝑯 . + 𝛼. 𝑯 6 1 − 𝛼. 𝑨 𝛼! = 0.1くらいの定数 𝛽0 = 𝜆 𝑙 ※𝜆はハイパラ 1 − 𝛽. 𝑰7 + 𝛽. 𝑾 ./( 1 − 𝛽. 𝑰7 + 𝛽. 𝑾 ./( 16

18.

対策︓スキップ接続 GCNIIでは層を深くすると,若⼲精度が向上した • GCNIIは層が深くなっても精度を維持し, 若⼲ではあるが精度も良くなっている • 初期残差接続と恒等写像を組み合わせて はじめて,精度向上が実現できている [5] Chen, Ming, et al. “Simple and deep graph convolutional networks.” International conference on machine learning. PMLR, 2020. のFigure 2から抜粋 17

19.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 18

20.

過平滑化以外の問題の対策 層数を増やすことで,以下の4つの問題が⽣じる 問題点 メモリ⾼消費 計算量⼤ メモリ消費量や計算量が増⼤する なぜ起こるのか ミニバッチの計算に必要な近傍頂点の個数が, 層数が増えるにつれて爆発的に増⼤するため 過学習 パラメータ数が増えて過学習する パラメータ数が⼤きいわりに訓練に使⽤でき るラベル付き頂点が少なく,パラメータを⼗ 分に最適化することができないため 最適化が困難 パラメータを最適化するのが難しい 情報集約と特徴変換を交互に繰返す構造のせ いで,勾配法による最適化が難しくなるため 過圧縮 固定⻑のベクトルに情報を押し込む ことができない 指数関数的に多い頂点の情報を固定⻑のベク トルに押し込める必要があるため 19

21.

過平滑化以外の問題の対策 層数を増やすことで,以下の4つの問題が⽣じる 問題点 メモリ⾼消費 計算量⼤ メモリ消費量や計算量が増⼤する なぜ起こるのか ミニバッチの計算に必要な近傍頂点の個数が, 層数が増えるにつれて爆発的に増⼤するため 過学習 パラメータ数が増えて過学習する パラメータ数が⼤きいわりに訓練に使⽤でき るラベル付き頂点が少なく,パラメータを⼗ 分に最適化することができないため 最適化が困難 パラメータを最適化するのが難しい 情報集約と特徴変換を交互に繰返す構造のせ いで,勾配法による最適化が難しくなるため 過圧縮 固定⻑のベクトルに情報を押し込む ことができない 指数関数的に多い頂点の情報を固定⻑のベク トルに押し込める必要があるため 20

22.
[beta]
過平滑化以外の問題の対策

𝒍 + 𝟏層⽬から𝒍層⽬の中間表現を計算できる構造でメモリ削減
Grouped Reversible GNNs[6]

頂点埋め込み𝑯(0) ∈ ℝ<×> を𝐶個のグループ
(0)
(0)
(0)
0
𝑯2 , 𝑯8 , ⋯ , 𝑯9 に分ける. 𝑯!

層の順伝搬計算 𝑯 0 → 𝑯 012

9
012

𝑯6

.

<×/

∈ℝ

𝑯!

は右式で⾏う.

012

= + 𝑯!
!78

0
012

= 𝑓:- 𝑯!;2 , 𝐴, 𝑈 + 𝑯!

0

変形

ある頂点𝑣の埋め込みベクトル

0

𝑯! = 𝑯!

𝒉 ∈ ℝ*

012

012

− 𝑓:- 𝑯!;2 , 𝐴, 𝑈

𝑙 + 1層⽬から𝑙層⽬の中間表現を計算できる︕︕
*
𝒉& ∈ ℝ+

*
𝒉( ∈ ℝ+

・・・

*
𝒉+ ∈ ℝ+

ということは

𝐿層⽬の出⼒だけを保持しておけばよいので,
層が深くなってもメモリ消費量は⼀定︕

𝐶個に分割する
[6] Li, Guohao, et al. "Training graph neural networks with 1000 layers." International conference on machine learning. PMLR, 2021.

21

23.

過平滑化以外の問題の対策 層数を増やすことで,以下の4つの問題が⽣じる 問題点 メモリ⾼消費 計算量⼤ メモリ消費量や計算量が増⼤する なぜ起こるのか ミニバッチの計算に必要な近傍頂点の個数が, 層数が増えるにつれて爆発的に増⼤するため 過学習 パラメータ数が増えて過学習する パラメータ数が⼤きいわりに訓練に使⽤でき るラベル付き頂点が少なく,パラメータを⼗ 分に最適化することができないため 最適化が困難 パラメータを最適化するのが難しい 情報集約と特徴変換を交互に繰返す構造のせ いで,勾配法による最適化が難しくなるため 過圧縮 固定⻑のベクトルに情報を押し込む ことができない 指数関数的に多い頂点の情報を固定⻑のベク トルに押し込める必要があるため 22

24.

過平滑化以外の問題の対策 特徴変換と情報集約を切り分けて,学習しやすくする 通常のGCN 特徴変換と情報集約を切り分けたGCN (6) 𝑯' = 𝑓; (𝒙' ) ( 𝑯 . 𝑾 ./( 𝑯(./() = 𝜎 𝑨 (𝑯 . 𝑯(./() = 𝑨 特徴変換と情報集約を交互に繰り返す ︓特徴変換は最初だけ ︓以後は情報集約のみ + (.𝑯 6 𝒁𝒗 = ; 𝛼. ⋅ 𝑨 さらに,スキップ接続を導⼊することもできる︓ 𝑿 特 徴 変 換 𝑯(") 情 報 集 約 𝑯($) 情 報 集 約 𝑯(%) .=6 情 報 集 約 ・・・ 情 報 集 約 ' 𝑯(&) 𝒁 𝛼? 𝛼6 𝛼2 23

25.

過平滑化以外の問題の対策 スキップ接続の重みづけ⽅法には様々なバリエーションがある LightGCN DAGNN + (.𝑯 6 𝒁𝒗 = ; 𝛼. ⋅ 𝑨 .=6 (Deep Adaptive GNN) 𝛼, = 1 𝐿+1 注意機構(Attention)で𝛼, も学習する ' PPNP (Personalized Propagation of Neural Prediction) Approximate PPNP 𝐿 → ∞の極限を考える. G -& 𝑯(.) 𝒁=𝛼 𝑰− 1−𝛼 𝑨 𝐿 → ∞ではなく𝐿 = 10程度で打ち切る ことでPPNPの逆⾏列の計算を避ける 24

26.

過平滑化以外の問題の対策 層数を増やすことで,以下の4つの問題が⽣じる 問題点 メモリ⾼消費 計算量⼤ メモリ消費量や計算量が増⼤する なぜ起こるのか ミニバッチの計算に必要な近傍頂点の個数が, 層数が増えるにつれて爆発的に増⼤するため 過学習 パラメータ数が増えて過学習する パラメータ数が⼤きいわりに訓練に使⽤でき るラベル付き頂点が少なく,パラメータを⼗ 分に最適化することができないため 最適化が困難 パラメータを最適化するのが難しい 情報集約と特徴変換を交互に繰返す構造のせ いで,勾配法による最適化が難しくなるため 過圧縮 固定⻑のベクトルに情報を押し込む ことができない 指数関数的に多い頂点の情報を固定⻑のベク トルに押し込める必要があるため 25

27.

過平滑化以外の問題の対策 ⼊⼒グラフの辺を配線し直して過圧縮の問題を解決 グラフ拡散 畳み込み (GDC) ⼊⼒グラフGをもとに,各頂点から各頂点への影響度を 計算する. その影響度をもとに新しく辺集合を定義する. [7] Gasteiger, Johannes, Stefan Weißenberger, and Stephan Günnemann. "Diffusion improves graph learning." Advances in neural information processing systems 32 (2019). 拡散 グラフ伝搬 (EGP) ⼊⼒グラフGと情報の流れやすいグラフHを⽤意する. Gでの情報集約とHでの情報集約を交互に繰り返す. [8] Deac, Andreea, Marc Lackenby, and Petar Veličković. "Expander graph propagation." Learning on Graphs Conference. PMLR, 2022. 動的再配線 メッセージ 伝達 (DRew) 第l層において,直接の近傍𝒩(𝑣),2ホップ先の集合 𝒩8 𝑣 ,…,lホップ先の集合𝒩0 (𝑣)から情報の集約を⾏う. [9] Gutteridge, Benjamin, et al. "Drew: Dynamically rewired message passing with delay." International Conference on Machine Learning. PMLR, 2023. 26

28.

アジェンダ ❶ 過平滑化の測定と正則化 ❷ 対策︓辺の削除 ❸ 対策︓スキップ接続 ❹ 過平滑化以外の問題の対策 ❺ 本当に深くする必要があるのか︖ 27

29.

本当に深くする必要があるのか︖ タスクやデータの性質によって必要な層数を検討すべき l あまり遠くの頂点の情報を考える必要がない場合が多い→そこまで深くする必要はない︖ l どこまで深くするか︖→問題を解く上で必要な情報の範囲に応じて考える l 深くしてもそこまで精度が上がらない • l 画像の場合とは対照的 画像や点群のように,考慮すべき頂点の個数が多い場合は深いGNNが有効になることも. [10] Li, Guohao, et al. “Deepgcns: Can gcns go as deep as cnns?.” Proceedings of the IEEE/CVF international conference on computer vision. 2019. のFigure 5 28