じんべえざめのノート

仮想生物と強化学習と深層学習に興味がある大学院生のブログ。

論文紹介 Understanding Batch Normalization

今回は、NIPS2018に投稿されたUnderstanding Batch Normalizationという論文を読んだので、紹介していきたいと思います。この論文は、なぜバッチ正規化(Batch Normalization)が学習に効果的なのかを実証的なアプローチで検証した論文です。

この記事は、ニューラルネットワークの基礎(全結合層や畳み込み層)を理解している方を対象にしています。また、可能な限りバッチ正規化を知らない方でも理解できるようにしています。この記事を読み終わるころには、「なぜバッチ正規化が学習に効果的なのか」が分かるようになっています。

ニューラルネットの基礎は以下の記事で紹介しています。

この記事は論文を要約したものに説明を加えたものとなっています。記事内で1人称で語られている文章については、多くが論文の主張となっています。しかし、あくまで論文の私なりの解釈であるため、間違っている可能性も大いに考えられます。鵜呑みにはせず、参考程度にして頂けると嬉しいです。

Understanding Batch Normalization [Johan Bjorck, Carla Gomes, Bart Selman, Kilian Q. Weinberger, NIPS2018 Poster, arXiv, 2018/09]

論文での主張のまとめ

時間が無い方も多いと思いますので、こちらで先にこの論文での主張をまとめておきます。詳細については続きをご確認ください。

  • バッチ正規化によって設定可能になる高い学習率は、SGDのノイズを増加させる原因となるため、正則化の効果と精度が向上する。
  • 正規化なしでは、層が深くなるにつれてチャンネルの平均と分散が指数関数的に増加するが、バッチ正規化によって一定に保つことが可能となる。
  • チャンネル・ユニットの値の平均値に偏りがある場合、出力層・畳み込み層の勾配の大きさが入力依存ではなくなるが、バッチ正規化によって偏りが緩和され、勾配が入力に高い依存性を持つようになる。
  • 高い学習率を設定した場合、正規化なしでは学習が進むにつれて出力層側での誤差が発散するが、バッチ正規化によって各層の出力が正規化されるため、誤差が発散しなくなる。
  • 正規化されていない場合、畳み込み層の重みの更新量は対応するイン/アウトチャンネルによっては一貫して小さくなる場合があるが、バッチ正規化によってその現象を抑えることが出来る。
  • 重みの初期化は、各チャンネルの分散が一定であるべきという考えから設計されているが、実際は層が深くなるにつれて分散が大きくなる。バッチ正規化はこの重みの初期化の影響を緩和する。

Batch Normalizationのオリジナルの論文では、バッチ正規化の効果は「内部共変量シフトを緩和すること」だと主張されていますが、こちらの論文では「その効果が存在しないとは主張しないが、バッチ正規化の効果はそれなしで説明できると考える」と主張されています。

バッチ正規化

バッチ正規化とは、ニューラルネットワークの各層への入力を正規化する手法です。これによって、高い学習率を設定可能になる・正則化がかかる・精度を向上させる・収束が早くなるなどの効果があります。

バッチ正規化では、正規化を行うためにバッチ内の要素(チャンネルやユニット)ごとに平均と標準偏差を求めます。このとき要素とは、畳み込み層のときはチャンネル、全結合層のときはユニットとなります。

畳み込み層の場合、平均と標準偏差は下図のようにして求めます。チャンネルごとにバッチ × x座標 × y座標の要素の平均と標準偏差を求めています。(正確には違いますが)バッチ正規化では、各チャンネルの値を対応するチャンネルの平均で引き標準偏差で割ることで、正規化を行っています。

f:id:jinbeizame007:20180924132105p:plain

このとき、バッチ正規化の入力と出力はそれぞれ4次元テンソル(バッチ、チャンネル、x座標、y座標)で表現することが出来ます。ここで、チャンネルの次元を特徴次元、各座標の次元を空間次元と呼びます。

