- BrainTools - https://www.braintools.ru -

GFusion: как мы обучали диффузионную LLM в GigaChat

Салют, Хабр!

Хочу поделиться проектом, которым я занимался во время стажировки в команде GigaChat Pretrain. В течение нескольких месяцев мы исследовали диффузионные языковые модели (dLLM) — относительно новое направление в LLM, в котором многие идеи только начинают проверяться на практике.

Главной целью было не тратить огромное количество ресурсов на обучение [1] с нуля, а взять базовую авторегрессионную модель GigaChat3-10B-A1.8B-base и перевести её в диффузионный режим. Так появились наши экспериментальные GFusion-10B-A1.8B-base [2] и GFusion-10B-A1.8B [3]!

GFusion: как мы обучали диффузионную LLM в GigaChat - 1

TLDR

Прежде чем углубляться в детали, отмечу главные результаты проекта:

  • Ускорение генерации. В режиме одного пользователя GFusion в среднем на 70% быстрее своего брата GigaChat3, и даже превосходит вариант с дополнительной MTP-головой на 39%.

  • Способности модели. GFusion достигает ускорения при просадке качества всего на 2-4 п. п. в сравнении с GigaChat3, причем этот трейд-офф можно контролировать через параметры генерации диффузии.

  • Open-source. Мы поддержали GFusion в SGLang, а также добавили новый алгоритм семплирования, который ускоряет генерацию других диффузионных LLM.

  • Ускорение обучения. Оптимизировали attention для текстовой диффузии, что позволило нам получить +60% к средней скорости обучения GFusion по сравнению с Flex-Attention.

Сравнение скорости генерации GFusion и вариантов GigaChat3

Сравнение скорости генерации GFusion и вариантов GigaChat3

Далее я подробнее расскажу о том, как мы проходили полный цикл обучения, какие подходы к обучению dLLM сработали лучше, откуда берётся ускорение и с какими ограничениями мы столкнулись по дороге.


Почему диффузия?

Большинство современных LLM являются авторегрессионными (AR): каждый следующий токен предсказывается по предыдущему контексту. Эта схема отлично оптимизирована на практике, но у неё есть фундаментальное ограничение: строго последовательная генерация по одному токену.

Часть этого ограничения умеют обходить через спекулятивный декодинг, когда более мелкая модель предлагает несколько последовательных предсказаний, а основная верифицирует их за один проход. Это позволяет получить ускорение, но основная модель в любом случае остаётся авторегрессионной.

Диффузионные модели ослабляют это ограничение: они работают не с одним следующим токеном, а с частично замаскированным блоком фиксированного размера. Модель итеративно восстанавливает маски, постепенно превращая блок в осмысленный текст. Токены внутри одного блока могут декодироваться не только слева направо, а в произвольном порядке и по несколько за один проход. При этом сами блоки генерируются авторегрессионно, что позволяет переиспользовать KV-кеш для уже готового контекста.

Процесс генерации диффузии на примере блока из четырёх токенов

Процесс генерации диффузии на примере блока из четырёх токенов

Если за один проход модель декодирует не один, а сразу несколько токенов, то для генерации той же последовательности требуется меньше forward pass-ов. Это удобно измерять метрикой TPF (Tokens-Per-Forward) — средним количеством токенов, которое модель финализирует за один проход.

В таком контексте обычную авторегрессионную генерацию можно рассматривать как частный случай диффузии с блоком размером 1: модель генерирует ровно один токен, поэтому её TPF всегда равен 1. Выигрыш у диффузионного подхода появляется только в том случае, когда TPF становится заметно больше 1.

Что разберем дальше

Полный цикл обучения GFusion состоял из следующих этапов:

  • Базовый претрейн: основная часть адаптации AR-модели к диффузионной генерации.

  • Мидтрейн: стадия прокачивания способностей и знаний модели.

  • Расширение контекста: увеличиваем контекстное окно до 32 тыс. токенов.

  • SFT: учим модель следовать инструкциям и формату.

  • Confidence tuning: короткая стадия после SFT для ускорения генерации.

