97 часов на одной RTX 4090: MoE с подключаемыми экспертами, самодистилляция и почему перплексия — плохая метрика. dynamic architecture.. dynamic architecture. llm.. dynamic architecture. llm. mixture of experts.. dynamic architecture. llm. mixture of experts. moe.. dynamic architecture. llm. mixture of experts. moe. Research.. dynamic architecture. llm. mixture of experts. moe. Research. self-distillation.. dynamic architecture. llm. mixture of experts. moe. Research. self-distillation. Блог компании BorisovAI: Платформа анализа трендов.. dynamic architecture. llm. mixture of experts. moe. Research. self-distillation. Блог компании BorisovAI: Платформа анализа трендов. искусственный интеллект.. dynamic architecture. llm. mixture of experts. moe. Research. self-distillation. Блог компании BorisovAI: Платформа анализа трендов. искусственный интеллект. Машинное обучение.

Меня зовут Борисов Павел, занимаюсь ML-исследованиями. Последние месяцы ковырялся с архитектурой MoE, где эксперты подключаются поверх замороженной модели. 22 эксперимента на одной RTX 4090, ниже разбор что получилось.

Что за архитектура

Берём предобученную языковую модель и замораживаем целиком, ни один вес не меняется. К каждому MLP-слою прикручиваем маленький обучаемый модуль, «эксперт». Сверху маршрутизатор — линейный слой на 37 тысяч параметров, который для каждого токена выбирает эксперта.

Архитектура SEED-NET: замороженный backbone, подключаемые эксперты и обучаемый маршрутизатор

Архитектура SEED-NET: замороженный backbone, подключаемые эксперты и обучаемый маршрутизатор

По сути это MoE [1], но отличие от Mixtral [2] или DeepSeek-MoE [3] принципиальное. Там маршрутизация это часть предобучения на триллионе токенов. Тут базовая модель не трогается, эксперты подключаются как плагины.

Обучение нового эксперта — три шага:

  1. Изоляция. Замораживаем всё кроме одного эксперта. Прогоняем тексты нужной области (математику, код, научные статьи). ~15 минут на GPU.

  2. Интеграция. Размораживаем только маршрутизатор. Показываем тексты из всех областей, он учится направлять токены к нужному эксперту. ~15 минут.

  3. Горячее подключение. Новая область = новый эксперт, повторяем шаги 1-2. Старых не трогаем.

Деградация при подключении нового эксперта: 0.000%. На трёх масштабах. Веса базовой модели заморожены, катастрофическое забывание [4] исключено архитектурно.

Результаты на трёх масштабах

Тестировал на GPT-2 (124M), Pythia-410M и Pythia-1B:

Базовая модель

Областей

Снижение перплексии

Отставание от идеала

Деградация

GPT-2 124M

4

33.4%

6.6%

Pythia-410M

6

34.3%

5.9%

0.000%

Pythia-1B

8

31.2%

3.2%

0.000%

Снижение перплексии ~31-34% на всех масштабах. Отставание от оракула сужается с ростом: 6.6% → 5.9% → 3.2%.

Математический эксперт специализировался лучше всех, перплексия -85.1%, междоменный разрыв 64.9x. Разговорный хуже всех (3.1%) — слишком общий стиль, мало специфики.

Что не работает

Вспомогательные лоссы. Четыре варианта штрафов за неравномерную нагрузку (balance, diversity, entropy, importance), все ухудшили на 11-27%. Wang et al. [5] писали о том же, DeepSeek-V3 [6] от штрафов тоже отказался.

Совместное обучение экспертов и маршрутизатора — коллапс, точность с 80.8% до 73.7%. В DeepSeek-MoE [3] пришли к похожему.

Маршрутизация по меткам. Без подсказок маршрутизатор нашёл границы точнее (6.6%) чем с явными метками (7.3%).

Вместо штрафов использовал безлоссовую балансировку [5] — обучаемое смещение для выравнивания нагрузки. 100% экспертов живы на всех масштабах.

Бенчмарки и проблема с перплексией

Итого: перплексия -31%, маршрутизация 96%, деградация нулевая. Всё отлично, пока не запустил бенчмарки.

31% снижение перплексии на Pythia-1B с 8 экспертами = +0.29 п.п. на MMLU. Почти ничего.

Ладно, Pythia маленькая. Взял Qwen 2.5 3B, она даёт 74.4% на GSM8K из коробки. Обучил математического эксперта. Перплексия на математике -23.9%, междоменный разрыв 64.9x, маршрутизация 100%.

GSM8K после подключения: 65.8%. Минус 8.6 п.п.

Перепроверил три раза. Разморозка верхних слоёв, совместное обучение, двухфазная схема — всё в районе -8.4…-8.6 п.п.

В чём дело: эксперт обучился на учебниках и статьях, выучил статистику языка математики, то есть что после «решим уравнение» скорее идёт формула, а не рецепт. Перплексия от этого снижается, но GSM8K-то требует логику рассуждения, а не знание частотности слов. Hu et al. [7] и Fang et al. [8] показывали околонулевую корреляцию перплексии с бенчмарками, и вот это ровно оно.

