【DL輪読会】Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees

8K Views

November 17, 23

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees Takeyuki Nakae http://deeplearning.jp/ 1

2.

書誌情報 • 論⽂タイトル(URL: https://arxiv.org/abs/2309.09968) • 概要 勾配ブースティングを⽤いて表形式のデータを⽣成することができる • Github(実装) https://github.com/SamsungSAILMontreal/ForestDiffusion 2

3.

拡散モデル • 拡散モデル ノイズから実データを⽣成することができる⽣成モデル ノイズを徐々に除去する「ノイズ除去モデル」を利⽤することで、本物の画像(データ)に近 いデータを⽣成することができる。 学習は、ノイズ付与後の画像からノイズ付与前の画像がどのようなものだったかを予測する ことで学習できる。ただし完全なノイズからいきなり画像を予測できないので、予測しやす いように徐々にノイズを付与する。このためどれくらいノイズを付与したかのノイズレベル というのが存在する。 有名な例: Stable-diffusion 3

4.

決定⽊・勾配ブースティング • 決定⽊ 「特徴①≦5は左、5<特徴①は右」のようにデータを分割 して予測を⾏う機械学習アルゴリズム。 分割時の予測値 の損失が最⼩になるように学習。 • 勾配ブースティング 決定⽊(弱学習器)を構築した後に、その予測の誤差を修正 するように新たな決定⽊を構築して予測を⾏うモデル。 • Xgboost 2次テイラー展開・⾼速並列化/分散システム・⾼速分位 点分割・スパース性を考慮した分割探索を⽤いた定番の 勾配ブースティング。 4

5.

研究背景① テーブルデータ⽣成は重要なデータの匿名化のために重要 しかしデータ⽣成には以下の課題が存在する。 • データ数が少ない • ⽋損値がある →⽋損値はNN系のモデルでは対応できない。 A B C D 12 catA 0.1 100 23 None 0.55 1000 None catA None 500 10 catB 0.25 None これまでの⽣成モデルの対策として、⽋損値の削除と補完があった。 • ⽋損値の削除: df.dropna() 課題: 削除するとデータ数が減少する • ⽋損値の補完: df.fillna() 課題:モデルにバイアスがかかる可能性 5

6.

研究背景② これまでの⽣成モデルの対策として、⽋損値の削除と補完があった。 • ⽋損値の削除: df.dropna() 課題: 削除するとデータ数が減少する • ⽋損値の補完: df.fillna() 課題:モデルにバイアスがかかる可能性 今までの⽣成モデル(データの近似)はNNがベース →NNは⽋損値を扱うことができない 勾配ブースティングは⽋損値にも対応→勾配ブースティングでいけるのでは? 実データ 生成データ A B C D 12 catA 0.1 100 23 catB 0.55 1000 データを近似するモデル NN ↓ 勾配ブースティング A B C D 10 catB 0.122 10 32 catB 0.501 500 6

7.

本研究の特徴(メリット) • 他のテーブルデータ⽣成モデルと⽐べても⽣成の性能が⾼い 24データセット・4つの観点からも良い評価 • ⽋損値があってもデータを学習できる 従来の研究では⽋損値の対処が必要 • GPUを⽤いずとも⾼速な⽣成が可能 CPUを並列で利⽤してトレーニングするように実装 GPUも利⽤可能 • ⽋損値の補完(インピュテーション)も可能 性能は既存研究には劣る 7

8.

Method 概要 • DNN(深層ニューラルネットワーク)を⽤いたノイズ除去の⽅法 ノイズを⼊⼒し、単⼀のノイズ除去モデルでノイズの除去何度か繰り返して、 ⽣成データを作成する。 しかし、勾配ブースティングはDNNとは事情が異なるため、さまざまな⾯で異なる ⽅法を取る必要がある。 何度か繰り返す 生成データ ノイズ ノイズ除去モデル NN A B C D 12 catA 0.1 100 23 catB 0.55 1000 8

9.

Method 概要 勾配ブースティングでノイズ除去を⾏うために、以下の⽅法について説明する。 1. 1. Duplicating the data to produce many data-noise pairs ノイズと実データの対応⽅法 2. 2. Training different models per noise level ノイズ除去の⽅法について 3. 5. Data processing and regularizations 学習する前の処理について 4. 3. Choice of GBT データを近似する勾配ブースティングを選択 5. 4. Imputation Imputationについて 適宜紹介: Algorithm アルゴリズム詳細 9

10.

⼿法 • 勾配ブースティングの場合 全データと全データと同サイズのノイズ を⽤意する。 性能を上げるために、上記のデータを 100個複製※してモデルを学習する。 サンプル データ モデル データ ノイズ データ データ 100 ノイズ ノイズ モデル 1. Duplicating the data to produce many data-noise pairs • DNNの場合 データをサンプリングしてバッジ学習が できる ※多ければ⼤きいほど良いがメモリの都合上100を上限する 10

11.

⼿法 2. Training different models per noise level • DNNの場合 ノイズをそのままノイズ除去モデルに⼊⼒することで、 徐々にノイズ除去ができる • 勾配ブースティングの場合 学習するノイズに新たな特徴量ノイズレベルtを⼊⼒するこ とで、ノイズレベルに応じた学習が可能。 しかし、決定⽊の分割にはランダム性がある。 →特徴選択によってはノイズレベルtを選択しない可能性 このため、 ノイズレベルごとに異なるモデルを訓練する。 ノイズ + ノイズレベルt変数 ノ イ ズ t 0.9 0.24 -1.2 50 1.1 0.1 -0.7 40 1.12 0.012 -0.5 30 11

12.

⼿法 5. Data processing and regularizations • 学習前のデータ変換 カテゴリ変数: ダミー変数に変換 連続的数値: -1~1に変換 catA catB catC catD catE 1 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 1 • 復元 カテゴリ変数: 確率が最も⾼かったカテゴリの要素に変換 連続的数値: 逆変換(整数値は近い値に丸める) catA catB catC catD catE cat 0.99 0.002 0.002 0.003 0.003 0.1 0.2 0.5 0.1 0.1 catC 0.01 0.05 0.025 0.9 0.025 catD 変換 catA 12

13.

⼿法 5. Data processing and regularizations • ノイズ除去モデルについて ノイズの除去は、各特徴量⼀つ⼀つに対して推定する。 (Xgboostは⼀つの変数に対してしか予測ができないため。) またハイパーパラメータは、L2正則化を0にした以外はデフォルト スコア関数の推定が⽬的なのでモデルの過学習の⼼配をする必要はない。 • ノイズ除去モデルの学習 数値データ: 平均⼆乗誤差を損失関数に使⽤ カテゴリデータ: tabddpmで提案されたスコア関数を損失関数に使⽤ 損失関数: 13

14.

⼿法(アルゴリズム) 6. Algorithm ForestDiffusionの学習の流れについて 1. 実データ𝑋を𝑛!"#$% 倍複製する(𝑋 & ) 2. 𝑋 & と同サイズのノイズを作成する(𝑍) 3. Forwardで𝑋 & と𝑍をノイズデータの作 成(𝑋 𝑡 , 𝑌(𝑡))と、ノイズ除去を⾏う ⽬的変数を作成する 4. Xgboostでモデルを学習する。なお⼊ ⼒は実データとノイズを混ぜた𝑋 𝑡 を⼊⼒し、除去すべきノイズである 𝑌(𝑡)を予測する 5. 3.~4.を𝑛' 回繰り返し、ノイズ除去モ デルf 𝑡 を作成する 14

15.

⼿法(アルゴリズム) 6. Algorithm ノイズ除去モデルf 𝑡 を⽤意 1. ⽣成したいデータサイズ分のノイズ 𝑋 𝑛' を⽤意 2. Reverseでモデルを元にノイズ除去を 𝑛' 回実⾏ 3. 実データ𝑋 0 が出⼒される 15

16.

⼿法 3. Choice of GBT 経験的にXgboostが最も良かったのでXgboostを利⽤した。 4. Imputation Imputationは画像の⽋損を補完するインペインティングと同じようなもののため、 テーブルデータでも有効。 →拡散モデル⼿法である「REPAINT」を応⽤する。 REPAINTとは? 拡散モデルを⽤いたインペインティング 拡散確率モデル(DDPM)を⽤いて画像を復元する。 ※フロー⽅式でインペインティング する⽅法もあるが⽋損値があると難 しいため本研究では利⽤していない。 16

17.

⼿法(アルゴリズム) 6. Algorithm f 𝑡 はXgboostのノイズ除去モデル 𝑋はn(obs)×pのサイズの実データ ノイズ𝑋 𝑛! も実データと同じサイズ⽤意 𝑀は0と1で表され、⽋損値は0、⽋損無しは 1で表される。 以下の⾏為を𝑛! 回繰り返す 1. 𝑋(𝑡 − 1)の空のデータセットを作成する 2. 𝑋 𝑡 − 1 1 − 𝑀 (⽋損値のあるデータ箇 所)に対してReverseでノイズ除去 3. 𝑋(𝑡 − 1) 𝑀 (⽋損値のないデータ箇所)に 対して、Forwardでノイズ付与 17

18.

Method 概要 勾配ブースティングを⽤いたノイズ除去の⽅法 DNNは単⼀のモデルでノイズ除去を⾏うが、勾配ブースティングは⼀般的に⼀つの 変数のみに対して予測を⾏う。 このため、勾配ブースティングは、 「特徴量の数×ノイズ除去の回数」個のモデルを作成し、ノイズ除去を⾏う 除去する順番に応じてモデルが変わり、これを繰り返して⽣成データを作成する。 何度か繰り返す ノイズ ノイズ除去モデル GradientBoosting モ デ ル モ デ ル ・・・ モ デ ル テーブルデータ モ デ ル 18

19.

実験(セットアップ) 実験は24データセット(分類・回帰)で⾏い、内容は以下の3 つ(訓練80%・テスト 20%) 1. ⽋損値のないデータ⽣成の性能(完全データ⽣成) 2. ⽋損値埋め(インピュテーション) 3. ⽋損値のあるデータ⽣成の性能(不完全データ⽣成) ⽣成データの品質評価は、 • 実データと⽣成データの分布の近さ(ワッサースタイン距離) • ⽣成データの多様性(Coverrage) • ⽣成データでの学習→実データでの予測精度(R2スコア・F1-Score) • 統計的推論(Percent Bias) ※VP-diffusionを使⽤する場合は我々の⼿法をForest-VPとし、条件付きフローマッチングを使⽤する場合は Forest-Flowとする。 19

20.

実験(データ⽣成性能{⽋損なし}) 上の表を確認すると、Forest-flowが(ほぼすべてのメトリクスで)⼀般的に最も ⾼い性能を⽰した。 ※比較手法は、 Gaussian Copula・TVAE・CTGAN・CTAB-GAN+・STaSy・TabDDPM(Oracleは実データ) 20

21.

実験(データ⽣成性能{⽋損あり}) 上の表を確認すると、Forest-flowは、半分の実験結果でTabDDPMより優れてい た。 ※比較手法は、 Gaussian Copula・TVAE・CTGAN・CTAB-GAN+・STaSy・TabDDPM 21

22.

実験(インピュテーション) ※データを20%欠損値に変更す ることで実験をした。 結果は上の表となる。最も優れた⼿法は、MICE-ForestとMissForest。 Forest-Flowは⽐較的性能が良くみえるが⼀概に良いとは⾔えない。 ※比較手法は、kNN-Imputation・ICE・MissForest・MICE-Forest・softimpute・minibatch Sinkhorn divergence・MIDAS・GAIN 22

23.

まとめ 考察 ForestDiffusionの特徴 • 学習データに⽋損値が含まれる場合でも、リアルに近いデータを⽣成できる。 性能は近年の深層学習⼿法と同等かそれ以上。 • 多様なインピュテーションも可能 しかし性能は既存研究のMICE-ForestとMissForest • GPUや深層学習が必須ではない しかし学習にミニバッチを利⽤しないため、メモリコストは他の⼿法(深層学習)よ り⾼い。 さらにデータセットを𝑛!"#$% 倍(デフォルト100倍)複製しているため、データセット が⾮常に⼤きくなるとさらにメモリが悪化する。 23

24.

まとめ 今後の課題 今後の課題 • 既存の性能の良い拡散モデルの⼿法を適⽤する。 例: Multinomial Diffusion Model(TabDDPMで利⽤されている) Elucidating the Design Space of Diffusion-Based Generative Models classifier-free guidance • XGBoostモデルをミニバッチで訓練する⽅法を⾒つける。 メモリの節約ができる • 将来的な応⽤ データ増強、クラス不均衡、ドメイン翻訳タスクに適⽤する。 24

25.

appendix • Forward Process 25

26.

appendix • Reverse Process 26

27.

appendix • Ablation Study 27