Не буду подробно разбирать каждый этап, а сосредоточусь на том, что характерно именно для dLLM: на экспериментах на претрейне, алгоритмах декодирования и ускорении на SFT.

Адаптация AR-модели к диффузии

В свежих статьях по dLLM предлагают много разных решений: новые функции потерь (loss-функции), attention-маски, стратегии зашумления и способы моделирования. Почти каждая работа показывает преимущество своего подхода, но зачастую сравнения проводят на разных моделях, данных и бюджетах обучения. Поэтому для начала мы решили сравнить основные варианты для нашего единого сетапа.

Как обучается диффузия?

Пусть есть исходная последовательность из обучающей выборки:

x_0=(x_0^{(1)}, dots, x_0^{(L)}).

Из нее мы получаем зашумлённую версию:

x_t=(x_t^{(1)}, dots, x_t^{(L)}), quad t in (0, 1).

Для этого сначала семплируем уровень шума t in (0, 1), а затем каждый токен заменяем на токен маски с этой вероятностью. Модель обучается восстанавливать замаскированные токены, имея доступ к чистому контексту перед текущим блоком и к зашумлённому состоянию самого блока.

В наших экспериментах лучше всего сработал шум из распределения t sim mathcal{U}(0.25, 0.85).

Интуиция [4] простая: при слишком маленьком шуме задача становится почти тривиальной, а при слишком большом — чрезмерно сложной. Диапазон от 0,25 до 0,85 даёт модели разные уровни сложности, но не превращает обучение в угадывание почти всей последовательности с нуля.

В отличие от стандартного AR-обучения, на вход диффузионной модели подаётся

x=text{Concat}(x_t, x_0),

то есть не только зашумлённая последовательность, но и исходная.

Для такой схемы обучения требуется особенная attention-маска. Необходимо, чтобы внутри одного блока токены могли взаимодействовать друг с другом, а между блоками сохранялась авторегрессионная зависимость. Функция потерь при этом остаётся обычной кросс-энтропией, однако мы считаем её только для замаскированных позиций x_t.

Мы также рассматривали гибридную постановку, когда к основной диффузионной функции потерь добавляют AR-регуляризацию на токены исходной последовательности x_0. В этом варианте модель одновременно учится восстанавливать замаскированные токены и поддерживать авторегрессионное обновление контекста.

Attention-маски для разных подходов моделирования

Attention-маски для разных подходов моделирования

Размер блока

В качестве целевого размера блока обычно берут от 16 до 64 токенов. Он должен быть достаточно большим, чтобы появлялся параллелизм и потенциальное ускорение, но не настолько большим, чтобы задача восстановления становилась невозможной.

Самый простой способ превратить AR-модель в dLLM — сразу обучать её с блоком целевого размера. Однако для авторегрессионной модели такой переход оказывается слишком резким, так как задача меняется с предсказания одного следующего токена на восстановление целого блока.

Более стабильным оказался вариант с постепенным увеличением размера блока, например:

B=1 rightarrow 2 rightarrow 4 rightarrow 8 rightarrow 16

В этом режиме заметно улучшается сходимость, а модель показывает более высокие значения метрик.При размере блока 1 задача почти совпадает с авторегрессионной генерацией, поэтому переход от AR к dLLM получается более плавным. По мере дальнейшего увеличения блока задача модели постепенно усложняется, что заставляет её восстанавливать всё больше токенов за один проход.

Эффект от постепенной адаптации блока для 3B-модели

Эффект от постепенной адаптации блока для 3B-модели

Тонкости моделирования

Для обучения модели необходимо определиться с методом генерации. Есть два основных варианта:

  • Предсказывать следующий токен, если он замаскирован. Такой метод лучше использует знания исходной AR-модели, которая была обучена почти на эту же задачу.

  • Предсказывать текущий замаскированный токен. Это стандартная постановка для современных dLLM: модель восстанавливает токены, на позициях которых стоят токены маски.

Иллюстрация подходов моделирования

Иллюстрация подходов моделирования

