用語集

バッチ・ノーマライゼーション

バッチ正規化でディープラーニングのパフォーマンスを向上!このテクニックがAIモデルの学習速度、安定性、精度をどのように向上させるかを学びましょう。

Ultralytics HUB で
を使ってYOLO モデルをシンプルにトレーニングする。

さらに詳しく

バッチ正規化は、学習プロセスを安定させ、ディープニューラルネットワークの学習を大幅に高速化するために、ディープラーニングで広く使用されているテクニックである。Sergey IoffeとChristian Szegedyが2015年の論文「Batch Normalization」で紹介した:Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift(バッチ正規化:内部共変量シフトの低減によるディープネットワーク学習の高速化)」で紹介され、内部共変量シフトとして知られる、ネットワークの深い層への入力の分布が学習中に変化する問題に対処する。ミニバッチごとに各層への入力を正規化することで、バッチ正規化は活性化値のより安定した分布を維持し、よりスムーズで高速な収束につながります。

バッチ正規化の仕組み

トレーニング中、バッチ正規化は各ミニバッチのレイヤへの入力を標準化します。これは、ミニバッチ全体の活性化の平均と分散を計算し、これらの活性化を正規化します。重要なのは、このテクニックは、活性化チャンネルごとに学習可能な2つのパラメータ(スケール(ガンマ)パラメータとシフト(ベータ)パラメータ)を導入することである。これらのパラメータにより、ネットワークは正規化された入力の最適なスケールと平均を学習することができる。このプロセスは、活性を合理的な範囲に保つことで、消失勾配や 爆発勾配のような問題に対処するのに役立つ。推論中、平均と分散は、通常、学習中に推定された母集団の統計量を使用して固定されます。

バッチ正規化のメリット

ニューラルネットワークにバッチ正規化を適用すると、いくつかの重要な利点がある:

  • トレーニングの高速化:多くの場合、学習率が大幅に向上するため、トレーニングプロセスの収束が早まります。その他の最適化戦略については、モデルトレーニングのヒントを参照してください。
  • 改善された勾配フロー:活性化分布を安定化させることで、勾配が消失したり爆発したりする問題を緩和し、特に非常に深いネットワークにおいて、より安定した学習につながります。
  • 正規化効果:バッチ正規化は、ミニバッチ統計により、レイヤ入力にわずかなノイズ成分を加える。これは一種の正則化として機能し、ドロップアウトのような他のテクニックの必要性を減らす可能性がある。
  • 初期化に対する感度の低減:バッチ正規化を使用したネットワークは、多くの場合、学習開始前に選択された初期重みの影響を受けにくい。
  • より深いネットワークを可能にする:ディープアーキテクチャのトレーニングに関連する問題に対処することで、より深いモデルのトレーニングを成功に導く。

応用と実例

バッチ正規化は、特にコンピュータ・ビジョンにおいて、多くの最先端のディープラーニング・モデルの定番コンポーネントである。

  1. 画像認識と物体検出: 畳み込みニューラルネットワーク(CNN)では、バッチ正規化は通常、畳み込み層の後、活性化関数ReLUなど)の前に適用される。ResNetのようなモデルはこれに大きく依存している。物体検出モデルでは Ultralytics YOLOバッチ正規化は、学習の安定化、精度の向上、収束のスピードアップに役立ち、COCOのような複雑なデータセットでの効果的な検出を可能にします。クロスミニバッチ正規化(CmBN)のようなバリエーションは、YOLOv4のようなモデルで使用され、パフォーマンスをさらに向上させました。
  2. 生成的逆数ネットワーク(GAN):バッチ正規化は、敵対的学習プロセスを安定させるために、GANの生成器ネットワークと識別器ネットワークでしばしば使用されるが、アーチファクトを避けるためには慎重な実装が必要である。これはモード崩壊を防ぎ、よりスムーズな学習ダイナミクスを保証するのに役立ちます。

関連概念とバリエーション

バッチ正規化は広く使用されているが、関連する正規化技術がいくつか存在し、それぞれが異なるシナリオに適している:

考察と実装

バッチ正規化の重要な考慮点は、学習中のミニバッチサイズに依存することです。バッチサイズが小さすぎる場合(例えば、1や2)、バッチ統計量が母集団統計量のノイズの多い推定値になるため、性能が低下する可能性があります。さらに、訓練(バッチ統計量を使用)と推論(推定母集団統計量を使用)で動作が異なる。標準的なディープラーニングフレームワーク PyTorch (torch.nn.BatchNorm2dそして TensorFlow (tf.keras.layers.BatchNormalization)がロバストな実装を提供している。代替手段があるにもかかわらず、バッチ正規化は、多くの最新のディープラーニングモデルを効果的にトレーニングするための基本的なテクニックであり続けている。

すべて読む