Что нужно сделать с learning rate при увеличении размера batch?

Ответ

При увеличении размера батча (batch size) learning rate (LR) обычно следует пропорционально увеличить.

Причина: Больший батч дает более точную (менее шумную) оценку градиента функции потерь. Поскольку шум в оценке градиента уменьшается, мы можем позволить себе делать более крупные и уверенные шаги в направлении антиградиента, не рискуя расходимостью из-за шума.

Эмпирическое правило (linear scaling rule): Новый LR = Базовый LR * (Новый размер батча / Базовый размер батча)

Практический пример на PyTorch:

import torch
import torch.nn as nn

# Исходные параметры
base_batch_size = 32
base_lr = 0.01

# Новые параметры
new_batch_size = 256
scaled_lr = base_lr * (new_batch_size / base_batch_size)  # 0.08

model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr)

Важные нюансы:

  1. Это правило хорошо работает для SGD и его модификаций (например, SGD с моментом) в широком диапазоне размеров батча.
  2. Для очень больших батчей (тысячи и более) часто требуется дополнительная техника warmup — плавное увеличение LR от малого значения до целевого в начале обучения, чтобы стабилизировать процесс.
  3. Это правило — хорошая отправная точка, но окончательный LR всегда требует валидации и тонкой настройки под конкретную задачу и архитектуру.

Ответ 18+ 🔞

Слушай, а вот это реально интересный момент, который многих новичков вгоняет в ступор. Сидят, блядь, с размером батча в 32, LR выставили 0.001, и всё вроде летает. А потом им приспичило ускориться, взяли батч в 256, а скорость обучения оставили прежней. И потом сидят, чешут репу: "Чё за хуйня? Модель нихрена не учится, или учится овердохуища медленно!"

А причина-то простая, как три копейки. Представь, что градиент — это такой совет, куда тебе идти, чтобы найти минимум потерь. Когда батч маленький, этот совет — от одного чувака. Он может быть прав, а может и нахуй послать, потому что его данные кривые. Это шумно, ёпта. И если ты на таком совете сделаешь огромный шаг, то запросто улетишь в другую сторону, в какую-нибудь локальную яму или вообще на хуй с горы.

А когда батч большой — это уже не один советчик, а целый комитет из 256 человек голосует, куда идти. Их коллективное мнение куда точнее и стабильнее. Шума — ноль ебать. И раз совет теперь надёжный, то почему бы не шагать увереннее и шире? Вот и выходит, что с ростом батча надо пропорционально увеличивать learning rate.

Есть даже такое эмпирическое правило, хитрая жопа, но работает:

Новый LR = Старый LR * (Новый размер батча / Старый размер батча)

Смотри, как это в коде выглядит, тут всё просто:

import torch
import torch.nn as nn

# Было у нас
base_batch_size = 32
base_lr = 0.01

# Стало
new_batch_size = 256
# Считаем по правилу: 0.01 * (256 / 32) = 0.08
scaled_lr = base_lr * (new_batch_size / base_batch_size)

model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr) # Вот тут уже ставим увеличенную скорость

Но, бля, есть важные нюансы, а то сейчас накосячите.

  1. Это правило — золотое для обычного SGD или SGD с моментом. Берёшь и масштабируешь, и в 80% случаев всё будет окей.
  2. Если ты совсем ебанулся и взял батч размером с небольшой город (тысячи примеров), то тут уже может быть пиздец. Слишком резкий старт с большой LR может всё развалить. Поэтому для монструозных батчей используют warmup — это когда ты начинаешь обучение с маленькой, почти нулевой скорости, и плавненько, за несколько эпох, наращиваешь её до расчётного значения. Так процесс стабилизируется.
  3. И главное, чувак: это правило — не священное писание, а отправная точка. Ты взял, посчитал, поставил LR=0.08. А модель всё равно ведёт себя как манда с ушами? Значит, надо ковырять дальше: может, 0.06, а может, 0.1. Всегда смотри на кривые обучения и лосс на валидации. Доверия к любой эвристике — ебать ноль, пока сам не проверишь на своих данных.

Короче, суть в том, что большой батч даёт точный градиент, а на точном градиенте можно смело жать на газ. Вот и вся магия.