畳み込み層でのバッチ正規化を数式で表すと、以下のようになります。

  •  INPUT(b,c,x,y):入力 (バッチ, チャンネル, x座標, y座標)
  •  OUTPUT(b,c,x,y):出力 (バッチ, チャンネル, x座標, y座標)
  •  μ_c:各チャンネルの平均値
  •  σ_c:各チャンネルの標準偏差
  •  ε:0で割らないための補正項
  •  γ_c, β_c:アフィン変換のパラメータ

\begin{align} OUTPUT(b,c,x,y) = γ_c \frac{INPUT(b,c,x,y) - μ_c}{\sqrt{σ^2_c + ε}} + β_c \end{align}

このように、バッチ正規化とは入力から対応するチャンネルの平均値を引き、標準偏差(0にならないための補正項付き)で割ることで正規化行ったものに、アフィン変換を行ったものです。γ_cβ_cは学習可能なパラメータです。

正規化した後にアフィン変換を行うことで、完全に分布を制限してしまうことを避け、表現可能な分布に自由度を与えています。

バッチ正規化の利点

ここでは、バッチ正規化が実際に性能にどの程度効果があるのかを検証しています。実験には、オリジナルResNetの論文の110層のResNetと同じものを使用しています。

下図は、バッチ正規化をかけたネットワークの学習率を変えたものと、バッチ正規化をかけていないネットワークの学習結果を示しています。バッチ正規化なしの場合は学習率を下げなくてはいけないため、0.0001となっています。

左図が訓練データの認識率で、右図がテストデータの認識率となっています。

  • bn-orig-lr:学習率 0.1, バッチ正規化あり
  • bn-med-lr:学習率 0.003, バッチ正規化あり
  • bn-small-lr:学習率 0.0001, バッチ正規化あり
  • unnorm:学習率 0.0001, バッチ正規化なし

f:id:jinbeizame007:20180924142413p:plain

この図から、正規化されていないネットワークの精度は正規化されているネットワークの精度と比較して精度が大きく下がることが分かります。

また、学習率が低い正規化されたネットワークは訓練とテストの精度の差が大きいため、学習率が高いほど正則化の効果があることが分かります。

学習率と正則化

これらの結果を説明するために、SGDの単純なモデルを考えます。αを学習率、Bをミニバッチ、誤差関数 f(x) = \frac{1}{N} \sum f_i(x)をデータセットの全てのデータに対する誤差の平均とします。このとき、SGDの推定する勾配の式を、以下のように2つに分割することが出来ます。

  •  α:学習率
  •  B:ミニバッチ
  •  f_i(x):i番目のデータに対する誤差
  •  f(x) = \frac{1}{N} \sum f_i(x):全データに対する誤差の平均

\begin{align} α∇f_{S G D} (x) &= \frac{α}{|B|} \sum_{i∈B} ∇f_i(x) \\ &= α∇f(x) + \frac{α}{|B|} \sum_{i∈B} (∇f_i(x) - ∇f(x)) \end{align}

これは、 x = y + (x - y)の変形と似たような変形です。最後の式の右辺の第1項は全データに対する勾配で、第2項はSGDがミニバッチとしてランダムにデータを選択することにより発生するノイズ項です。

データは一様にサンプルされるため、ノイズ項の期待値は E [ \frac{α}{|B|} \sum_{i∈B} (∇f_i(x) - ∇f(x)) ] = 0となります。従って誤差にバイアスはかかりませんが、通常はノイズが発生します。

ここで、ノイズ量をM = E[(∇f_i(x) - ∇f(x))^2]と定義します。すると、SGDによって与えられる勾配推定値のノイズを以下のように表現することが出来ます。

\begin{align} E[|| (∇f(x) - ∇f_{S G D}(x))||^2] &= \frac{α^2}{B} E[(∇f(x) - ∇f_{S G D}(x))^2] \\ &= \frac{α^2}{B} M \end{align}

この式は、バッチサイズと学習率がSGDのノイズ量を調節することを示しています。SGDのノイズは、ニューラルネットワーク正則化するうえで重要な役割を担っていると広く信じられています。

