- BrainTools - https://www.braintools.ru -

Тихий убийца Трансформеров: как Weight Decay уничтожает эмбеддинги и нормализацию

У каждого из нас есть “мышечная память” при написании кода обучения [1] нейросетей. Мы собираем архитектуру, а затем пишем примерно такую строчку, даже не задумываясь:

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)

Weight Decay (L2-регуляризация) это база. Мы знаем, что он тянет веса к нулю, не дает отдельным нейронам [2] “зазвездиться” и предотвращает переобучение. Для линейных слоев (W * X) это работает великолепно. Но Трансформер состоит не только из матриц W. В нем есть специфические слои, для которых Weight Decay это не лекарство от переобучения, а тихий убийца, который медленно разрушает геометрию латентного пространства и душит градиенты.

Давайте залезем под капот оптимизатора и посмотрим, как слепое применение Weight Decay уничтожает ваши эмбеддинги и слои нормализации.

Физика Weight Decay

Чтобы понять проблему, нужно вспомнить математику [3] AdamW. В отличие от градиента, который обновляет вес только если есть ошибка [4], Weight Decay применяется безусловно на каждом шаге оптимизатора:

Wnew​=Wold​−η⋅∇LηλWold

Где λ это наш weight_decay.

Физически это гравитация. На каждом шаге (на каждом батче) оптимизатор “откусывает” от каждого веса микроскопический процент его значения, независимо от того, что говорят данные.

А теперь посмотрим, что эта гравитация делает с разными частями сети.

Жертва №1: Эмбеддинги (Черная дыра для редких токенов)

Слой эмбеддингов (nn.Embedding) это огромная lookup-таблица (Словарь * Размерность).

Главное отличие эмбеддингов от линейных слоев разреженность обновлений.
Когда вы прогоняете батч текста, в нем участвуют, скажем, 2000 уникальных токенов. Градиент (∇L) вычисляется только для этих 2000 токенов. Для остальных 48 000 слов из вашего словаря градиент равен нулю.

Но оптимизатору AdamW всё равно! Вы передали ему model.parameters(), и он применяет правило Weight Decay ко всей матрице эмбеддингов.

Что происходит в реальности:
Представьте редкое слово, например, “Утконос”. Оно встретилось в первом батче, модель сдвинула его вектор в правильном направлении. Следующий раз слово “Утконос” встретится через 10 000 батчей.

Все эти 10 000 шагов градиент для “Утконоса” равен нулю. Но формула Weight Decay продолжает работать:

Wутконос​=Wутконос​−0−ηλWутконос​

Оптимизатор методично умножает вектор редкого слова на условные 0.999 десять тысяч раз подряд. К тому моменту, когда “Утконос” снова появится в тексте, его вектор схлопнется в ноль. Вся семантическая геометрия, которую модель выучила для редких слов, стирается в пыль.

Из-за глобального Weight Decay эмбеддинги редких токенов постоянно “засасывает” в центр координат, лишая модель способности понимать узкоспециализированный контекст.

Жертва №2: Слои нормализации (Удушение сигнала)

Современные архитектуры (LLaMA, Mistral, Gemma) используют RMSNorm. У этих слоев нет весов в классическом понимании. У них есть обучаемый параметр Scale (γ)

Зачем нужен Scale (γ)? Нормализация принудительно делает дисперсию сигнала равной единице. Но иногда следующему слою (например, функции активации) нужна другая амплитуда сигнала для корректной работы. Обучаемый параметр γ существует исключительно для того, чтобы сеть могла восстановить нужный масштаб дисперсии.

Что происходит, если мы применяем к γ Weight Decay? Мы буквально говорим оптимизатору: 

“Штрафуй сеть за большую амплитуду сигнала”.

Weight Decay постоянно тянет γ к нулю. Сеть пытается сделать сигнал громче, чтобы протолкнуть его через глубокие слои, а оптимизатор бьет ее по рукам и заставляет “говорить шепотом”. Это создает искусственное сопротивление в потоке градиентов. Вы заставляете сеть тратить драгоценную емкость оптимизатора на то, чтобы бороться с вашей же регуляризацией.

Как это лечить? (И почему об этом не пишут в туториалах)

В серьезных репозиториях вы никогда не найдете слепого

 optimizer = AdamW(model.parameters()).

Правильный инженерный подход – декаплинг (разделение) параметров. Мы должны применять Weight Decay только к многомерным матрицам весов (Linear, Conv), и отключать его для одномерных тензоров (Norm, Bias) и эмбеддингов.

На Pytorch это делается через создание групп параметров. Вот как выглядит здоровый код инициализации оптимизатора для Трансформера:

def configure_optimizers(model, weight_decay, learning_rate):
    # Разделяем параметры на те, что нужно "декеить", и те, что нет
    decay = set()
    no_decay = set()
    
    # Бежим по всем модулям сети
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # полный путь к параметру
            
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, (nn.Linear, nn.Conv2d)):
                decay.add(fpn)
            elif isinstance(m, (nn.LayerNorm, nn.Embedding, nn.RMSNorm)):
                no_decay.add(fpn)

    # Собираем группы для оптимизатора
    param_dict = {pn: p for pn, p in model.named_parameters()}
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]
    
    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate)
    return optimizer

Заключение

Индустрия ИИ сегодня страдает от избытка абстракций. Высокоуровневые фреймворки прячут от нас математику, позволяя обучать модели в три строчки кода. Но дьявол всегда кроется в деталях.

Один дефолтный параметр weight_decay=0.1, примененный не к тем тензорам, может стоить вам пары процентов точности, потери редких фактов в LLM или привести к загадочной нестабильности при масштабировании глубины сети.

Нейросети ленивы. Оптимизаторы слепы. Задача инженера направлять их, опираясь на геометрию и физику процесса, а не на дефолтные параметры из туториалов.

Удачного обучения, и берегите свои эмбеддинги!

Автор: YH7H22

Источник [5]


Сайт-источник BrainTools: https://www.braintools.ru

Путь до страницы источника: https://www.braintools.ru/article/30367

URLs in this post:

[1] обучения: http://www.braintools.ru/article/5125

[2] нейронам: http://www.braintools.ru/article/6020

[3] математику: http://www.braintools.ru/article/7620

[4] ошибка: http://www.braintools.ru/article/4192

[5] Источник: https://habr.com/ru/articles/1036146/?utm_source=habrahabr&utm_medium=rss&utm_campaign=1036146

www.BrainTools.ru

Rambler's Top100