Ответ
Да, я использовал дистилляцию для уменьшения размера и ускорения больших языковых моделей (например, BERT) с минимальной потерей качества. Это критически важно для деплоя моделей в production, где есть ограничения по latency и ресурсам.
Пример реализации дистилляции для задачи классификации текста:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# 1. Загрузка предобученных моделей
teacher_name = 'bert-large-uncased'
student_name = 'bert-base-uncased'
teacher = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=5)
student = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=5)
# 2. Определение кастомной функции потерь для дистилляции
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=2.0):
"""
Комбинирует loss от hard labels и soft targets от учителя.
"""
# Soft loss: KL-дивергенция между сглаженными распределениями
soft_loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1)
) * (temperature ** 2)
# Hard loss: стандартный кросс-энтропийный loss для истинных меток
hard_loss = F.cross_entropy(student_logits, labels)
# Комбинированный loss
return alpha * soft_loss + (1.0 - alpha) * hard_loss
# 3. Обучение студента с учетом выходов учителя (псевдокод цикла обучения)
optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
teacher.eval() # Учитель замораживается
for batch in dataloader:
input_ids, attention_mask, labels = batch
with torch.no_grad():
teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
student_logits = student(input_ids, attention_mask=attention_mask).logits
loss = distillation_loss(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()
Практический результат:
- Удалось уменьшить модель для анализа тональности от
bert-large(334M параметров) доdistilbert-base(66M параметров). Инференс ускорился примерно в 3 раза на CPU, при падении точности (F1-score) всего на ~0.5-1%. - Ключевые гиперпараметры: Температура (
temperature) — обычно от 2 до 5. Она "смягчает" распределение вероятностей учителя, делая его более информативным для студента. Коэффициентalphaбалансирует вклад soft и hard loss.
Ответ 18+ 🔞
А, ну это про то, как из здоровенной, неповоротливой модели сделать шуструю, чтоб в продакшене не тормозила, да? Ёпта, классическая история — берёшь какого-нибудь bert-large, который жрёт память и тормозит, как танк в болоте, и пытаешься запихнуть его на сервер, где ресурсов — кот наплакал. И тут на помощь приходит дистилляция, ебать мои старые костыли.
Смысл в чём? Берёшь большую, умную модель — «учителя». Она там всё знает, всё понимает, но она жирная. А потом берёшь маленькую, шуструю модель — «студента». И заставляешь студента не просто по учебнику учиться (по жёстким меткам), а подражать учителю, смотреть на его «мягкие» предсказания. Учитель-то не просто говорит «это класс 1», он говорит «на 85% это класс 1, но есть 10% на класс 2 и 5% на класс 3». Вот эта дополнительная информация — золото! Студент из неё учится гораздо лучше, чем из сухих цифр.
Вот смотри, как это в коде выглядит, тут всё прозрачно:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# 1. Загрузка предобученных моделей
teacher_name = 'bert-large-uncased'
student_name = 'bert-base-uncased'
teacher = AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels=5)
student = AutoModelForSequenceClassification.from_pretrained(student_name, num_labels=5)
# 2. Определение кастомной функции потерь для дистилляции
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=2.0):
"""
Комбинирует loss от hard labels и soft targets от учителя.
"""
# Soft loss: KL-дивергенция между сглаженными распределениями
soft_loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1)
) * (temperature ** 2)
# Hard loss: стандартный кросс-энтропийный loss для истинных меток
hard_loss = F.cross_entropy(student_logits, labels)
# Комбинированный loss
return alpha * soft_loss + (1.0 - alpha) * hard_loss
# 3. Обучение студента с учетом выходов учителя (псевдокод цикла обучения)
optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
teacher.eval() # Учитель замораживается
for batch in dataloader:
input_ids, attention_mask, labels = batch
with torch.no_grad():
teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
student_logits = student(input_ids, attention_mask=attention_mask).logits
loss = distillation_loss(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()
Видишь, в чём фокус? Учитель у нас в режиме eval, мы у него только логиты выпрашиваем. А студент учится на двух фронтах: и на настоящих ответах (hard_loss), и на сглаженных, умных догадках учителя (soft_loss). Параметр temperature — это вообще гениальная штука. Он как будто «подогревает» распределение вероятностей учителя, делает его более размазанным и информативным. Без него учитель мог бы просто тыкать в один класс с вероятностью 99%, и студенту было бы нечего с этого списать, кроме как «ага, этот класс». А так — получает целую карту знаний, ебать колотить.
А что на практике? Да огонь, просто пиzда рулю. У меня был случай: нужно было bert-large для анализа тональности запихнуть в контейнер, где памяти — овердохуища не было. Сделали дистилляцию в distilbert-base. Итог? Модель стала легче в 5 раз, а инференс на CPU ускорился раз в три. И самое главное — точность просела всего на какие-то жалкие полпроцента-процент по F1. Пользователь вообще разницы не заметил, а инфраструктурные ребята перестали материться, что у них сервис падает. Вот это я понимаю — результат.
Так что да, технология — хитрая жопа, но когда надо из жирного, умного «учителя» сделать шустрого «студента», который в продакшене не подведёт, то дистилляция — это один из самых рабочих вариантов. Главное — не переборщить с температурой и правильно альфу подобрать, а то студент или учителя тупо скопирует, или вообще в сторону загуляет.