従って、バッチ正規化が可能にする高い学習率設定はSGDノイズを増加させる原因となり、正則化の効果と精度が向上すると考えられます。

勾配と活性化の大きさ

バッチ正規化の利点は、高い学習率を可能にすることによって発生することが判明しました。ここでは、バッチ正規化がなぜ高い学習率によって発生する大きな更新量での学習を可能にするかを検証します。

下図は、正規化されたネットワークの勾配と正規化されていないネットワークの勾配を示しています。正規化されていないネットワークの勾配は、正規化されたネットワークの勾配と比べて約2桁大きく、長いテールで分散していることが分かります。

  • 左:正規化あり
  • 右:正規化なし

f:id:jinbeizame007:20180924165834p:plain

これは、バッチ正規化によってチャンネルの平均と分散が正規化されることによる影響です。

下図は、層が深くなるにつれてチャンネルの平均と分散がどのように変化するかを示しています。y軸が対数スケールであることに注意してください。正規化されていないネットワークでは層が深くなるにつ入れて平均と分散が指数関数的に増加していますが、正規化されたネットワークでは層が深くなっても平均と分散が一定に保たれていることが分かります。

f:id:jinbeizame007:20180924170537p:plain

出力層における勾配

ネットワークの深さとともにチャンネルの平均値が増加することが分かりました。分類に対応する出力層では、平均の偏りはネットワークが予測するクラスが偏っていることを意味します。

下図は、出力層の各ユニットの勾配を示しています。正規化されていないネットワークでは、ミニバッチ内のほぼ全てのデータについて勾配がほぼ同じであり、勾配が入力依存ではないことが分かります。しかし、正規化されたネットワークは比較的勾配の入力に対する依存性が高いことが分かります。

  • 左:正規化なし
  • 右:正規化あり
  • x軸:各クラスに対応するユニット
  • y軸:ミニバッチ内のデータの番号

f:id:jinbeizame007:20180924171557p:plain

畳み込み層における勾配

同様の理由から、正規化されていないネットワークの畳み込み層の勾配が大きい理由を説明することが出来ます。最初の2つの次元が入出力チャンネルに対応し、後の2つが空間次元(x座標, y座標)に対応する畳み込み層の重み K(c_o, c_i, i, j)を考えます。 また、3×3の畳み込みを行うために、空間次元に S = {-1,0,1} × {-1, 0, 1} を付与します。ここで、畳み込み処理を以下のように表現することが出来ます。

