Elastic Weight Consolidation解説

導入

Catastrophic Forgettingを防ぐものとしてElastic Weight Consolidation(EWC)が提案されています。 Overcoming catastrophic forgetting in neural networks (2016)が元の論文です。

Catastrophic Forgettingとは

人工ニューラルネットにおいて異なるタスクを連続的に学習すると前のタスクについて忘れてしまうことがあります。Catastrophic Forgettingは厳密には少し忘れることではなく完全に忘れることを意味するはずですが、最近の論文ではForgettingと同じ意味合いで使われている気がします。 元の意味合いではCatastrophic forgetting in connectionist networks(1999)とかを読めばいいかな。

記事を書いた動機

解説記事や実装等は複数あるのですが、中身の数式の変形についてちゃんと言及しているものがなかったので後の自分にためにメモとして残しておきます。

概要

基本アイデア

タスクAを学習した後にタスクBを学習することを考えます。 タスクBの学習をするときの損失関数に正則化項を入れます。ここで重要なのが単純に二次正則化を入れるのではなく、タスクAで学習した後の個々のパラメーターの重要さに依存して係数の変わる正則化項を入れることです。以下のような形で入れます。変数の意味は詳細で説明します。

L(\theta) = L_B (\theta) + \frac{\lambda}{2} \sum_i F_i (\theta - \theta_{A, i}^{*})^{2}

この効果によって、タスクAの性能を出すのに必要なパラメーターはほとんど変化しないので、タスクAについての性能が落ちない(落ちにくい)ことが実現できます。

詳細

変数

まず、タスクAのデータを D_{A}(これの中身は入力 x_iに対して出力y_iが与えられた (x_i, y_i)です。以下では、簡単のために (x_i, y_i)x_iと書きます。)、タスクBのデータを D_B、学習したいパラメーターを \thetaとします。

パラメーター \thetaを学習するということは以下のように表せます。

\theta^{*} = argmax_{\theta} \log p (D | \theta)

 L(\theta) = - \log p (D | \theta)とおけば、これはソフトマックス函数を最終層にかました時などの損失関数に対応します。

本題

今回最大化したい尤度は  p(\theta | D) = p(\theta | D_A, D_B) です。今回やりたいことはタスクBの学習の際にタスクAのデータを使わないことなので、ベイズ的に分離してみます。ちなみにこれからは対数尤度で議論します。 \begin{eqnarray} \log p(\theta | D ) &=& \log p (D_B, D_A | \theta) + \log p(\theta) - \log p(D_A, D_B)\\ &=&\log p (D_B | \theta ) + \log p(D_A | \theta) + \log p(\theta) - \log p(D_A) - \log p(D_B)\\ &=&\log p(D_B | \theta ) + \log p(\theta | D_A ) - \log p(D_B) \end{eqnarray} これでタスクAとタスクBについて分離できました。タスクBの要素について注目すると、タスクBだけが与えられた時の損失関数に対応していることがわかります。

次にタスクAの成分に注目します。タスクAについてデータごとに分離して2次までテイラー展開を行います。(ラプラス近似) \begin{eqnarray} \log p(\theta | D_A ) &=& \sum_i \log p(x_i | \theta) + \log p(\theta) - \log p(D_A)\\ &\sim& \sum_i [ \log p(x_i | \theta^{*}_{A}) + \frac{1}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} \frac{\partial^{2} \log p (x_i | \theta )}{\partial \theta_j \partial \theta_k} |_{\theta = \theta_{A}^{*}} (\theta_k - \theta_{A, k}^{*} ) ] + \log p (\theta) - \log p (D_A) \end{eqnarray}

二回微分のところはデータ数がたくさんあればフィッシャーの情報行列に近似できるので、次のように書き換えられます。 \begin{eqnarray} \log p(\theta | D_A ) &= & \log p(D_A | \theta^{*}_{A}) + \frac{1}{2} \sum_i (\theta_{j} - \theta_{A, j}^{*})^{T} \frac{\partial^{2} \log p (x_i | \theta )}{\partial \theta_j \partial \theta_k} |_{\theta = \theta_{A}^{*}} (\theta_k - \theta_{A, k}^{*} ) + \log p (\theta) - \log p (D_A)\\ &\sim& \log p(D_A | \theta^{*}_{A}) - \frac{N}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} F_{jk} (\theta_{k} - \theta_{A, k}^{*} ) + \log p (\theta) - \log p (D_{A}) \end{eqnarray}

この近似した式を元の式に代入します。 \begin{eqnarray} \log p(\theta | D ) &=& \log p(D_B | \theta ) +\log p(D_A | \theta^{*}_{A}) - \frac{N}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} F_{jk} (\theta_{k} - \theta_{A, k}^{*} ) + \log p (\theta) - \log p (D_{A})- \log p(D_B) \end{eqnarray} タスクBについての項を左辺に一部移項すると、 \begin{eqnarray} \log p(\theta | D ) - \log p (\theta) + \log p(D_B) &=& \log p(D_B | \theta ) +\log p(D_A | \theta^{*}_{A}) - \frac{N}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} F_{jk} (\theta_{k} - \theta_{A, k}^{*} ) - \log p (D_{A})\\ \log p(D_B | \theta, D_A ) &=& \log p(D_B | \theta ) - \frac{N}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} F_{jk} (\theta_{k} - \theta_{A, k}^{*} )+\log p(D_A | \theta^{*}_{A}) - \log p (D_{A}) \end{eqnarray} この左辺はタスクAについてgivenな条件付き確率になっています。つまり、今回最適化したい、タスクAについての学習後における対数尤度に対応していることがわかります。また、右辺の \log p(D_A | \theta^{*}_{A})  -  \log p (D_{A})についてですが、タスクAについてのデータ D_Aが真のパラメーター\theta_{A}^{*}から生成されているならば、(データサイズが無限に大きければそうなるはず) \log p(D_A | \theta^{*}_{A})  -  \log p (D_{A}) = \log p(D_A | \theta^{*}_{A})  -  \log p (D_{A}| \theta^{*}_{A}) = 0となります。さらに、タスクBについての部分は、タスクBだけを与えられた時の損失関数 L_B (\theta) =\log p(D_B | \theta ) になっています。

以上から今回最適化したい損失関数 L_B (\theta | D_A) = \log p(D_B | \theta, D_A )は \begin{eqnarray} L_B (\theta | D_A ) = L_B (\theta) + \frac{N}{2} (\theta_{j} - \theta_{A, j}^{*})^{T} F_{jk} (\theta_{k} - \theta_{A, k}^{*} ) \end{eqnarray} となります。

利用

参考文献においては計算量を落とすために、フィッシャーの情報行列の対角成分だけを利用しています。タスクBの学習の際にタスクAのパラメーターの解に引っ張られて大きくずれられない状況になっています。実際の学習の際には正則化部分の係数 Nはサンプルサイズではなくそれよりも小さい値を用います。おそらく学習率が無限小であればサンプルサイズでいいかと思いますが、ある程度の大きさのものだとタスクAの解空間から大きくずれてしまうため正則化部分の計算がうまくいかなるはずです。

さらに追加のタスクを行う際には正則化項が追加されていきます。