- BrainTools - https://www.braintools.ru -
Дело было вечером, делать было нечего. Я сидел за ноутом и разбирал новую идею Deepseek Engram: Лян Ванфень собрал вместе хеш-таблицы и почти-линейный трансформер – получилось дешево и сердито.
Однако есть в Engram один недостаток – он требует много RAM (каламбурчик, хаха). А хотелось архитектуру, на инференс которой не придется скидываться всем поселком.
Engram, по сути, перешивает токены и добавляет к ним факты. Реализовано это довольно хитро, через хеш-функцию, O(1) по сложности. Благодаря такой пристройке трансформер уделяет больше внимания [1] на грамматику и связь слов в предложении.
А что если вместо дорогого по вычислениям Engram взять простые свертки? Они дешевые, быстрые и могут запомнить базовые факты.
Именно об этом я и подумал. И тут же решил проводить тесты.
К сожалению у меня нет в гараже кластера на 8xH200 (да и гаража у меня нет), поэтому обучить что-то большое не получится. Однако для быстрого эксперимента хватит Colab и его Т4 16Гб.
За пару минут набросал схему в Obsidian. Теперь про каждый блок отдельно
RMSNorm
Базовый слой нормализации, в современный трансформерах без него будет тяжко.
Conv1D
Ключевое нововведение. Depthwise и kernel = 3 обогащают токены и перемешивают их. Чтобы сетка не ‘поглядывала’ реализовал каузальные свертки.
MQA
Довольно быстрая и дешевая реализация классического Self-Attention, но все еще не линейная или реккурентная архитектура.
FFN + SwiGLU
Два главных компонента: новая функция активации и необычное расширение в линейном слоев – x8/3 на 3 слоя вместо устоявшегося х4 на 2 слоя (позволяет сохранить то же кол-во параметров при большем кол-ве операций).
Эта комбинация отлично показала себя в моделях Llama, где была применена впервые.
Все это решил обозвать NormIs-1. Логики в названии нет абсолютно никакой.
Не стал что-то менять в нормализации и сделал самую простую версию.
сlass RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.g
Так-же сделал с FFN – просто и понятно
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = int(dim * 4 * 2 / 3)
self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
self.w_val = nn.Linear(dim, hidden_dim, bias=False)
self.w_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
gate = F.silu(self.w_gate(x))
val = self.w_val(x)
return self.w_out(gate * val)
Наивная реализация сверток. Спойлер – простой forward() потом вышел мне боком из-за медленной памяти [2].
class CausalConv1D(nn.Module):
def __init__(self, dim, kernel_size=3):
super().__init__()
self.pad = kernel_size - 1
self.conv = nn.Conv1d(dim, dim, kernel_size, groups=dim)
def forward(self, x):
x = x.transpose(1, 2)
x = F.pad(x, (self.pad, 0))
x = self.conv(x)
x = x.transpose(1, 2)
return x
А вот и все ноутбуки с обучением [3] (ссылки на Colab):
Кастомная архитектура [4]
MHA + MQA [5]
Один из самых важных вопросов – а как вообще оценить NormIs-1? С чем его сравнивать? Какие метрики измерять?
Введем двух дополнительных кандидатов – трансформер на MQA и на MHA без сверток.
MHA считается лучшим по качеству, но он-же медленнее всего. Это Topline
MQA – топ по скорости, но может терять в качестве. Это Baseline.
Метрики ‘интеллекта’ модели – Loss (Cross-Entropy) и Perplexity. Метрики скорости – время обучения и TPS (tokens per second).
Моя цель – усидеть на двух стульях: получить интеллект [6] уровня MHA, не потеряв при этом в скорости генерации MQA. Если NormIs-1 догонит Topline по качеству, оставшись таким же быстрым – это победа.
Чтобы эксперимент был честным, я зафиксировал все гиперпараметры. Изменялась только архитектура внутреннего блока.
Конфигурация:
Датасет: TinyStories. Идеален для микро-моделей: в нем простая лексика, но строгие требования к грамматике и логике [7].
Токенизатор: Свой собственный, обученный на 8К токенов. Это позволило не раздувать матрицу эмбеддингов и сфокусировать ‘мозги’ модели на смысле, а не на хранении словаря.
Геометрия: model_dim = 128, context = 256. Компактно, но достаточно для коротких рассказов.
Обучение: steps = 5000, batch = 64.
Итого на претрейн – токенов.
Запустил обучение и ушел пить чай. По моим расчетам каждая модель училась бы не более получаса.
И вот наступил момент Х, пора сравнивать.
|
Сравнение |
MHA Topline |
MQA Baseline |
NormIs-1 |
|---|---|---|---|
|
Параметры |
1.84M |
1.75M |
1.75M |
|
Время обучения |
24:04 |
24:15 |
25:03 |
|
Val. Perplexity |
7.9 |
8.24 |
7.94 |
|
Val. Loss |
2.0668 |
2.1095 |
2.0713 |
|
Tokens/sec |
362 |
339 |
202 |
Качество довольно хорошее. NormIs остался на уровне MHA, имея меньше параметров.
Но вот скорость обучения и инференса выглядит печально. А все из-за наивной реализации сверток. Граф вычислений на PyTorch должен создавать новый CUDA Kernel для каждой свертки.
Из-за этого модель значительно медленнее при инференсе, а при обучении это не так заметно. Думаю, если написать нормальный движек, то NormIs получит свои 300т/с+
Вот ссылка [8] на папку – там графики падения лосса и примеры генерации модели.
Результат хороший, но не прорывной. Дальше я хочу попробовать эту же конфигурации, но на большем масштабе (20М+ параметров) и на сложной задаче (например, Fineweb-Edu).
Спасибо что дочитали статью. Это мой первый опыт [9] написания подобных текстов.
Буду рад если получится дать фидбек на мои решения. Я в ML недавно, только учусь. Будет интересно послушать людей с опытом.
Автор: morginalium8
Источник [10]
Сайт-источник BrainTools: https://www.braintools.ru
Путь до страницы источника: https://www.braintools.ru/article/29702
URLs in this post:
[1] внимания: http://www.braintools.ru/article/7595
[2] памяти: http://www.braintools.ru/article/4140
[3] обучением: http://www.braintools.ru/article/5125
[4] Кастомная архитектура : https://colab.research.google.com/drive/1TC4ZOKJ_GbvxpS1lZU669b2q0lvIFeyM?usp=sharing
[5] MHA + MQA: https://colab.research.google.com/drive/1iSRfHcrV8CtmnvcmysuUb16BYI6svXMp?usp=sharing
[6] интеллект: http://www.braintools.ru/article/7605
[7] логике: http://www.braintools.ru/article/7640
[8] ссылка: https://drive.google.com/drive/folders/1UMNzLbMWHYBYfimec-RLhGzUYrqKmEL-?usp=sharing
[9] опыт: http://www.braintools.ru/article/6952
[10] Источник: https://habr.com/ru/articles/1030492/?utm_campaign=1030492&utm_source=habrahabr&utm_medium=rss
Нажмите здесь для печати.