\begin{align} OUTPUT(b,c,x,y) = \sum_{c'} \sum_{i, j∈S} INPUT(b,c', x+i, y+j)K(c,c',i,j) \end{align}

ここで、畳み込み層のパラメータK(c_o,c_i,i,j)の勾配は、以下の式によって与えられる。

\begin{align} \frac{∂L}{∂K(c_o,c_i,i,j)} = \sum_{b,x,y} d_{c_o c_i i j}^{d x y} \end{align}

\begin{align} d_{c_o c_i i j}^{d x y} = \frac{∂L}{∂OUTPUT(b,c_o,x,y)} INPUT(b,c_i,x+i,y+i) \end{align}

下の表は、勾配の絶対値の和・勾配の和の絶対値などをまとめた表です。ここでは、以下の要素が一致するかどうかを検証しています。

  •  \sum_{b x y} |d_{c_o c_i i j}^{b x y}|:勾配の絶対値の和
  •  \sum_b |\sum_{x y} d_{c_o c_i i j}^{b x y}|:勾配の空間次元の和の絶対値の和
  •  \sum_{x y} | \sum_b d_{c_o c_i i j}^{b x y}|:勾配のバッチの和の絶対値の和
  •  | \sum_{b x y} d_{c_o c_i i j}^{b x y}|:勾配の和の絶対値

f:id:jinbeizame007:20180924172929p:plain

また、正規化されていないネットワークでは、勾配の和の絶対値は絶対値の和にほぼ等しいが、正規化されたネットワークでは、和の絶対値と絶対値の和では10^2のスケールの差があることが分かります。

これらの結果は、正規化されていないネットワークでは、勾配は空間次元内とバッチ内の両方で同じ符号を持ち、入力依存でもなく空間次元にも依存しないことを示している。

勾配のスケールを超えるバッチ正規化

上の表から、バッチ正規化によって勾配が約2桁小さくなることが分かりました。しかし、バッチ正規化によって学習率を約4桁大きくすることが出来ることが分かっているため、バッチ正規化が大きな学習率の設定を可能にする理由を説明することは出来ていません。

正規化されていないネットワークでは、学習率を高く設定した場合、最初の数回のミニバッチで誤差が爆発します。ここで、発散の定義をミニバッチでの損失が10^3を超えた場合とします。実験では、誤差がその範囲に達した場合に、ネットワークは一度も元通りになりませんでした。

下図では、正規化されていないネットワークでの勾配の更新に伴うミニバッチの誤差の遷移が示されています。

f:id:jinbeizame007:20180924183950p:plain:w500

この図から、勾配の更新が進むにつれて誤差が発散することが分かります。

下図は、正規化されていないネットワークの勾配の更新によって、チャンネルの平均と分散がどのように遷移するかを示しています。この図から、出力層側の層の平均と分散が特に発散していることが分かります(Layer 44の平均、分散のスケールがそれぞれ1e9, 1e19であることに注意してください)。

f:id:jinbeizame007:20180924210319p:plain f:id:jinbeizame007:20180924210335p:plain

バッチ正規化は、各層の出力を正規化することによって出力の指数関数的な増加を阻害することは明らかです。これによって、大きな更新量での学習を行っても、ネットワークの出力が正規化されたパラメータ空間の領域に収まることが保証されます。これは、バッチ正規化がより高い学習率の設定を可能にする第2の主要なメカニズムだと推測されています。

さらなる観測

先ほどの表は、正規化されていないネットワークでは、勾配はバッチ内および空間次元内にわたって類似していること示唆しています。しかし、それらが入力/出力チャンネルc_i, c_oによってどのように変化するかは不明です。

下図は、45層目の各イン/アウトチャンネルに対応するパラメータの平均勾配の絶対値を示しています。正規化されていないネットワークでは一部のイン/アウトチャンネルは一貫して小さな勾配を持っており、パラメータの更新量が非常に小さいことが分かります。しかし、この現象はバッチ正規化によって抑えることが出来ます。

f:id:jinbeizame007:20180924190700p:plain

重みの初期化の影響

上図は、勾配や活性化のスケールに差があることを示しています。ニューラルネットワークの初期化手法は、ランダムに重みを設定した際に、チャンネルの分散が一定であるべきという考えから設計されています。そこで、最新の初期化手法の結果がどのようになるのかを検証します。

全結合層で構成される単純なニューラルネットワークを考えます。ここで、入力をx、出力をy、重み行列をA_iとしたとき、出力はy = A_t... A_2 A_1 xとなります。

**下図は、ランダムな正方行列と重み行列の積(y)の特異値の分布を示しています。

  • One matrix: y = A_1 x
  • Product of 5 matrix: y = A_5 ... A_2 A_1 x
  • Product of 25 matrix: y = A_{25} ... A_2 A_1 x

f:id:jinbeizame007:20180924192637p:plain

この図から、特異値の分布はより多くの行列が掛け合わされるにつれて、より長いテールを持つことが分かります。これは、特異値の最大の値と最小の値の比が、層が深くなるにつれて増加することを意味します。

これらの結果から、掛け合わせる行列数が増加することで 1)収束が遅くなる、2)小さい学習率が必要とされる、3)異なる部分空間内の勾配の比が増加、などの影響があることが考えられます。

バッチ正規化は、このような重みの初期化の影響を緩和します(実際に、2個前の図ではバッチ正規化によってすべての勾配が同様に増加・減少することが示されています)。