Что такое Batch Normalization и как она работает?

Ответ

Batch Normalization (BN) — это метод, который стабилизирует и ускоряет обучение глубоких нейронных сетей за счёт нормализации активаций каждого слоя по мини-батчу во время обучения.

Проблема, которую она решает: Internal Covariate Shift — изменение распределения входных данных для каждого последующего слоя в процессе обучения, что замедляет сходимость и требует осторожного подбора learning rate.

Алгоритм (на этапе обучения): Для каждого признака (канала/нейрона) в активациях слоя:

  1. Вычислить среднее и дисперсию по текущему мини-батчу. (muB = frac{1}{m} sum{i=1}^{m} x_i) (sigmaB^2 = frac{1}{m} sum{i=1}^{m} (x_i - mu_B)^2)
  2. Нормализовать данные: (hat{x}_i = frac{x_i - mu_B}{sqrt{sigma_B^2 + epsilon}})
  3. Масштабировать и сдвинуть с помощью обучаемых параметров (gamma) и (beta): (y_i = gamma hat{x}_i + beta). Это позволяет сети при необходимости "отключить" нормализацию.

На этапе инференса используются скользящие средние значений (mu) и (sigma), накопленные во время обучения.

Реализация в PyTorch:

import torch.nn as nn

# Для полносвязных слоев (после активации или перед? Чаще перед.)
model_fc = nn.Sequential(
    nn.Linear(784, 256),
    nn.BatchNorm1d(256),  # Нормализует по батчу для 256 признаков
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Для сверточных слоев
model_conv = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.BatchNorm2d(64),  # Нормализует по батчу для 64 каналов
    nn.ReLU(),
    nn.MaxPool2d(2)
)

Преимущества:

  • Позволяет использовать более высокие learning rates.
  • Снижает чувствительность к инициализации весов.
  • В некоторой степени действует как регуляризатор (за счёт шума от статистик батча).

Недостатки/нюансы:

  • Поведение ухудшается при очень маленьких размерах батча (неточная оценка статистик).
  • Увеличивает потребление памяти.
  • В таких архитектурах, как RNN или Transformer, чаще используют Layer Normalization.

Ответ 18+ 🔞

А, слушай, вот эта штука — Batch Normalization. Ну, в общем, спасение для чуваков, которые нейронки тренируют. Представь себе: у тебя сеть глубокая, слоёв дохуя, и данные между ними гуляют как пьяные по коридору. Это и есть та самая проблема — Internal Covariate Shift, ёпта. То есть распределение на входе каждого нового слоя пляшет так, что терпения ебать ноль. Из-за этого learning rate надо подбирать с лупой и молитвой, а сходится всё медленнее, чем черепаха в сиропе.

А теперь сам фокус. Берём мы наш текущий мини-батч и для каждого канала или нейрона делаем три шага, которые проще, чем кажутся.

Шаг первый: Считаем, где наша стая собралась. Берём среднее (μ) и разброс (σ²) по всем примерам в батче. Формулы вот эти: μ_B = 1/m ∑ x_i и σ_B² = 1/m ∑ (x_i - μ_B)². Всё честно, по-пацански.

Шаг второй: Приводим всех к общему знаменателю. Каждое значение вычитанием среднего и делением на корень из дисперсии (плюс крошечный эпсилон, чтоб на ноль не делить) превращаем в этакого стандартизированного зэка. x̂_i = (x_i - μ_B) / √(σ_B² + ε). Теперь у нас данные — как солдаты на плацу, стоят ровненько.

Шаг третий, самый хитрый: А что, если нормализация — это перебор? Вот для этого и есть два обучаемых параметра — γ (гамма) и β (бета). Мы умножаем наш нормализованный x̂_i на γ и прибавляем β: y_i = γ * x̂_i + β. Гениальность в том, что если сети эта нормализация нахуй не сдалась, она может выучить γ ≈ стандартному отклонению, а β ≈ среднему, и всё вернётся как было! То есть сеть сама решает, насколько ей этот BN нужен. Удивление пиздец, правда?

А вот на инференсе — отдельная песня. Тут мы уже не по батчу считаем, а используем те скользящие средние μ и σ, которые тихо копили во время всей тренировки. Чтобы быстро работало.

Как в PyTorch это выглядит? Да элементарно, чувак.

import torch.nn as nn

# Для обычных плотных слоёв (BatchNorm1d)
model_fc = nn.Sequential(
    nn.Linear(784, 256),
    nn.BatchNorm1d(256),  # Нормализует по батчу для 256 признаков
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Для свёрточных сетей (BatchNorm2d)
model_conv = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.BatchNorm2d(64),  # Нормализует по батчу для 64 каналов
    nn.ReLU(),
    nn.MaxPool2d(2)
)

Чем хорош? Да всем, ядрёна вошь!

  • Learning rate можно ставить побольше, не боясь, что всё разлетится к хуям.
  • Не так страшна кривая инициализация весов.
  • Немного регуляризует, потому что статистики батча каждый раз чуть разные — это как лёгкий шум, который мешает сети заучить данные наизусть.

Но и подводных камней овердохуища:

  • Маленький размер батча — это пиздец. Статистики считаются криво, и всё идёт по пизде. BN начинает хуже работать, чем вообще без него.
  • Памяти жрёт больше, это да.
  • Не для всех архитектур годится. В тех же RNN или Transformer'ах чаще используют Layer Normalization, потому что там своя, хитрая жопа, с последовательностями.

Короче, инструмент мощный, но не серебряная пуля. Без него сейчас ни одна уважающая себя свёрточная сеть не обходится, это факт.