Мы сравнили оба варианта и не увидели значимой разницы в результатах. В итоге остановились на стандартном моделировании с предсказанием текущего замаскированного токена, так как этот подход идейно проще и не создаёт дополнительных проблем при генерации.

Кроме чисто диффузионного обучения, можно использовать и гибридную постановку. Идея в том, чтобы к основной функции потерь dLLM, которая считается по зашумлённой части входа x_t, с некоторым весом lambda прибавлять авторегрессионную loss-функцию по чистой последовательности x_0:

mathcal{L_{mathrm{total}}}=mathcal{L_{mathrm{dLLM}}} + lambda mathcal{L_mathrm{AR}}, quad lambda geq 0.

В таком случае mathcal{L_mathrm{AR}} выступает в роли регуляризации, помогающей сохранить знания исходной AR-модели, в то время как диффузионная часть учится генерировать блоки.

Сравнение метрик для стандартной и гибридной постановок диффузии на претрейне

Сравнение метрик для стандартной и гибридной постановок диффузии на претрейне

Гибридная постановка действительно давала прирост метрик на ранних шагах, но по мере обучения вариант без авторегрессионного компонента заметно её обгонял. AR-регуляризация полезна на ранней стадии адаптации, однако на более поздних шагах начинает мешать модели полностью перейти в диффузионный режим. Для основного этапа обучения мы решили использовать чисто диффузионную функцию потерь.

Авторегрессия не отпускает

Интуитивно понятно, что токены на начальных позициях блока предсказать сильно проще, так как они находятся ближе к уже декодированному контексту. Если для блока размером 32 посмотреть на средний номер шага, на котором финализируется каждый токен, то эта гипотеза подтверждается и на практике.

Средний номер шага, на котором декодируется токен маски для каждой позиции

Средний номер шага, на котором декодируется токен маски для каждой позиции

Такое наблюдение наводит на мысль, что есть некоторое расхождение между обучением и инференсом. На этапе генерации модель зачастую видит блоки, в которых чистый контекст находится левее, а замаскированные токены — правее. На обучении же токены маскируются независимо и равновероятно, поэтому распределение входов получается другим.

Мы попробовали несколько вариантов авторегрессионного семплирования, которые делают обучающие примеры более похожими на режим генерации. Как и в случае с дополнительной AR-функцией потерь, качество на ранних шагах становилось лучше, но по ходу обучения сильно уступало равномерному маскированию.

Относительное изменение метрик при использовании авторегрессионного семплирования

Относительное изменение метрик при использовании авторегрессионного семплирования

Стадия претрейна

После экспериментов мы зафиксировали финальный рецепт претрейна: предсказание текущего замаскированного токена, постепенный рост размера блока и стандартная диффузионная loss-функция.

Модель инициализировали весами GigaChat3-10B-A1.8B-base и обучали в несколько стадий на контексте в 4 тыс. токенов. На начальных блоках размером 1 и 2 мы использовали дополнительную AR-функцию потерь, а затем отключали её. Количество шагов для каждого размера блока подбирали так, чтобы модель успевала адаптироваться на каждом этапе.

Расписание learning rate для претрейна

Расписание learning rate для претрейна

Отдельно отмечу стадию с блоком размером 8. На ней мы наблюдали заметные просадки метрик, поэтому решили сделать её короче. Вероятная причина в том, что для такого блока задача восстановления уже становится достаточно сложной, но ещё не дает модели достаточно контекста для его декодирования.

После базового претрейна мы перешли к мидтрейну, где фокус сместился с адаптации к диффузионному режиму на прокачивание способностей модели за счёт более качественных данных. Всего на этом этапе модель сделала 20 тысяч шагов, а суммарно за обе стадии прошла около 530 млрд токенов.

Декодирование: где появляется ускорение

Для обученной модели необходим алгоритм семплирования — после каждого шага от него зависит, какие токены маски можно заменить на обычные токены, а какие ещё нет.

