12
Overcoming catastrophic forge2ng in neural networks Yusuke Uchida@DeNA

Overcoming Catastrophic Forgetting in Neural Networks読んだ

Embed Size (px)

Citation preview

Overcomingcatastrophicforge2nginneuralnetworks

YusukeUchida@DeNA

なにこれ?

•  ニューラルネットワークが持つ⽋陥「破滅的忘却」を回避するアルゴリズムをDeepMindが開発http://gigazine.net/news/20170315-elastic-weight-consolidation/

•  なんか凄そうだし、論⽂中に汎⽤⼈⼯知能とか書いてあるけど、やっていることはシンプル

•  端的にいうと、NNのパラメータのフィッシャー情報⾏列をパラメータの重要度として利⽤する

背景

•  汎⽤⼈⼯知能は多数の異なるタスクをこなすことが求められる

•  これらのタスクは明⽰的にラベル付けされていなかったり、突然⼊れ替わったり、⻑い間再び発⽣しなかったりする

•  連続的に与えられるタスクを、以前に学習したワスクを忘れることなく学習するContinual Learning(継続学習)が重要となる

•  なぜならNNは現在のタスク(e.g. task B)に関する情報を扱うと、以前のタスク(e.g. task A)に関する情報を急に失ってしまう=catastrophic forgetting

背景

•  現状のcatastrophic forgettingに対するアプローチは、全てのタスクに関するデータを予め揃え、同時にすべてのテスクを学習する(各タスクのデータを細切れに並べて学習させる)multitask learning

•  もしタスクが逐次的にしか与えられない場合、データを⼀時的に記憶し、学習時に再⽣する(system-level consolidation; システムレベルの記憶固定)しかない

•  タスクが多いと⾮現実的

※この辺の⽤語は脳神経科学系?

背景•  継続学習はtask-specific synaptic consolidationにより実現されていると解釈することができる(?)=以前のタスクに対する知識は、⾮可塑的になったシナプスの割合に⽐例して⻑持ちする

•  このsynaptic consolidationに着想を得たelastic weight consolidation (EWC) を提案する

> synaptic consolidationは、前回お話ししたLTPを起因とするシナプスの構造変化のことです。> system consolidationは、もっとマクロなレベルの変化のことで、脳全体の異なる脳領域間で起こるゆるやかな再編成です。 http://ameblo.jp/neuroscience2013/entry-11860086069.html

EWCの導出

•  EWCは、タスクAのエラーがなるべく⼩さくなるパラメータ空間内でタスクBのパラメータを学習=重要なパラメータは変化させない(イメージ)

•  L2は、元のパラメータからのL2距離をタスクBの学習時に加える=全パラメータの重要度が同じと仮定

•  No penaltyはタスクAを無視

パラメータ空間

タスクA学習後の最良パラメータ

EWCの導出•  今、データDが与えられとき、その事後確率を求めたいとする

•  対数をとると

•  Dが、タスクAのデータDAと、タスクBのデータDBから構成され、互いに独⽴に⽣成されると仮定する

•  変形すると

•  つまり…?全体のデータセットに対して学習することを考える際に、タスクAに関しては、上記の事後確率に全ての情報が含まれている

どのパラメータが重要か

EWCの導出•  真の事後分布を求めることは不可能なので、

MackayのLaplace approximationに従い、この事後確率を平均がθA

*、対⾓の精度がフィッシャー情報⾏列Fの対⾓成分で与えられる多変量ガウス分布で近似する(対⾓じゃないと、パラメータの⼆乗個の値が必要)

•  これにより、EWCの⽬的関数は下記となるθA*

精度高フィッシャー情報量大

精度低フィッシャー情報量小

直感的理解•  結局やりたいことは、タスクAで学習したパラメータθA

*をなるべく維持しながらタスクBにも適⽤するようにパラメータを修正する

•  なるべく維持=L2距離を⼩さくでは駄⽬ •  なぜならパラメータは多様体上にあり、適切な計量を考慮した測地線で距離を測らないといけない(e.g. 地図上の2点を結ぶ直線が最短ルートではない)

•  確率分布における計量=フィッシャー情報⾏列

•  KLダイバージェンスのテイラー展開の2次の項がフィッシャー情報⾏列

•  局所的には、KLダイバージェンスの観点から、θA*

から⼤きく外れないようにタスクBを学習する

どうやって求めるの?

•  フィッシャー情報⾏列を対⾓と仮定しているので、単に各パラメータのフィッシャー情報量

•  すなわち、対数尤度関数(ロス関数)の分散→⼀定数の学習サンプル毎に勾配を求め、それらの⼆乗和の平均

余談

•  結局フィッシャー情報⾏列のFiiが⼤きい=θiが重要

•  TOWARDS THE LIMIT OF NETWORK QUANTIZATION, ICLR’17.では、ロス関数のHessian=フィッシャー情報⾏列を考慮したパラメータの量⼦化を⾏っている

参考

•  丁寧な記事(数式はこちらから引⽤) –  https://rylanschaeffer.github.io/content/research/

overcoming_catastrophic_forgetting/main.html •  これも良い

–  http://www.inference.vc/comment-on-overcoming-catastrophic-forgetting-in-nns-are-multiple-penalties-needed-2/

•  TensorFlow実装 –  https://github.com/ariseff/overcoming-catastrophic