Что такое архитектура U-Net?

Ответ

U-Net — это архитектура сверточной нейронной сети (CNN), изначально разработанная для биомедицинской семантической сегментации изображений. Ее название и характерная форма происходят от U-образной симметричной структуры, состоящей из пути сжатия (энкодера) и пути расширения (декодера).

Ключевые компоненты архитектуры:

  1. Энкодер (левая часть «U»):

    • Состоит из повторяющихся блоков, каждый из которых содержит две свертки 3x3 с активацией ReLU и последующей операцией макс-пулинга 2x2.
    • Цель: Извлечь контекстные, высокоуровневые признаки из изображения, постепенно уменьшая его пространственные размеры (ширину и высоту) и увеличивая глубину (количество карт признаков).
  2. Декодер (правая часть «U»):

    • Состоит из симметричных блоков, каждый из которых начинается с операции транспонированной свертки (transposed convolution) или upsampling для увеличения пространственных размеров.
    • За этим следуют две свертки 3x3 с ReLU.
    • Цель: Восстановить пространственную информацию для точного позиционирования границ объектов.
  3. Skip-connections (пропускные соединения):

    • Самая важная особенность U-Net.
    • Карты признаков с каждого уровня энкодера копируются и конкатенируются с соответствующими картами признаков на уровне декодера.
    • Зачем это нужно: При пулинге в энкодере теряется детальная пространственная информация (где именно находится объект). Skip-connections передают эти высокочастотные детали напрямую в декодер, что позволяет совместить «что» (контекст из энкодера) и «где» (детали из skip-connection) для точной локализации.

Упрощенная схема на PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """Блок из двух сверток 3x3 с ReLU."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        # Энкодер
        self.enc1 = DoubleConv(n_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.pool = nn.MaxPool2d(2)
        # ... и т.д.
        # Декодер
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(512, 256) # 256 из upconv + 256 из skip = 512 вход
        # ... и т.д.
        self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Прямой проход энкодера с сохранением признаков для skip-connections
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        # ...
        # Прямой проход декодера с конкатенацией
        d4 = self.upconv4(bottleneck)
        d4 = torch.cat([d4, e3], dim=1) # SKIP-CONNECTION здесь!
        d4 = self.dec4(d4)
        # ...
        return self.out_conv(d1)

Области применения: Помимо медицинской визуализации (сегментация клеток, органов на МРТ/КТ), U-Net успешно применяется для сегментации в автономном вождении, спутниковых снимках, восстановления изображений и colorization.

Ответ 18+ 🔞

А, слушай, вот эта U-Net — ну просто ёперный театр, честное слово. Представь себе такую хуйню: есть у тебя картинка, а тебе надо на ней каждый пиксель пометить, вот это — кошка, это — собака, а это — трава, на которую они срать ходят. Так вот, U-Net для этого и придумали, изначально для всяких медицинских штук, типа найти на снимке все раковые клетки, чтобы врач не охуел от усталости, разглядывая.

Архитектура у неё, блядь, проще некуда, если вникнуть. Форма как буква «U», отсюда и название, гениально, да? Состоит из двух путей: один сжимает, другой — расширяет. Как будто ты надеваешь старые джинсы после праздников — сначала впихиваешь всё это добро (энкодер), а потом пытаешься привести в божеский вид (декодер).

Вот как эта штука работает, по косточкам:

  1. Энкодер (левая нога «U»):

    • Это как твой друг, который всё упрощает до состояния «нормально всё». Берёт картинку и начинает её прогонять через кучу слоёв. Две свёртки 3x3, активация ReLU (чтоб не засыпала), потом макс-пулинг 2x2 — и пошло-поехало.
    • Цель: выжать из изображения всю суть, высокоуровневые признаки. Картинка становится меньше по размеру (ширина, высота), но зато глубина (количество этих самых признаков) растёт, как долг у студента. В итоге получается такая сжатая, концентрированная суть изображения — «понял, что на картинке, но где именно — хуй его знает».
  2. Декодер (правая нога «U»):

    • А это уже обратный процесс. Начинаем из сжатой абстракции лепить обратно детальную картинку. Для этого используем транспонированную свёртку или апсемплинг — это как надувать спущенный мячик, увеличиваем размер обратно.
    • Цель: восстановить пространственные размеры и наконец-то понять, где же эта чёртова кошка сидит на картинке.
  3. Skip-connections (пропускные соединения):

    • А вот это, блядь, самый сок! Важнейшая фишка U-Net, без неё нихуя не работает.
    • Пока энкодер всё сжимал, детальная информация о том, где что находится, терялась. Это как если б тебе друг пересказал фильм: «Ну там мужик, он всех победил, круто было». А кто мужик? Где? Какие детали? Хуй знает.
    • Так вот, skip-connections — это спасительная соломинка. Мы на каждом этапе сжатия копируем ещё не до конца растёртые карты признаков и, когда декодер начинает разворачивать, подкидываем их ему прямо в объятия, конкатенируем.
    • Зачем? Чтобы декодер, который уже понял «что» (кошка), получил от энкодера напоминание «где» (вот на этом диване, сука, шерсть вся). Без этого сегментация была бы пиздопроебибна — объекты расплывчатые, границы как у пьяного.

Вот тебе упрощённый код на PyTorch, чтобы вообще всё стало ясно, как божий день:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """Блок из двух свёрток 3x3 с ReLU. Просто дважды тыкаем в картинку."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        # Энкодер — сжимаем, блядь, всё что можно
        self.enc1 = DoubleConv(n_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.pool = nn.MaxPool2d(2)
        # ... и так далее, глубже и глубже
        # Декодер — разворачиваем обратно
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(512, 256) # 256 из upconv + 256 из skip = 512 вход
        # ... и так далее, к поверхности
        self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Гоняем через энкодер, припрятывая фичи для будущего
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        # ...
        # А теперь декодер, и в самый нужный момент — хоба! — подкидываем skip-connection
        d4 = self.upconv4(bottleneck)
        d4 = torch.cat([d4, e3], dim=1) # Вот он, мать его, SKIP-CONNECTION! Соединяем!
        d4 = self.dec4(d4)
        # ...
        return self.out_conv(d1)

Где эту U-Net применяют? Да везде, где нужно тыкать пальцем в пиксели! Не только в медицине. Автономное вождение — чтобы отделить дорогу от обочины, а пешехода от столба. Спутниковые снимки — чтобы посчитать, сколько у соседа построили нелегальных гаражей. Даже старые чёрно-белые фото в цвет раскрасить. Штука, в общем, универсальная, хоть и простая как валенок. Главное — не забыть про эти самые skip-connections, а то получится ерунда полная.