Как бороться с градиентным затуханием в глубоких нейронных сетях?

«Как бороться с градиентным затуханием в глубоких нейронных сетях?» — вопрос из категории Нейронные сети и Deep Learning, который задают на 26% собеседований Data Scientist / ML Инженер. Ниже — развёрнутый ответ с разбором ключевых моментов.

Ответ

Градиентное затухание возникает, когда градиенты становятся чрезвычайно малыми при обратном распространении через многие слои, что останавливает обучение ранних слоев. Для борьбы с этим я применяю комбинацию методов.

1. Правильная инициализация весов: Использую инициализацию, учитывающую нелинейность активации (Xavier/Glorot для tanh, He/Kaiming для ReLU).

import torch.nn as nn
# Для слоя с ReLU
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')

2. Функции активации: Предпочитаю ReLU и его вариации (LeakyReLU, PReLU, ELU) вместо насыщающихся функций (сигмоида, tanh), так как их производная не стремится к нулю при больших значениях.

3. Нормализация: Batch Normalization стабилизирует распределение активаций в каждом слое, что позволяет использовать более высокие learning rates и смягчает проблему затухания.

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.BatchNorm1d(256),  # Нормализует активации
    nn.ReLU(),
    nn.Linear(256, 10)
)

4. Архитектурные решения: Использую skip-connections (остаточные связи) из архитектур типа ResNet. Они создают путь для градиента, позволяя ему "перепрыгивать" через слои.

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.linear = nn.Linear(in_features, in_features)
        self.bn = nn.BatchNorm1d(in_features)
    def forward(self, x):
        residual = x
        out = self.linear(x)
        out = self.bn(out)
        out = nn.ReLU()(out)
        return out + residual  # Прямой пропуск градиента

5. Оптимизаторы и планирование скорости обучения: Адаптивные оптимизаторы (Adam, RMSprop) могут частично компенсировать малые градиенты. Также помогает планировщик learning rate (например, ReduceLROnPlateau), который снижает LR при застое.

На практике, комбинация BatchNorm, ReLU и остаточных связей решает проблему затухания для сетей в сотни слоев.