- BrainTools - https://www.braintools.ru -

Архитектура важнее размера: внедряем каузальные свертки в трансформер и получаем связный сторителлинг

Дело было вечером, делать было нечего. Я сидел за ноутом и разбирал новую идею Deepseek Engram: Лян Ванфень собрал вместе хеш-таблицы и почти-линейный трансформер – получилось дешево и сердито.

Однако есть в Engram один недостаток – он требует много RAM (каламбурчик, хаха). А хотелось архитектуру, на инференс которой не придется скидываться всем поселком.

Небольшой ликбез

Engram, по сути, перешивает токены и добавляет к ним факты. Реализовано это довольно хитро, через хеш-функцию, O(1) по сложности. Благодаря такой пристройке трансформер уделяет больше внимания [1] на грамматику и связь слов в предложении.

Основная идея

А что если вместо дорогого по вычислениям Engram взять простые свертки? Они дешевые, быстрые и могут запомнить базовые факты.

Именно об этом я и подумал. И тут же решил проводить тесты.

К сожалению у меня нет в гараже кластера на 8xH200 (да и гаража у меня нет), поэтому обучить что-то большое не получится. Однако для быстрого эксперимента хватит Colab и его Т4 16Гб.

Архитектура модели

Слой

Слой

За пару минут набросал схему в Obsidian. Теперь про каждый блок отдельно

RMSNorm
Базовый слой нормализации, в современный трансформерах без него будет тяжко.

Conv1D
Ключевое нововведение. Depthwise и kernel = 3 обогащают токены и перемешивают их. Чтобы сетка не ‘поглядывала’ реализовал каузальные свертки.

Визуализация MQA

Визуализация MQA

MQA
Довольно быстрая и дешевая реализация классического Self-Attention, но все еще не линейная или реккурентная архитектура.

FFN + SwiGLU
Два главных компонента: новая функция активации и необычное расширение в линейном слоев – x8/3 на 3 слоя вместо устоявшегося х4 на 2 слоя (позволяет сохранить то же кол-во параметров при большем кол-ве операций).
Эта комбинация отлично показала себя в моделях Llama, где была применена впервые.

Все это решил обозвать NormIs-1. Логики в названии нет абсолютно никакой.

Меньше слов – больше кода

Не стал что-то менять в нормализации и сделал самую простую версию.

сlass RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

Так-же сделал с FFN – просто и понятно

class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        hidden_dim = int(dim * 4 * 2 / 3)
        self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.w_val = nn.Linear(dim, hidden_dim, bias=False)
        self.w_out = nn.Linear(hidden_dim, dim, bias=False)
    def forward(self, x):
        gate = F.silu(self.w_gate(x))
        val = self.w_val(x)
        return self.w_out(gate * val)

Наивная реализация сверток. Спойлер – простой forward() потом вышел мне боком из-за медленной памяти [2].

class CausalConv1D(nn.Module):
    def __init__(self, dim, kernel_size=3):
        super().__init__()
        self.pad = kernel_size - 1
        self.conv = nn.Conv1d(dim, dim, kernel_size, groups=dim)
    def forward(self, x):
        x = x.transpose(1, 2)
        x = F.pad(x, (self.pad, 0))
        x = self.conv(x)
        x = x.transpose(1, 2)
        return x

А вот и все ноутбуки с обучением [3] (ссылки на Colab):
Кастомная архитектура [4]
MHA + MQA [5]

Метрики

Один из самых важных вопросов – а как вообще оценить NormIs-1? С чем его сравнивать? Какие метрики измерять?

Введем двух дополнительных кандидатов – трансформер на MQA и на MHA без сверток.

MHA считается лучшим по качеству, но он-же медленнее всего. Это Topline
MQA – топ по скорости, но может терять в качестве. Это Baseline.

Архитектура слоев у двух дополнительных кандидатов

Архитектура слоев у двух дополнительных кандидатов

Метрики ‘интеллекта’ модели – Loss (Cross-Entropy) и Perplexity. Метрики скорости – время обучения и TPS (tokens per second).

Моя цель – усидеть на двух стульях: получить интеллект [6] уровня MHA, не потеряв при этом в скорости генерации MQA. Если NormIs-1 догонит Topline по качеству, оставшись таким же быстрым – это победа.

Сравнение

Чтобы эксперимент был честным, я зафиксировал все гиперпараметры. Изменялась только архитектура внутреннего блока.

Конфигурация:

  • Датасет: TinyStories. Идеален для микро-моделей: в нем простая лексика, но строгие требования к грамматике и логике [7].

  • Токенизатор: Свой собственный, обученный на 8К токенов. Это позволило не раздувать матрицу эмбеддингов и сфокусировать ‘мозги’ модели на смысле, а не на хранении словаря.

  • Геометрия: model_dim = 128context = 256. Компактно, но достаточно для коротких рассказов.

  • Обучение: steps = 5000batch = 64.

Итого на претрейн – 81.920.000токенов.

Запустил обучение и ушел пить чай. По моим расчетам каждая модель училась бы не более получаса.

прочитать с соответствующей интонацией=)

прочитать с соответствующей интонацией =)

И вот наступил момент Х, пора сравнивать.

Сравнение

MHA Topline

MQA Baseline

NormIs-1

Параметры

1.84M

1.75M

1.75M

Время обучения

24:04

24:15

25:03

Val. Perplexity

7.9

8.24

7.94

Val. Loss

2.0668

2.1095

2.0713

Tokens/sec

362

339

202

Качество довольно хорошее. NormIs остался на уровне MHA, имея меньше параметров.

Но вот скорость обучения и инференса выглядит печально. А все из-за наивной реализации сверток. Граф вычислений на PyTorch должен создавать новый CUDA Kernel для каждой свертки.

Из-за этого модель значительно медленнее при инференсе, а при обучении это не так заметно. Думаю, если написать нормальный движек, то NormIs получит свои 300т/с+

Вот ссылка [8] на папку – там графики падения лосса и примеры генерации модели.

Выводы и идеи

Результат хороший, но не прорывной. Дальше я хочу попробовать эту же конфигурации, но на большем масштабе (20М+ параметров) и на сложной задаче (например, Fineweb-Edu).

Спасибо что дочитали статью. Это мой первый опыт [9] написания подобных текстов.

Буду рад если получится дать фидбек на мои решения. Я в ML недавно, только учусь. Будет интересно послушать людей с опытом.

Автор: morginalium8

Источник [10]


Сайт-источник BrainTools: https://www.braintools.ru

Путь до страницы источника: https://www.braintools.ru/article/29702

URLs in this post:

[1] внимания: http://www.braintools.ru/article/7595

[2] памяти: http://www.braintools.ru/article/4140

[3] обучением: http://www.braintools.ru/article/5125

[4] Кастомная архитектура : https://colab.research.google.com/drive/1TC4ZOKJ_GbvxpS1lZU669b2q0lvIFeyM?usp=sharing

[5] MHA + MQA: https://colab.research.google.com/drive/1iSRfHcrV8CtmnvcmysuUb16BYI6svXMp?usp=sharing

[6] интеллект: http://www.braintools.ru/article/7605

[7] логике: http://www.braintools.ru/article/7640

[8] ссылка: https://drive.google.com/drive/folders/1UMNzLbMWHYBYfimec-RLhGzUYrqKmEL-?usp=sharing

[9] опыт: http://www.braintools.ru/article/6952

[10] Источник: https://habr.com/ru/articles/1030492/?utm_campaign=1030492&utm_source=habrahabr&utm_medium=rss

www.BrainTools.ru

Rambler's Top100