Маршрутизатор при этом работал на 0.4% отставания от оракула. На MMLU показал +0.15 п.п. выше оракульного выбора, то есть обходил эксперта на задачах где тот вредит.

Самодистилляция

Раз проблема в данных, попробовал другой подход: обучить эксперта на пошаговых решениях самой модели вместо сырого текста. По мотивам STaR [10], только STaR дообучивает всю модель, а тут внешний эксперт поверх замороженной.

Взял 750 задач GSM8K, Qwen решил 638 правильно. Получилось 119 тысяч токенов, это в 33 раза меньше чем 4 миллиона токенов сырого текста. Формат «Вопрос/Ответ», как при инференсе.

GSM8K: 75.5%. +1.1 п.п. к базе, +9.7 п.п. к варианту с сырым текстом. При этом перплексия ухудшилась на 17.8%.

Ещё заметил что формат данных важен: «Вопрос/Ответ» (совпадает с форматом инференса) дал +2-3 п.п. по сравнению с «Задача/Решение». Для сравнения, LoRA-вариант (13.4 млн параметров вместо 67.6 млн) показал 74.0%, всего -0.4 п.п. от базы, но маршрутизации там нет.

Цикл самоулучшения

Дальше захотелось замкнуть цикл: модель решает задачи, обучаем эксперта на решениях, модель решает лучше, обучаем снова…

Цикл

Верных

Новых

GSM8K

0 (исходный)

638/750

75.5%

1

658/750

+20

75.5%

2

668/750

+10

76.0%

+20, потом +10, затухает. Но 76.0% вроде есть. Проблема в другом.

Ошибка с seed

При создании эксперта через QwenWithMoE() PyTorch инициализировал веса рандомно. Seed я не фиксировал. Разброс от инициализации ~5 п.п., а эффект цикла 0.5 п.п.

После torch.manual_seed(42) и увеличения выборки до 500 задач:

Seed

GSM8K

42

76.4%

123

75.2%

456

76.0%

Среднее

75.87% ± 0.61 п.п.

Перепроверка цикла

С фиксированным seed и 500 задачами:

Холодный старт (свежий эксперт каждый цикл): 76.4% → 74.6% → 74.6%. Плато.

Тёплый старт (продолжаем обучение): 76.4% → 75.0% → 71.6%.

При тёплом старте перплексия продолжала падать: 1.58 → 1.45 → 1.36. По лоссам всё хорошо, а GSM8K деградирует. Причина: между циклами 85-90% задач повторялись, эксперт на них переобучался.

То «улучшение» 75.5% → 76.0% из таблицы выше — статистический шум. На 200 задачах доверительный интервал ~5 п.п., эффект 0.5 п.п.

Label smoothing

Попробовал сглаживание меток, минус 9 п.п. на GSM8K. По Müller et al. [11] сглаживание делает варианты ответа более равновероятными. В классификации картинок это нормально, но в математике «15 минус 7» это 8, а не «скорее 8 чем 7». Каждый промежуточный шаг рассуждения должен быть точным.

Итоги

С архитектурой всё хорошо: 0.000% деградации, 96% точность маршрутизации, безлоссовая балансировка [5] вместо штрафов. Самодистилляция дала +9.7 п.п. по сравнению с обучением на сыром тексте (119 тысяч токенов собственных решений vs 4 миллиона из учебников). Замкнуть цикл самоулучшения не получилось, задачи повторяются между итерациями и эксперт переобучается. И главное — перплексия оказалась бесполезной для оценки рассуждений, она может падать когда бенчмарки деградируют и расти когда они улучшаются.


Ссылки

  1. Shazeer et al. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. arxiv:1701.06538

  2. Jiang et al. Mixtral of Experts. 2024. arxiv:2401.04088

  3. Dai et al. DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models. 2024. arxiv:2401.06066

  4. Kirkpatrick et al. Overcoming Catastrophic Forgetting in Neural Networks. PNAS 2017. arxiv:1612.00796

  5. Wang et al. Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts. 2024. arxiv:2408.15664

  6. DeepSeek-AI. DeepSeek-V3 Technical Report. 2024. arxiv:2412.19437

  7. Hu et al. Can Perplexity Reflect Large Language Model’s Ability in Long Text Understanding? 2024. arxiv:2405.06105

  8. Fang et al. What is Wrong with Perplexity for Long-context Language Modeling? 2024. arxiv:2410.23771

  9. Wei et al. Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. NeurIPS 2022. arxiv:2201.11903

  10. Zelikman et al. STaR: Bootstrapping Reasoning With Reasoning. NeurIPS 2022. arxiv:2203.14465

  11. Müller, Kornblith, Hinton. When Does Label Smoothing Help? NeurIPS 2019. arxiv:1906.02629


Код и результаты экспериментов: GitVerse | GitFlic*

Автор: borisovai-ru

Источник

Rambler's Top100