ELECTRAの解説
今日はBERTの派生モデルであるELECTRAを解説しようと思います。
BERTに関しては、過去にもブログ記事を投稿していますので良かったら御覧ください。
- BERTについて勉強したことまとめ (1) BERTとは? その特徴と解決しようとした問題、及び予備知識
- BERTについて勉強したことまとめ (2)モデル構造について
- BERTについて勉強したことまとめ (3) 自己教師学習と汎用性について
- BERTのモデル構造をもう少し詳しく
ELECTRAは、BERTと同様にGoogle Researchが発表した自然言語処理のための機械学習モデルです。
- BERTと同様にTransformerベース
- 事前学習とファインチューニングがあるのも同様
- 事前学習の手法を改善したもの
参考リンク
事前学習タスクの改善
BERTの事前学習タスクには、MaskedLMとNSP(Next Sentence Prediction)の2つがありました。
後の研究で、NSPはそれほど重要でないことが指摘されています(RoBERTaの論文で指摘されています)
これをうけてELECTRAでも同様にNSPは使われていません。
ELECTRAでは、MaskedLMの代わりの異なるタスクを提案しています。それがReplaced Token Detectionです。
Replaced Token Detection
直訳すれば、「置き換えたトークンの探知」ですね。
このタスクのために、以下のようなGANに似たモデルを用意します。
GeneratorもDiscriminatorも従来のBERTと同様にTransformerのEncoderのスタックです。
入力に対して、一部のトークンにMASKをかけます。
Generatorがマスクの復元をします。ここまではMaskedLMと同じですね。この復元作業を「トークン置換」として利用します。
Discriminatorによって「どのトークンがオリジナルのままで、どのトークンが置き換えられたものなのか」を判別します。
言い換えると、各トークンごとに、originalかreplacedかの2値判定をします。
タスク自体はとてもシンプルで分かりやすいですね。
GANと違うところ
以上のモデルは、GANのしくみを真似て作られていますが、異なる点もあります。
GANでは、Generator側が、後工程のDiscriminatorが識別しにくいように出力をするように学習をします。
ただしELECTRAではそれは行わずに、普通に最尤推定(一番確率が高い単語で置き換えるだけ)でMASKを復元します。すなわち、敵対的ではないです。
論文では、敵対的な場合でも実験したが、Generator側が強すぎてうまく学習ができなかった点が述べられています。
補足:ファインチューニングで使うのはDiscriminatorのみ
ELECTRAは以上の通り、事前学習のタスクを改善したもので、ファインチューニングに関してはBERTと同様です。
その際に使われるのはDiscriminatorのみで、Generatorは事前学習でしか使われません。
なので、ファインチューニングの際に捨ててしまうGeneratorのパラメータが多いとおもったいないので、サイズは小さいほうがいいと論文で述べられています。
なんでこのタスクがいいいの?
サンプルに対して効率的(sample-efficient)だからです。
従来のBERTでは、MASKを復元して正しく戻せたらlossが少なくなるわけですが、そもそもMASKが書けられるのが、全トークンの15%程度です。
すなわち、用意した全サンプルのうちの15%しか、フィードバックのために使えなかったのです。
Replaced Token Detectionならば、すべてのトークンに対して、originalかreplacedかの2値の判定をするので、全トークンをlossの計算に回せるわけです。
性能
小さいサイズでも、大きいサイズでも性能が改善しています。詳しくは公式の情報を見てください。
特に、ELECTRA-Smallサイズで、BERT-Baseとほぼ同等の性能が出ている点が強いと思います。
BERT-Baseはだいたい180GPU日が必要でしたが、ELECTRA-Smallは数GPU日で、個人レベルでも事前学習を回せるほどです。
感想
Replaced Token Detectionは、説明を見てしまえばその効率の良さが直ちに分かるものです。きっと他のBERT派生でも積極的に使われる(使われている)と思います。
あと最近だとAttentionの計算量の改善が注目されているように思います。これらを組み合わせたらさらに小さく性能の高いモデルができるのかなあと期待が高まりますね。
コメントを残す