Ответ
Стандартная бинарная кросс-энтропия (BCE) не всегда оптимальна, особенно при дисбалансе классов или наличии лёгких примеров. Вот ключевые модификации и альтернативы:
-
Взвешенная BCE (Weighted BCE): Присваивает больший вес классу с меньшим количеством примеров.
import torch.nn as nn # pos_weight = (количество негативных примеров) / (количество позитивных примеров) criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0])) -
Focal Loss: Предложена для детекции объектов. Динамически уменьшает вес для лёгких примеров (хорошо классифицируемых), фокусируя обучение на сложных.
def focal_loss(logits, targets, alpha=0.25, gamma=2.0): bce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none') pt = torch.exp(-bce_loss) # p_t из статьи focal_weight = alpha * (1 - pt) ** gamma loss = focal_weight * bce_loss return loss.mean() -
Dice Loss / F1 Loss: Прямо оптимизирует метрику, похожую на Dice coefficient (F1-score). Особенно полезна для задач сегментации медицинских изображений, где фон доминирует.
def dice_loss(pred, target, smooth=1e-6): pred = torch.sigmoid(pred) intersection = (pred * target).sum() dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth) return 1 - dice -
Label Smoothing: Заменяет жёсткие метки (0 и 1) на «мягкие» (например, 0.1 и 0.9). Это работает как регуляризатор, предотвращая излишнюю уверенность модели.
-
Lovász Hinge Loss: Дифференцируемый суррогат для оптимизации метрики IoU (Intersection over Union) в задачах бинарной сегментации.
Когда что использовать:
- Дисбаланс классов (1:100): Focal Loss или взвешенная BCE.
- Сегментация изображений: Dice Loss + BCE (комбинированная).
- Переобучение / излишняя уверенность: Label Smoothing.
- Требуется максимизация IoU: Lovász Loss.
Ответ 18+ 🔞
А, ну это же про наши любимые лоссы, да? Сидишь такой, пишешь модель, а она тебе выдает предсказания хуй пойми какие. И думаешь — бля, ну вроде BCE втулил, а нихуя не работает как надо. Так, слушай сюда, сейчас разжую.
Вот смотри, обычная бинарная кросс-энтропия — она как тупой молоток. Иногда бьёт точно, а иногда по пальцам, особенно если данные кривые. Например, если у тебя классы разбалансированы пиздец как — на один позитивный пример тысяча негативных. Модель быстро смекает, что проще всегда предсказывать ноль, и будет тебе счастливо минимизировать лосс, а на деле-то нихуя не решает задачу. Охуенно, да?
1. Взвешенная BCE (Weighted BCE)
Первая мысль — давай накрутим весов. Если позитивных примеров мало, давай за каждый такой косяк модели будем бить её в три раза сильнее. В PyTorch это pos_weight. Берёшь, считаешь, сколько у тебя негативов и позитивов, и втуливаешь.
import torch.nn as nn
# pos_weight = (количество негативных примеров) / (количество позитивных примеров)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]))
Работает часто, но это как костыль — помогает идти, но бегать не заставит. Иногда хватает.
2. Focal Loss А вот это уже поинтереснее. Её придумали ребята из Facebook для детекции объектов. Суть в чём? А в том, что модель быстро находит кучу простых примеров (например, чистый фон), на них отлично учится, и её нихуя не волнуют те несколько сложных пикселей, где сидит объект. Focal Loss тупо говорит: «Э, дружок, на простых примерах ты и так охуенен, давай меньше на них обращай внимания. А вот на этих, сложных, где ты нихуя не уверен, — давай поднажмём».
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
bce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
pt = torch.exp(-bce_loss) # p_t из статьи
focal_weight = alpha * (1 - pt) ** gamma
loss = focal_weight * bce_loss
return loss.mean()
Параметр gamma — это твоя «жестокость». Чем больше, тем сильнее модель будет забивать хуй на лёгкие примеры и париться над сложными. При дисбалансе — то, что доктор прописал.
3. Dice Loss / F1 Loss А это вообще отдельная песня, особенно для сегментации, например, медицинских снимков. Представь: тебе надо найти опухоль размером с горошину на снимке лёгких размером с Техас. Фон — овердохуища, объект — хуй с горы. BCE будет сходить с ума по фону, а на опухоль положит болт. Dice Loss же тупо пытается оптимизировать метрику, похожую на F1-score. Она считает, насколько предсказанная маска и истинная маска «пересекаются».
def dice_loss(pred, target, smooth=1e-6):
pred = torch.sigmoid(pred)
intersection = (pred * target).sum()
dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
return 1 - dice
Часто её используют вместе с BCE, типа Dice Loss + 0.5 * BCE. Так и пишут в статьях — «мы использовали комбинированный лосс». Звучит солидно, и иногда реально работает лучше.
4. Label Smoothing Это такой хитрожопый приём. Вместо того чтобы говорить модели «вот это точно 1, а вот это точно 0», мы говорим «вот это скорее всего 1 (0.9), а вот это скорее всего 0 (0.1)». Зачем? Чтобы модель не становилась слишком самоуверенной, сука. А то выучит датасет наизусть, а на новых данных обосрётся. Это как регуляризатор, не дающий модели зазубривать.
5. Lovász Hinge Loss Это уже для самых отчаянных, кто хочет оптимизировать прямо IoU (Intersection over Union) — главную метрику в сегментации. Сама IoU не дифференцируема, а Lovász Loss — её дифференцируемый суррогат. Если простыми словами — это способ заставить модель думать не о том, насколько каждый пиксель похож на правду, а о том, насколько вся предсказанная область в целом совпадает с истинной. Сложная, но мощная штука.
Короче, когда что брать:
- Дисбаланс пиздец какой (1:100): Первым делом смотри на Focal Loss или взвешенную BCE.
- Сегментация медицинская (ищем горошину в стоге сена): Твой выбор — Dice Loss, часто в паре с BCE.
- Модель слишком уверена в себе, переобучается: Попробуй Label Smoothing, охлади её пыл.
- Нужно выжать максимум из IoU, и плевать на всё: Разбирайся с Lovász Loss, если не боишься немного математики.
Пробуй, эксперементируй, и помни — нет серебряной пули. Иногда простая взвешенная BCE выигрывает у всех этих навороченных штук. Волнение ебать, но такова жизнь.