2.1K Views
December 07, 23
スライド概要
自然言語処理の基礎の輪読会第8回の発表スライドです。
2023年12月7日(木) 18:30~
AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!
自然言語処理の基礎 Transformer 6章 京都大学工学部電気電子工学科4回生 三宅大貴 0
論文情報 Attention Is All You Need (2017) https://arxiv.org/abs/1706.03762 特に記載がない限り、スライドの図は全て上記論文から引用してい ます 1
Transformerの利点 RNNやLSTMは隠れ状態のせいでN番目の単語を計算するのにN回の 計算が必要だった (そしてこれは並列化できない!) TransformerではN番目の単語を計算するのに1回の計算でよいので, 学習が効率的になる 2
Transformerの全体像 構成要素は6つ 1. 2. 3. 4. 5. 6. QKV Attention Multi-head Attention Feed Forward Positional Encoding Residual Connection Layer Normalization 3
QKV Attention 𝑄𝑄 ∈ ℝ𝑇𝑇×𝑑𝑑 , 𝐾𝐾 ∈ ℝ𝑆𝑆×𝑑𝑑 , 𝑉𝑉 ∈ ℝ𝑆𝑆×𝑑𝑑 として、Attentionは次で計算される ただし 𝑇𝑇, 𝑆𝑆 は系列長(文の長さ)、𝑑𝑑 は特徴量次元数 Attention 𝑄𝑄, 𝐾𝐾, 𝑉𝑉 = softmax 𝑄𝑄𝑄𝑄 𝑇𝑇 𝑑𝑑 𝑉𝑉 ※𝑄𝑄𝐾𝐾 𝑇𝑇 ∈ ℝ𝑇𝑇×𝑆𝑆 に対して,softmaxは 𝐾𝐾 の系列長方向にとる ( 𝑆𝑆 個の 和をとる) ※各要素が𝒩𝒩 0, 1 に従うとすると、 𝑄𝑄𝑄𝑄 𝑇𝑇 の分散は 𝑑𝑑 になる 4
QKV Attention 𝑋𝑋 ∈ ℝ𝑇𝑇×𝑑𝑑 , 𝑊𝑊 𝑥𝑥 ∈ ℝ𝑑𝑑×𝑑𝑑 として、 𝑄𝑄 = 𝑋𝑋𝑊𝑊 𝑄𝑄 , 𝐾𝐾 = 𝑋𝑋𝑊𝑊 𝐾𝐾 , 𝑉𝑉 = 𝑋𝑋𝑊𝑊 𝑉𝑉 としたのがSelf-Attention 𝑌𝑌 ∈ ℝ𝑆𝑆×𝑑𝑑 として 𝑄𝑄 = 𝑋𝑋𝑊𝑊 𝑄𝑄 , 𝐾𝐾 = 𝑌𝑌𝑊𝑊 𝐾𝐾 , 𝑉𝑉 = 𝑌𝑌𝑊𝑊 𝑉𝑉 としたのがCross-Attention 5
QKV Attention 私 𝑑𝑑 I は ate 朝 a に バナナ 𝑇𝑇 banana in を the 食べた morning 𝑆𝑆 6
QKV Attention 各クエリに対応するキーを探し、そのバリューを代入するイメージ 私 は 朝 に バナナ を 食べた I 私 ate 食べた a バナナ banana in the morning × = バナナ に 朝 朝 7
Multi-head Attention softmaxはexpを含むため、僅かな差が増大されてしまう →結合係数(attention)は均等に割り振られない →Attentionを ℎ 通り行う MultiHead 𝑄𝑄, 𝐾𝐾, 𝑉𝑉 = Concat head1 , … , headℎ 𝑊𝑊 𝑂𝑂 𝑄𝑄 head𝑖𝑖 = Attention 𝑄𝑄𝑊𝑊𝑖𝑖 , 𝐾𝐾𝑊𝑊𝑖𝑖𝐾𝐾 , 𝑉𝑉𝑊𝑊𝑖𝑖𝑉𝑉 𝑄𝑄 重みは 𝑊𝑊𝑖𝑖 ∈ ℝ𝑑𝑑×𝑑𝑑𝑘𝑘 , 𝑊𝑊𝑖𝑖𝐾𝐾 ∈ ℝ𝑑𝑑×𝑑𝑑𝑘𝑘 , 𝑊𝑊𝑖𝑖𝑉𝑉 ∈ ℝ𝑑𝑑×𝑑𝑑𝑣𝑣 , 𝑊𝑊 𝑂𝑂 = ℝℎ𝑑𝑑𝑣𝑣 ×𝑑𝑑 の 3ℎ + 1 個 𝑑𝑑𝑘𝑘 = 𝑑𝑑𝑣𝑣 = 𝑑𝑑/ℎ として、計算量がほぼ変わらないようにする 8
Multi-head Attention Single-head Attention Multi-head Attention Attn Attn x Attn x Attn Attn 9
Multi-head Attention making(上段)にかかるAttentionの可視化 1番目のheadはmoreとdifficultに同程度、 2番目のheadはmoreにより強く、 3,4番目のheadはdifficultにより強くかかっている 10
Feed Forward ReLUを挟んだ2層の全結合 FFN 𝑥𝑥 = max 0, 𝑥𝑥𝑊𝑊1 + 𝑏𝑏1 𝑊𝑊2 + 𝑏𝑏2 ただし、トークンごとに適用される(トークン同士で影響を与えな い) 11
Positional Encoding Multi-head AttentionとFeed-Forwardを組み合わせた関数 𝑓𝑓 では、 トークンの並び替えに対して不変性が成り立つ つまり、並べ替えの操作を 𝑇𝑇 として、𝑓𝑓 𝑥𝑥 = 𝑇𝑇 −1 𝑓𝑓 𝑇𝑇 𝑥𝑥 立つ が成り これは「語順は各トークンの意味に影響を与えない」ということに なる? 12
Positional Encoding 𝑉𝑉 ∈ ℝ𝑆𝑆×𝑑𝑑 について、次のようにPEが足される 𝑝𝑝𝑝𝑝𝑝𝑝 𝑉𝑉𝑝𝑝𝑝𝑝𝑝𝑝,2𝑖𝑖 += sin 100002𝑖𝑖/𝑑𝑑 𝑝𝑝𝑝𝑝𝑝𝑝 𝑉𝑉𝑝𝑝𝑝𝑝𝑝𝑝,2𝑖𝑖+1 += cos 100002𝑖𝑖/𝑑𝑑 𝑝𝑝𝑝𝑝𝑝𝑝 𝑖𝑖 の小大 →周波数の高低 →周期成分を特定できれば位置がわかる 𝑖𝑖 13
Relative Positional Encoding 学習データに含まれないような系列長のデータを推論した時,Positional Encodingが上手く働かない →相対位置を表すような仕組みを使う 従来のAttentionを 𝐸𝐸 = と変更する 𝑄𝑄𝐾𝐾 𝑇𝑇 , 𝛼𝛼 = softmax 𝐾𝐾 𝐸𝐸𝑖𝑖𝑖𝑖 = 𝑄𝑄𝑖𝑖 𝐾𝐾𝑗𝑗𝑇𝑇 + 𝑄𝑄𝑖𝑖 𝑎𝑎𝑖𝑖𝑖𝑖 𝑇𝑇 , 𝐸𝐸 𝑑𝑑 , 𝑍𝑍 = 𝛼𝛼𝛼𝛼 と表すと, 𝑉𝑉 𝑍𝑍𝑖𝑖 = � 𝛼𝛼𝑖𝑖𝑖𝑖 (𝑉𝑉𝑗𝑗 + 𝑎𝑎𝑖𝑖𝑖𝑖 ) 𝑗𝑗 𝐾𝐾 𝑉𝑉 ただし,𝑎𝑎𝑖𝑖𝑖𝑖 , 𝑎𝑎𝑖𝑖𝑖𝑖 ∈ ℝ𝑑𝑑 の要素は埋め込み 𝑤𝑤 𝐾𝐾 , 𝑤𝑤 𝑉𝑉 ∈ ℝ𝑑𝑑 によって決まる 𝐾𝐾 ( 𝑗𝑗 − 𝑖𝑖 < 𝑛𝑛) 𝑤𝑤𝑗𝑗−𝑖𝑖 𝐾𝐾 𝐴𝐴𝐾𝐾 (𝑗𝑗 − 𝑖𝑖 > 𝑛𝑛) 𝑖𝑖𝑖𝑖 � 𝑤𝑤𝑛𝑛 𝐾𝐾 𝑗𝑗 − 𝑖𝑖 < −𝑛𝑛 𝑤𝑤−𝑛𝑛 (Masked-Attentionの場合は 𝑗𝑗 − 𝑖𝑖 > 0 の部分は考えなくてよい) Self-Attention with Relative Position Representations 14
Relative Positional Encoding 実際にはメモリ節約のためにSkewアルゴリズムというものが使われる 興味がある人は以下の論文の3.4を読んでみてください Music Transformer 15
Residual Connection AttentionやFeed-Forwardの入力を出力に足し合わせる 勾配は 𝑦𝑦 = 𝑓𝑓 𝑥𝑥 + 𝑥𝑥 𝜕𝜕𝜕𝜕 𝜕𝜕𝜕𝜕 𝜕𝜕𝑦𝑦 𝜕𝜕𝜕𝜕 ′ 𝑓𝑓 𝑥𝑥 + 1 = = 𝜕𝜕𝜕𝜕 𝜕𝜕𝑦𝑦 𝜕𝜕𝜕𝜕 𝜕𝜕𝑦𝑦 となるため,勾配消失を防ぐ効果がある 16
Layer Normalization 層の出力に対して正規化を行う ′ 𝑥𝑥𝑝𝑝𝑝𝑝𝑝𝑝,𝑘𝑘 = 𝑎𝑎𝑘𝑘 𝑥𝑥𝑝𝑝𝑝𝑝𝑝𝑝,𝑘𝑘 − 𝜇𝜇𝑥𝑥 + 𝑏𝑏𝑘𝑘 𝜎𝜎𝑥𝑥 + 𝜖𝜖 平均と分散は特徴量次元について計算される 1 𝜇𝜇𝑥𝑥 = � 𝑥𝑥𝑝𝑝𝑝𝑝𝑝𝑝,𝑘𝑘 𝑑𝑑 𝜎𝜎𝑥𝑥 = 𝑘𝑘=1,…,𝑑𝑑 1 𝑑𝑑 � 𝑘𝑘=1,…,𝑑𝑑 𝑥𝑥𝑝𝑝𝑝𝑝𝑝𝑝,𝑘𝑘 − 𝜇𝜇𝑥𝑥 2 17
Masked Attention 文章生成の時には,先の単語が見えない →Attentionの計算時も先の単語がないものとして扱う必要がある →Masked Attention その単語以降をそもそも入力しなければよい? →計算効率化のために固定長にしたい また,単語が存在しなくても<EOS>トークンや<PAD>トークン が存在するため,そこにもAttentionがかかってしまう 18
Masked Attention 𝑄𝑄𝐾𝐾 𝑇𝑇 の計算結果に対して,マスクしたいトークンの部分に-infを代入 する softmax 𝑄𝑄𝐾𝐾𝑇𝑇 𝑑𝑑 の結果,その部分は必ず0になり,かつ総和は1になる (灰色の部分が0) 19
Label Smoothing 交差エントロピー誤差(真の分布 𝑞𝑞 𝑘𝑘 推定した分布 𝑝𝑝 𝑥𝑥 ) 𝐿𝐿 = − � 𝑞𝑞 𝑘𝑘 log 𝑝𝑝 𝑘𝑘 𝑘𝑘 では,普通 𝑞𝑞 𝑘𝑘 を経験分布(one-hot)で仮定する これを 1 − 𝜖𝜖 𝑘𝑘 が正解クラスの場合 𝑞𝑞 𝑘𝑘 = � 𝜖𝜖 それ以外 𝐾𝐾 − 1 で置き換える正則化 (perplexityは下がるが,精度は上がるらしい) 20
Warmup 学習率を一定ではなく上下させる 学習初期は勾配が大きいため学習率は小さく,また学習終盤も過学 習しないように学習率は小さくしたい 21
Model Averaging 学習終盤にいくつかのチェックポイントを保存しておき,推論時は それらのチェックポイントでアンサンブルする パラメータのEMAを取るのと似た効果? 22
まとめ まとめ1 まとめ2 まとめ3 TransformerはAttentionによって学習を効率化できる Multi-head Attentionを使うことで多様な関係を学習できる TransformerはAttentionの他にも,Feed Forward,Layer Normalization,Positional Encodingといったモジュールを持つ 23
24