Большинство современных dLLM использует threshold sampling. Для этого метода токен маски декодируется только для тех позиций, для которых вероятность самого уверенного токена выше заранее заданного порога tau in [0, 1). Этот алгоритм хорош своей простотой, однако есть один недостаток: он рассматривает каждую позицию независимо друг от друга.

Мы решили рассмотреть и альтернативный подход из статьи [5]entropy-bounded sampling. Первым шагом для каждой замаскированной позиции вычисляем энтропию предсказанного распределения. Затем мы декодируем токены в порядке от самых уверенных к более неопределённым, пока накопленная энтропия не превысит заранее заданный порог gamma geq 0.

Сравнение threshold и entropy-bounded алгоритмов для чекпоинта после мидтрейна

Сравнение threshold и entropy-bounded алгоритмов для чекпоинта после мидтрейна

Затем мы провели сравнение обоих алгоритмов на нашей модели после мидтрейна. В результате EB-семплер превзошел базовый алгоритм во всем — генерация стала быстрее, так еще и метрики выросли. Отсюда можно сделать вывод, что от алгоритма декодирования напрямую зависят качество генерации и то, какого ускорения удается достичь. В дальнейшем мы везде используем EB-семплирование.

Расширение контекста

Следующим шагом мы расширили контекст модели до окна в 32 тыс. токенов, дообучив её на 30 млрд токенов длинных примеров. Для интерполяции RoPE c 4 тыс. до 32 тыс. использовали YaRN.

Эта стадия менее характерна именно для диффузионной генерации, но она важна для практического использования модели и дальнейшего обучения на SFT. Стоит отметить, что после этапа претрейна удалось почти полностью избежать снижения качества по сравнению с авторегрессионным чекпоинтом.

Сравнение базовых моделей с GFusion-10B-A1.8B-base, символом * отмечена авторегрессионная модель.

Сравнение базовых моделей с GFusion-10B-A1.8B-base, символом * отмечена авторегрессионная модель.

Стадия SFT

Complementary masking

На стадии SFT мы использовали идею complementary masking. В нашей постановке функция потерь считается только по замаскированным позициям, а остальная часть токенов не используется. В качестве решения для каждого обучающего примера будем добавлять парный пример с обратной маской: видимые токены маскируются, и наоборот.

Пример применения идеи complementary masking

Пример применения идеи complementary masking

Интересно, что лучший результат мы получили после шести эпох обучения — это вдвое больше, чем обычно требуется для авторегрессии. Такая разница объясняется тем, что на каждой эпохе диффузия может маскировать один и тот же пример множеством способов, что снижает эффект переобучения и позволяет извлекать из данных больше сигнала.

Confidence tuning (CT)

После SFT мы добавили короткую стадию для дополнительного ускорения модели. Для этого мы добавили дополнительный компонент к функции потерь, который штрафует модель за высокую энтропию распределения для корректно предсказанных токенов:

mathcal{L}_{mathrm{total}}=mathcal{L}_{mathrm{dLLM}} + beta sum_{t in mathcal{T}} H(p_t), quad

где mathcal{T} — корректно предсказанные токены, beta geq 0.

Чем более уверенные предсказания делает диффузионная модель, тем больше позиций при EB-семплировании мы можем декодировать за один шаг. Отсюда растёт TPF, и для генерации ответа модели требуется меньше вычислений.

При этом включать CT сразу было бы плохой идеей. В начале обучения модель корректно предсказывает в основном частотные слова, пунктуацию, очевидные продолжения. Если начать снижать энтропию в самом начале SFT, то появится некоторое смещение в сторону ускорения генерации таких простых токенов.

На практике beta=0.3 и обучение в 800 дополнительных шагов показали себя лучше всего с точки зрения [6] ускорения без существенной деградации модели.

Результаты

Итоговые варианты GFusion мы главным образом сравнивали с GigaChat3, а также другими диффузионными языковыми моделями похожих размеров (LLaDA-MoE-7B, LLaDA2.0-mini preview). Результаты приведены ниже — деградация качества по сравнению с авторегрессией действительно сильно заметнее, чем после стадии претрейна, однако GFusion не отстаёт от аналогичных dLLM.

Сравнение моделей с GFusion-10B-A1.8B, символом * отмечена авторегрессионная модель

Сравнение моделей с GFusion-10B-A1.8B, символом * отмечена авторегрессионная модель

Мы отдельно замерили скорость генерации GFusion до и после стадии confidence tuning. Для этого использовали aiperf [7] + SGLang в режиме одного пользователя (concurrency = 1) на 1xH100.

Сравнение decode TPS для GigaChat3, GFusion и GFusion после стадии confidence tuning

Сравнение decode TPS для GigaChat3, GFusion и GFusion после стадии confidence tuning

GFusion-10B-A1.8B в среднем даёт +70% к скорости генерации по сравнению с GigaChat3-10B-A1.8B, и даже обгоняет вариант с MTP на +39%. Если же понизить порог gamma до 0,50, то модель заметно замедляется, однако качество ответов растёт.

Инфраструктура

SGLang

Для удобства использования GFusion мы добавили её поддержку [8] в SGLang. Для работы MLA в диффузионном режиме с attention-головой размера 192 (не степень двойки) пришлось реализовать недостающую логику [9] для бекенда Flash-Attention.

Мы также добавили новый алгоритм entropy-bounded семплирования. Он показывает себя лучше по скорости и качеству не только для GFusion, но и для других диффузионных языковых моделей в SGLang.

Ускоряем обучение

Для обучения диффузии требуется кастомная attention-маска, поэтому использовать Flash-Attention не получится. Для такого случая есть Flex-Attention от PyTorch, который позволяет эффективно вычислять attention с произвольной маской, однако в нашем случае обучение всё равно далеко от оптимального.

Мы написали свою реализацию ядер на TileLang, которые нативно учитывают чёткую структуру диффузионной маски. Это позволило получить ускорение в +41,7% и +77,4% end-to-end при обучении на контексте в 4 тыс. и 32 тыс. токенов соответственно. Все реализации мы также выкладываем в открытый доступ [10].

Сравнение Flex-Attention с нашей реализацией на TileLang

Сравнение Flex-Attention с нашей реализацией на TileLang

Выводы

В проекте GFusion мы пытались не только обучить свою модель и получить ускорение по сравнению с авторегрессией, но и чуть глубже исследовать влияние отдельных компонентов на разных стадиях обучения, консистентно сравнить известные подходы и оценить их применимость.

Самым ценным в проекте оказался даже не сам результат, а возможность пройти весь путь от проектирования ML-экспериментов до инженерных вызовов, которые необходимо преодолеть для доведения идеи до работающей модели.

Если вам в целом интересны LLM и всё, что происходит вокруг них, — от исследовательских ML-фичей и архитектур до низкоуровневой инфраструктуры и оптимизации моделей, — приходите в проекты GigaChat!

Автор: perkyfever

Источник [11]


Сайт-источник BrainTools: https://www.braintools.ru

Путь до страницы источника: https://www.braintools.ru/article/32566

URLs in this post:

[1] обучение: http://www.braintools.ru/article/5125

[2] GFusion-10B-A1.8B-base: https://huggingface.co/ai-sage/GFusion-10B-A1.8B-base

[3] GFusion-10B-A1.8B: https://huggingface.co/ai-sage/GFusion-10B-A1.8B

[4] Интуиция: http://www.braintools.ru/article/6929

[5] статьи: https://arxiv.org/abs/2505.24857

[6] зрения: http://www.braintools.ru/article/6238

[7] aiperf: https://github.com/ai-dynamo/aiperf

[8] поддержку: https://github.com/sgl-project/sglang/pull/29776

[9] логику: http://www.braintools.ru/article/7640

[10] открытый доступ: https://github.com/tile-ai/tilelang/pull/2499

[11] Источник: https://habr.com/ru/companies/sberbank/articles/1054690/?utm_campaign=1054690&utm_source=habrahabr&utm_medium=rss

www.BrainTools.ru

Rambler's Top100