Как я создал альтернативу трансформерам. llm.. llm. mamba.. llm. mamba. transformers.. llm. mamba. transformers. глубокое обучение.. llm. mamba. transformers. глубокое обучение. искусственный интеллект.. llm. mamba. transformers. глубокое обучение. искусственный интеллект. математика.. llm. mamba. transformers. глубокое обучение. искусственный интеллект. математика. Машинное обучение.. llm. mamba. transformers. глубокое обучение. искусственный интеллект. математика. Машинное обучение. машинное+обучение.. llm. mamba. transformers. глубокое обучение. искусственный интеллект. математика. Машинное обучение. машинное+обучение. нейросети.

В этой статье я хочу остановиться на разборе предложенной мной архитектуры декодера и тех вариантов, с которыми я сравниваю её в исследовании, но сделать это проще и интуитивнее, чем в самой работе. На мой взгляд, существующие объяснения архитектур декодеров часто подаются разрозненно. Каждый подход описывают отдельно, без общей опоры. А ведь всё можно свести к одному фундаменту, и тогда становятся гораздо заметнее как сильные стороны каждого решения, так и их ограничения.

Для начала приведу все необходимые ссылки.

Само исследование: https://arxiv.org/abs/2604.18580

Код: https://github.com/LibratioAI/sessa

Структура статьи:

  1. Transformer

  2. S4D и Mamba

  3. Sessa: Selective State Space Attention

Transformer

Начнём с классики — трансформера, но посмотрим на него под другим углом: не на то, как он устроен и какие у него компоненты, а на то, как идейно к нему прийти. Для этого мы воспользуемся свёрткой. В нашем случае для простоты будем оперировать дискретными функциями со скалярными элементами.

К примеру, рассмотрим на входе последовательность из пяти элементов (слова, эмбеддинги и т. п.) и ядро из трех элементов. В классическом определении ядро свёртки отражается по нулю (переворачивается), но будем считать, что это уже сделано.

Наша цель — посмотреть, какую информацию такой миксер может извлекать из префикса последовательности. В decoder-only модели выход миксера затем используется для предсказания следующего элемента, но сначала рассмотрим саму операцию смешивания на примере длины 5.

Пусть вход

x=[x_1,x_2,x_3,x_4,x_5],

а ядро

mathbf b=bigl[b^{(0)},b^{(1)},b^{(2)}bigr].

Тогда выходы миксера на позициях могут быть, например,

hat x_3=b^{(0)}x_3+b^{(1)}x_2+b^{(2)}x_1,quad hat x_4=b^{(0)}x_4+b^{(1)}x_3+b^{(2)}x_2,quad hat x_5=b^{(0)}x_5+b^{(1)}x_4+b^{(2)}x_3.

В decoder-only модели такой выход в позиции t затем используется для предсказания следующего элемента x_{t+1}. Но дальше нас будет интересовать именно сам миксер: какую информацию из префикса он извлекает в позиции t. Поэтому hat x_t далее обозначает выход миксера, а не финальное предсказание следующего токена.

И, как видим, у нас есть две проблемы: ядро фиксировано по коэффициентам и имеет фиксированную длину окна. Такая свёртка в теории цифровой обработки сигналов называется FIR (Finite Impulse Response, конечный импульсный отклик) фильтром: его реакция на единичный импульс длится конечное число шагов, потому что ядро имеет конечную поддержку.

Попробуем исправить этот крайне неприятный недостаток: сделаем коэффициенты адаптивными относительно входа, например, просто через умножение wcdot x. Тогда:

b_t^{(k)}=f(x_{t-k})=w,x_{t-k},qquad kin{0,1,2}.

и выход миксера в позиции t:

hat{x}_{t}=sum_{k=0}^{2} b_t^{(k)},x_{t-k}=(w x_t),x_t + (w x_{t-1}),x_{t-1} + (w x_{t-2}),x_{t-2}.

Например, если t=5:

hat{x}_5=(w x_5),x_5 + (w x_4),x_4 + (w x_3),x_3.

При этом есть одно “но”: даже если сделать коэффициенты зависящими от входа, это само по себе не снимает ограничения по длине. Пока мы работаем в фиксированном окне радиуса K, оператор остаётся локальным и видит только последние K элементов. А если расширить суммирование на большее число позиций (вплоть до всего контекста), то при увеличении длины меняется режим вычисления: в сумме появляется больше слагаемых, и без нормировки выход начинает дрейфовать по масштабу и по распределению вкладов. Поэтому веса нужно сделать устойчивыми к росту T, например, так, чтобы их сумма была фиксированной. И здесь естественно появляется softmax: он нормирует сырые коэффициенты в веса, суммирующиеся к единице, хотя, конечно, это ещё не attention как механизм как таковой.

В нашем примере:

b_t^{(0)}=w x_t,quad b_t^{(1)}=w x_{t-1},quad b_t^{(2)}=w x_{t-2}.tilde b_t^{(k)}=frac{exp!big(b_t^{(k)}big)} {sum_{r=0}^{2}exp!big(b_t^{(r)}big)}, qquad kin{0,1,2}.hat{x}_{t}=sum_{k=0}^{2}tilde b_t^{(k)},x_{t-k}.

Например, при t=5:

hat{x}_5=tilde b_5^{(0)}x_5 + tilde b_5^{(1)}x_4 + tilde b_5^{(2)}x_3. tilde b_5^{(0)}=frac{e^{b_5^{(0)}}} {e^{b_5^{(0)}}+e^{b_5^{(1)}}+e^{b_5^{(2)}}}, quad   tilde b_5^{(1)}=frac{e^{b_5^{(1)}}} {e^{b_5^{(0)}}+e^{b_5^{(1)}}+e^{b_5^{(2)}}},  quad  tilde b_5^{(2)}=frac{e^{b_5^{(2)}}} {e^{b_5^{(0)}}+e^{b_5^{(1)}}+e^{b_5^{(2)}}}.

Несмотря на то что мы исправили две ключевые проблемы, такая система всё ещё не справляется со многими задачами. Возьмём, например, простую подзадачу для миксера: хотим, чтобы выход в позиции t извлекал предыдущий вход,

hat x_tapprox x_{t-1}.

Это означает, что на каждом шаге вес на позиции x_{t-1} должен быть больше, чем на позиции x_t и на позиции x_{t-2}. В частности, достаточно следующих требований:

text{для }t=3:quad tilde b_3^{(1)}>tilde b_3^{(0)},qquad text{для }t=4:quad tilde b_4^{(1)}>tilde b_4^{(2)}.

Но softmax монотонен по своим аргументам: если  tilde b_t^{(a)}>tilde b_t^{(b)}, то обязательно b_t^{(a)}>b_t^{(b)}. Значит из требований получаем:

tilde b_3^{(1)}>tilde b_3^{(0)} Rightarrow b_3^{(1)}>b_3^{(0)} Rightarrow w x_2 > w x_3,  quad  tilde b_4^{(1)}>tilde b_4^{(2)} Rightarrow b_4^{(1)}>b_4^{(2)} Rightarrow w x_3 > w x_2.

Получается противоречие: одновременно требуется w x_2 > w x_3 и w x_3 > w x_2.

Если интуитивно, то когда оценки зависят только от важности токена самого по себе, то есть фактически только от источника, модель не может менять предпочтение, кого копировать, в зависимости от шага t.Именно здесь и появляется необходимость в том, чтобы оценки зависели от того, кто читает и на что смотрит. И вот тут уже естественно появляется классический attention (с позиционным кодированием): вместо важности токена самой по себе, мы нормируем совместимость пары t, j,

b_t^{(j)}=f(x_t,x_j)=w,x_t x_j+r_{t-j} equiv beth_{t,j}, qquad beth_{t,j}=frac{q_t^top k_j}{sqrt{d_k}}+r_{t-j}. tilde b_{t,j}=operatorname{softmax}_{uinmathcal W(t)}!big(b_t^{(u)}big)_j=frac{exp!big(b_t^{(j)}big)}{sum_{uinmathcal W(t)}exp!big(b_t^{(u)}big)},qquad jinmathcal W(t). hat{x}_{t}=sum_{jinmathcal W(t)}tilde b_{t,j},x_j.

Где  r_{t-j} отвечает за позиционную информацию

И на нашем примере с конечным окном K=3, мы смотрим только на jin{t,t-1,t-2}. При t=5:

b_5^{(5)}=w,x_5x_5+r_0,quad b_5^{(4)}=w,x_5x_4+r_1,quad b_5^{(3)}=w,x_5x_3+r_2. tilde b_{5,5}=frac{e^{b_5^{(5)}}}{e^{b_5^{(5)}}+e^{b_5^{(4)}}+e^{b_5^{(3)}}},quad  tilde b_{5,4}=frac{e^{b_5^{(4)}}}{e^{b_5^{(5)}}+e^{b_5^{(4)}}+e^{b_5^{(3)}}},quad  tilde b_{5,3}=frac{e^{b_5^{(3)}}}{e^{b_5^{(5)}}+e^{b_5^{(4)}}+e^{b_5^{(3)}}},hat{x}_5=tilde b_{5,5}x_5+tilde b_{5,4}x_4+tilde b_{5,3}x_3.

В общем виде, когда x_t является вектором, вместо скаляра w используются линейные отображения для построения q_tи k_jс произвольным K.Вот и всё: теперь к этому attention добавляем MLP и получаем слой трансформера.

А теперь поговорим немного о свойствах этой архитектуры.

Память

Чтобы честно сравнить память трёх архитектур, нужно выбрать согласованный режим анализа. Потому что можно подобрать такие параметры, при которых каждая из них будет лучше другой, или наоборот. В контексте исследования я рассматривал контролируемый общий режим, в котором моделям сложно сфокусироваться на определённом элементе входной последовательности, то есть ломаем sharp retrieval. В случае трансформера это диффузный режим. А для самой оценки нужно понять, как выход в момент t чувствителен к токену в прошлом tau, используя якобианы.

Для создания диффузного режима в трансформере достаточно, чтобы разброс логитов, то есть разница между максимальным и минимальным значением перед softmax, был ограничен константой, не растущей с длиной контекста. Тогда softmax не может устойчиво сконцентрировать почти всю массу на одном токене, и веса остаются распределёнными по большому числу позиций.

При этом я доказываю оценки для режимов:

Однослойный: Если рассматривать режим без заморозки коэффициентов attention, верхняя оценка равна:

qquad Oleft(frac{1}{ell}right)

А вот нижняя оценка из-за добавления градиента по коэффициентам attention тривиальна: она больше или равна нулю, потому что путь по коэффициентам attention может частично компенсировать value-путь, в худшем случае может почти занулить.

Многослойный: рассматриваем стек из L слоёв attention и смотрим полный градиент от выхода на позиции t к входу на позиции tau.

В этом режиме верхняя оценка равна:

Oleft( frac{(log(1+ell))^{L-1}}{1+ell} right), qquad ell=t-tau.

Глубина добавляет дополнительные маршруты через промежуточные позиции и даёт логарифмическое усиление по сравнению с одним слоем:

L=1:quad Oleft(frac{1}{ell}right), qquad L=2:quad Oleft(frac{log ell}{ell}right).

Но при фиксированной глубине L влияние старых токенов всё равно затухает:

frac{(log ell)^{L-1}}{ell}to 0.

А нижняя оценка, как и в однослойном режиме, остаётся тривиальной: она больше или равна нулю.

В исследовании я также ввожу термины one path / many paths (пути) и one-hop / multi-hop (переходы). Они интуитивно объясняют, почему в диффузных режимах память ведёт себя по-разному у трёх архитектур. В контексте трансформера и механизма внимания удобно представить маршрутизацию влияния как ориентированный ациклический граф (DAG) по временным индексам: для трансформера это граф с прямыми рёбрами tau to t (при tau le t), где вес ребра задаётся вниманием alpha_{t,tau}. Тогда влияние из tau в t внутри одного слоя реализуется одним переходом (one-hop) и по одному пути (one path). Так как интуитивно, обратиться к определённому элементу attention может только одним способом в одном слое — дать ему максимальный вес, то есть один путь и один шаг.

Как я создал альтернативу трансформерам - 59

S4D и Mamba

Целью данных SSM является исправить недостаток attention, а именно квадратичную сложность вычислений при росте длины последовательности, при этом оставаясь примерно на том же уровне качества на длинных контекстах. И в целом цель хорошая, если бы не несколько фундаментальных “но”, которые мы в этом разделе обсудим. Но, чтобы прийти интуитивно к Mamba, мы сделаем это через промежуточный этап в виде S4D, возьмём нашу предыдущую LTI-систему и посмотрим на неё с другой стороны.

И вместо того, чтобы сначала делать коэффициенты адаптивными, а потом пытаться обобщить это на произвольную длину, сделаем наоборот: исправим сначала проблему конечного импульсного отклика, а потом уже будем думать про адаптивность. Основная идея такая: мы можем передавать дальше не только текущий взвешенный вход, но и состояние y, в котором накапливается прошлое. Тогда каждый новый вклад будет суммироваться с накопленным предыдущим, и система сможет помнить сколь угодно далёкий вход.

В цифровой обработке сигналов такая система известна как IIR-фильтр, а эта форма записи — как разностное уравнение. И теперь у нас есть кроме forward-ветки, которую мы рассматривали в предыдущем пункте ещё и feedback. И отталкиваясь от предыдущего пункта, мы учитываем как в нашем примере три последних x и взвешиваем их коэффициентами b:

y_t=y_{t-1} + b^{(0)}x_t + b^{(1)}x_{t-1} + b^{(2)}x_{t-2},qquad hat x_t=y_t.

Очевидно, что сейчас в качестве коэффициента в обратной связи мы используем единицу, но нам никто не мешает сделать его произвольным:

y_t=a y_{t-1} + b^{(0)}x_t + b^{(1)}x_{t-1} + b^{(2)}x_{t-2},qquad hat x_t=y_t.

И тут всплывает довольно важный момент про устойчивость системы. Есть два близких, но разных способа на это смотреть.

BIBO-стабильность: если вход ограничен, то выход тоже должен оставаться ограниченным. Обычно это рассматривают при нулевом начальном состоянии.

Внутренняя устойчивость: если вход равен нулю, но начальное состояние ненулевое, то собственная динамика системы не должна разгоняться.

Рассмотрим простой случай:

y_t=a y_{t-1} + u_t,

где

u_t=b^{(0)}x_t+b^{(1)}x_{t-1}+b^{(2)}x_{t-2}.

Коэффициент обратной связи равен 1.

Если для наглядности взять постоянный вход x_tequiv 1, то при

b^{(0)}+b^{(1)}+b^{(2)}ne 0

получаем

u_t=b^{(0)}+b^{(1)}+b^{(2)}.

Тогда

y_t=y_{t-1}+u_t,

и значит

y_t=y_0+tbigl(b^{(0)}+b^{(1)}+b^{(2)}bigr).

То есть выход растёт линейно. Поэтому при a=1 глобальную BIBO-стабильность в общем случае гарантировать нельзя.

Коэффициент по модулю больше 1.

Если |a|>1, то система неустойчива. Это можно увидеть двумя способами.

Во-первых, как нарушение внутренней устойчивости. Пусть вход равен нулю:

u_tequiv 0,

но начальное состояние ненулевое. Тогда

y_t=a^t y_0.

При |a|>1 это растёт экспоненциально.

Во-вторых, это можно увидеть именно как нарушение BIBO-стабильности. Возьмём нулевое начальное состояние и ограниченный импульсный вход:

u_0=1,qquad u_t=0quad text{для }t>0.

Тогда выход будет содержать множитель

a^t.

При |a|>1 он становится неограниченным. Значит ограниченный вход может породить неограниченный выход.

Коэффициент по модулю меньше 1.

Если |a|<1, то собственная динамика затухает:

y_t=a^t y_0,  qquad  |a|^tto 0.

Кроме того, импульсный отклик такой системы имеет вид геометрической прогрессии. Поэтому вклад прошлого суммируется с экспоненциально убывающими весами, и при ограниченном входе выход остаётся ограниченным. В этом случае система является BIBO-стабильной.

Но при этом нам никто не мешает использовать произвольное число элементов для взвешивания в обратной связи. В общем виде разностное уравнение будет выглядеть как:

y_t=sum_{m=1}^{P} a_{m}y_{t-m} + sum_{k=0}^{Q} b^{(k)}x_{t-k}, qquad hat{x}_t=y_t.

И здесь важно, что устойчивость уже описывается не одним числом a, а корнями характеристического уравнения, составленного из этих коэффициентов. Но это отдельная обширная тема. В данном случае достаточно просто принять это как факт. И дальше ситуация уже не так тривиальна, как с одним коэффициентом: отдельные a_m действительно могут быть по модулю больше 1, но при этом система всё равно может быть устойчивой, если все корни характеристического уравнения лежат внутри единичного круга.

Теперь нужно связать разностное уравнение с формой миксера S4D, которая выглядит подобным образом в простом скалярном случае:

h_t^{(n)}=alpha_n^{mathrm{S4D}},h_{t-1}^{(n)}+beta_n^{mathrm{S4D}},x_t,qquad n=1,dots,N,y_t=sum_{n=1}^{N} c_n^{mathrm{S4D}},h_t^{(n)}+d^{mathrm{S4D}},x_t.

Здесь alpha_n^{mathrm{S4D}}​ являются параметрами перехода состояния, и их не следует путать с коэффициентами a_m​ в разностной форме y_t=sum_{m=1}^P a_m y_{t-m}+dots

И хотя на первый взгляд она кажется немного другой, так как у нас N параллельных фильтров, которые учитывают только предыдущее состояние, я сейчас покажу, как из неё получить разностное уравнение.

Рассмотрим два состояния без входа: x_t=0. Для простоты возьмём c_1=c_2=1:

h_t^{(1)}=alpha_1^{mathrm{S4D}},h_{t-1}^{(1)},qquad h_t^{(2)}=alpha_2^{mathrm{S4D}},h_{t-1}^{(2)},qquad y_t=h_t^{(1)}+h_t^{(2)}.

выпишем y_{t-1}:

y_{t-1}=h_{t-1}^{(1)}+h_{t-1}^{(2)}.

Подставляем динамику в y_t:

y_t=alpha_1^{mathrm{S4D}},h_{t-1}^{(1)}+alpha_2^{mathrm{S4D}},h_{t-1}^{(2)}.

Теперь у нас система:

begin{cases} y_{t-1}=h_{t-1}^{(1)}+h_{t-1}^{(2)},\ y_t=alpha_1^{mathrm{S4D}},h_{t-1}^{(1)}+alpha_2^{mathrm{S4D}},h_{t-1}^{(2)}. end{cases}

Умножим первое уравнение на alpha_1^{mathrm{S4D}} ​и вычтем из второго:

y_t-alpha_1^{mathrm{S4D}},y_{t-1}=bigl(alpha_2^{mathrm{S4D}}-alpha_1^{mathrm{S4D}}bigr),h_{t-1}^{(2)}.

Сдвинем на один шаг назад:

y_{t-1}-alpha_1^{mathrm{S4D}},y_{t-2}=bigl(alpha_2^{mathrm{S4D}}-alpha_1^{mathrm{S4D}}bigr),h_{t-2}^{(2)}.

Но по динамике второго состояния h_{t-1}^{(2)}=alpha_2^{mathrm{S4D}},h_{t-2}^{(2)}​. Домножим предыдущее равенство на alpha_2^{mathrm{S4D}}​:

alpha_2^{mathrm{S4D}}bigl(y_{t-1}-alpha_1^{mathrm{S4D}},y_{t-2}bigr)=bigl(alpha_2^{mathrm{S4D}}-alpha_1^{mathrm{S4D}}bigr),h_{t-1}^{(2)}.

Значит правые части совпадают, и получаем:

y_t-alpha_1^{mathrm{S4D}},y_{t-1}=alpha_2^{mathrm{S4D}}bigl(y_{t-1}-alpha_1^{mathrm{S4D}},y_{t-2}bigr).y_t=bigl(alpha_1^{mathrm{S4D}}+alpha_2^{mathrm{S4D}}bigr),y_{t-1} -alpha_1^{mathrm{S4D}}alpha_2^{mathrm{S4D}},y_{t-2}.

И теперь если мы рассмотрим разностное уравнение для IIR фильтра:

y_t=a_1,y_{t-1}+a_2,y_{t-2}

То a_1 и a_2 равны:

a_1=alpha_1^{mathrm{S4D}}+alpha_2^{mathrm{S4D}},qquad a_2=-,alpha_1^{mathrm{S4D}}alpha_2^{mathrm{S4D}}.

При этом alpha_1^{mathrm{S4D}},alpha_2^{mathrm{S4D}} являются полюсами, а значит достаточно, чтобы они были по модулю меньше 1. Что и описывается в исследовании, посвящённом S4D.

Теперь, чтобы увидеть, что такое Mamba, достаточно просто сделать коэффициенты внутри состояний зависимыми от входа, то есть:

begin{aligned} widetilde{alpha}_{n,t}^{mathrm{Mamba}}(x_t),\ widetilde{beta}_{n,t}^{mathrm{Mamba}}(x_t),\ widetilde{c}_{n,t}^{mathrm{Mamba}}(x_t). end{aligned}

При этом создатели архитектуры Mamba заложили подобную селективность:

tilde a_{n,t}^{mathrm{Mamba}}=exp!big(Delta_t, a_{n}^{mathrm{Mamba}}big), qquad tilde b_{n,t}^{mathrm{Mamba}}=frac{exp!big(Delta_t, a_{n}^{mathrm{Mamba}}big)-1}{a_{n}^{mathrm{Mamba}}}; b_{n,t}^{mathrm{Mamba}}.begin{aligned} &text{где}\ &a_{n}^{mathrm{Mamba}}=-exp(theta_n), qquad theta_ninmathbb{R} text{— обучаемый параметр},\ &b_{n,t}^{mathrm{Mamba}}=w_{n},x_t, qquad w_{n}inmathbb{R},\ &Delta_t=operatorname{softplus}(delta + w_{Delta},x_t), qquad deltainmathbb{R}, w_{Delta}inmathbb{R}. end{aligned}

И тогда, если представить миксер Mamba в виде разностного уравнения, то мы получим:

y_t=sum_{m=1}^{N} a_{m,t}bigl(x_t,x_{t-1},dots,x_{t-N}bigr)y_{t-m} + sum_{k=0}^{N-1} b_{k,t}bigl(x_t,x_{t-1},dots,x_{t-N}bigr)x_{t-k},qquad hat{x}_t=y_t.

Данный вид получается из-за того, что параметры внутри каждого состояния зависят от входа, а значит их невозможно свести к одному коэффициенту при преобразовании в разностную форму.
Но при этом, как мы можем заметить, в разностной форме каждый коэффициент внутри окна зависит от входных данных в этом окне.

Как получить данную формулу.

Рассмотрим одно состояние:

h_t=A_t h_{t-1}+d_t x_t,qquad y_t=c_t h_t,

где

A_t=operatorname{diag}(alpha_{1,t},dots,alpha_{N,t}), qquad d_t=begin{bmatrix} beta_{1,t}\ vdots\ beta_{N,t} end{bmatrix}, qquad c_t=begin{bmatrix} c_{1,t}&dots&c_{N,t} end{bmatrix}.

В Mamba параметры A_t,d_t,c_t зависят только от текущего токена x_t.

Введём

Phi(a,b)=A_aA_{a-1}cdots A_{b+1},qquad a>b,Phi(a,a)=I.

Тогда для j=0,1,dots,N

h_{t-j}=Phi(t-j,t-N)h_{t-N} + sum_{s=t-N+1}^{t-j}Phi(t-j,s)d_sx_s.

Умножая на c_{t-j}, получаем

y_{t-j}=c_{t-j}Phi(t-j,t-N)h_{t-N} + sum_{s=t-N+1}^{t-j}c_{t-j}Phi(t-j,s)d_sx_s.

Обозначим

v_{j,t}:=c_{t-j}Phi(t-j,t-N),u_{j,t}:=sum_{s=t-N+1}^{t-j}c_{t-j}Phi(t-j,s)d_sx_s.

Тогда

y_{t-j}=v_{j,t}h_{t-N}+u_{j,t}, qquad j=0,1,dots,N.mathcal O_t=begin{bmatrix} v_{1,t}\ v_{2,t}\ vdots\ v_{N,t} end{bmatrix}=begin{bmatrix} c_{t-1}Phi(t-1,t-N)\ c_{t-2}Phi(t-2,t-N)\ vdots\ c_{t-N} end{bmatrix}.

Локальная невырожденность:

det mathcal O_tneq 0.

Тогда существует единственный набор коэффициентов

a_t=begin{bmatrix} a_{1,t}&dots&a_{N,t} end{bmatrix}

такой, что

v_{0,t}=sum_{m=1}^{N}a_{m,t}v_{m,t}.a_t=v_{0,t}mathcal O_t^{-1}.

Теперь

y_t=v_{0,t}h_{t-N}+u_{0,t},

а для m=1,dots,N

y_{t-m}=v_{m,t}h_{t-N}+u_{m,t}.

Следовательно,

y_t-sum_{m=1}^{N}a_{m,t}y_{t-m}=left( v_{0,t}-sum_{m=1}^{N}a_{m,t}v_{m,t} right)h_{t-N} + u_{0,t}-sum_{m=1}^{N}a_{m,t}u_{m,t}.

По построению

v_{0,t}-sum_{m=1}^{N}a_{m,t}v_{m,t}=0.

Значит

y_t-sum_{m=1}^{N}a_{m,t}y_{t-m}=u_{0,t}-sum_{m=1}^{N}a_{m,t}u_{m,t}.

Правая часть является линейной комбинацией входов

x_t,x_{t-1},dots,x_{t-N+1}.

Поэтому существуют коэффициенты b_{0,t},dots,b_{N-1,t}, такие что

u_{0,t}-sum_{m=1}^{N}a_{m,t}u_{m,t}=sum_{k=0}^{N-1}b_{k,t}x_{t-k}.

Для  k=0,dots,N-1

b_{k,t}=c_tPhi(t,t-k)d_{t-k} - sum_{m=1}^{k} a_{m,t}, c_{t-m}Phi(t-m,t-k)d_{t-k}.

Следовательно,

y_t=sum_{m=1}^{N}a_{m,t}y_{t-m} + sum_{k=0}^{N-1}b_{k,t}x_{t-k}.

Так как A_t диагональна,

Phi(t-j,t-N)=operatorname{diag} left( prod_{r=t-N+1}^{t-j}alpha_{1,r}, dots, prod_{r=t-N+1}^{t-j}alpha_{N,r} right).

Значит

(v_{j,t})_n=c_{n,t-j} prod_{r=t-N+1}^{t-j}alpha_{n,r}.

Поэтому v_{0,t} зависит только от параметров на индексах

t,t-1,dots,t-N+1,

а матрица mathcal O_t зависит только от параметров на индексах

t-1,t-2,dots,t-N.

Следовательно

a_t=v_{0,t}mathcal O_t^{-1}

зависит только от параметров на индексах

t,t-1,dots,t-N.

Но параметры индексов s зависят только от x_s, поэтому

a_{m,t}=a_{m,t}(x_t,x_{t-1},dots,x_{t-N}).b_{k,t}=c_tPhi(t,t-k)d_{t-k} - sum_{m=1}^{k} a_{m,t}, c_{t-m}Phi(t-m,t-k)d_{t-k}.

Здесь все множители c,d,Phi зависят только от параметров на индексах

t,t-1,dots,t-N+1,

а коэффициенты a_{m,t} зависят от

x_t,x_{t-1},dots,x_{t-N}.

Следовательно

b_{k,t}=b_{k,t}(x_t,x_{t-1},dots,x_{t-N}).

Поэтому, при detmathcal O_tneq 0 получаем

y_t=sum_{m=1}^{N} a_{m,t}bigl(x_t,x_{t-1},dots,x_{t-N}bigr)y_{t-m} + sum_{k=0}^{N-1} b_{k,t}bigl(x_t,x_{t-1},dots,x_{t-N}bigr)x_{t-k}, qquad hat{x}_t=y_t.

И теперь можно перейти к анализу свойств данной архитектуры.

Память

Как мы уже определились в предыдущем пункте, сравнение памяти должно быть честным и согласованным, и несмотря на то что в Mamba нет attention как в трансформере, за роль аналога диффузного режима будет отвечать величина Delta_t.

У Mamba есть режим, который в одном исследовании назвали freezing time, суть в том, что если Delta_tapprox 0 тоleft|widetilde{alpha}_{n,t}^{mathrm{Mamba}}(x_t)right|approx 1, а widetilde{beta}_{n,t}^{mathrm{Mamba}}(x_t)approx 0,состояние сохраняется, а новый вход не попадает. А значит оно может в этом случае хранить его сколь угодно долго, напоминая режим attention с one-hot. Для контролируемого режима достаточно предположить, что этот механизм плохо работает и не может корректно определить, где заморозить время, а где нет (failed freezing time), то есть ломаем sharp retrieval. И для этого режима в исследовании я показываю оценки также для однослойного и многослойного вариантов.

Однослойный. В этом режиме верхняя оценка равна:

Oleft(e^{-lambda c_Delta ell}right),

Нижняя оценка в этом режиме тривиальна и будет больше либо равна нулю.

Многослойный. Рассматриваем стек из L слоёв Mamba и полный градиент. В этом режиме верхняя оценка равна:

Oleft((1+ell)^{L-1}e^{-c_ast ell}right).

То есть глубина добавляет дополнительные маршруты через промежуточные слои и даёт полиномиальный множитель перед экспонентой, но общий характер памяти не меняется: экспонента всё равно доминирует. Например,

L=1:quad Oleft(e^{-c_ast ell}right), qquad   L=2:quad Oleft(ell e^{-c_ast ell}right).

Но при фиксированной глубине L влияние старых токенов всё равно затухает:

(1+ell)^{L-1}e^{-c_ast ell}to 0.

Нижняя оценка, как и в однослойном режиме, остаётся тривиальной: она больше либо равна нулю.

Но допустим мы рассмотрим отдельно реальную ситуацию с freezing time, и тут есть некоторые подводные камни, о которых я рассказываю в исследовании. Механизм заморозки опирается на распознавание по входным представлениям, а поскольку  Delta_t вычисляется как функция от входа, то когда разделимость сигнала падает и появляется шум в последовательности, распознавание может стать ненадёжным: модель не поддерживает режим Delta_tapprox 0 на всём релевантном промежутке, и возникают шаги Delta_t существенно больше нуля. Поскольку реальные последовательности часто содержат шум, подобный режим может встречаться довольно часто.

Ну и напоследок перед тем, как перейти к Sessa, представим маршрутизацию влияния в Mamba в виде ориентированного ациклического графа. Но он устроен иначе, чем в трансформере: это цепочка с рёбрами только между соседними шагами (t-1)to t где вес ребра задаётся переходом состояния. Тогда влияние из tau в t внутри одного слоя проходит по одному пути tautotau+1tocdotsto t (one path), но требует  ell=t-tau последовательных переходов (multi-hop). Интуитивно, при фиксированной модели и произвольной последовательности, чтобы обратиться к далёкому элементуtau, модели нужно пройти через цепочку обратной связи. Как мы разбирали выше, динамика на каждом шаге учитывает только конечное число элементов в forward- и feedback-ветках. Поэтому получается один путь, но много переходов. Это и объясняет, почему Mamba часто имеет экспоненциальное затухание.

Как я создал альтернативу трансформерам - 184

Sessa: Selective State Space Attention

Теперь перейдём к архитектуре, которую я предлагаю как альтернативу Mamba и Transformer.

Идея заключается в том, чтобы объединить два направления: добавить feedback и одновременно сделать forward- и feedback-ветки адаптивными к текущему токену и длине последовательности. Иными словами, добавить feedback в Transformer используя attention.

Сам слой Sessa имеет вид:

x_{ln}=mathrm{LN}(x),qquad   (a,g)=mathrm{split}(x_{ln}W^{mathrm{in}}+b^{mathrm{in}}),qquad   bar a=mathrm{GELU}(a),

после чего из bar a строятся две ветки миксера: forward-ветка

f_t=sum_{jle t}alpha^f_{t,j}v_j,

и feedback-ветка

s_t=f_t+gamma_tsum_{j<t}alpha^b_{t,j}s_j,

или, в параллельной матричной форме,

(I-B)s=f,qquad B_{t,j}=gamma_talpha^b_{t,j}.

Финальный выход слоя равен

y=x+big((sodot g)W^{mathrm{out}}+b^{mathrm{out}}big).

Если записать миксер в разностной форме, как в предыдущем пункте, получаем:

y_t=sum_{j<t} a_{t,j}(x_{le t})y_j + sum_{jle t} b_{t,j}(x_{le t})x_j, qquad hat x_t=y_t.

Как получить данную формулу.
f_t=sum_{jle t}alpha^f_{t,j}v_j, qquad s_t=f_t+gamma_tsum_{j<t}alpha^b_{t,j}s_j.

Подставляя f_t во вторую формулу, получаем

s_t=sum_{j<t}gamma_talpha^b_{t,j}s_j + sum_{jle t}alpha^f_{t,j}v_j.

Теперь обозначим

y_t:=s_t,qquad x_j:=v_j, quad a_{t,j}(x_{le t}):=gamma_talpha^b_{t,j},qquad b_{t,j}(x_{le t}):=alpha^f_{t,j}.

Тогда

y_t=sum_{j<t}a_{t,j}(x_{le t}),y_j+sum_{jle t}b_{t,j}(x_{le t}),x_j, qquad hat x_t=y_t.

В случае трансформера это эквивалентно

y_t=sum_{jle t} b_{t,j}(x_{le t})x_j, qquad hat x_t=y_t.

В Mamba коэффициенты зависят от конечного окна входов, а более далёкая история передаётся через состояние. В Sessa же feedback-коэффициенты напрямую строятся по всему префиксу.

Теперь перейдём к свойствам.

Память.

Для сравнения памяти в случае Sessa применяется диффузный режим, аналогичный трансформеру. Для этого режима однослойный и многослойный вариант:

Однослойный: Если рассматривать полный градиент одного блока без заморозки весов attention:

Oleft(ell^{-beta_{text{tail}}}log(1+ell)right), qquad beta_{text{tail}}=1-gamma_{max}c_2in(0,1).

У Sessa в этом режиме хвост ell^{-beta} с 0<beta<1, то есть медленнее хвоста Transformer 1/ell.

Нижняя оценка здесь тоже тривиальна и будет больше либо равна нулю.

Многослойный: рассматриваем стек из L слоёв Sessa и полный градиент от выхода на позиции t к входу на позиции tau. В этом режиме верхняя оценка равна:

Oleft(   sum_{k=1}^{L}   ell^{k(1-beta_{text{tail}})-1}   bigl(log(1+ell)bigr)^k   right).

Эквивалентно, если смотреть на доминирующий по порядку член при фиксированной глубине, то можно писать:

Oleft(   ell^{L(1-beta_{text{tail}})-1}   bigl(log(1+ell)bigr)^L   right).

То есть глубина здесь даёт не просто логарифмическое усиление, как у трансформера, а более сильное накопление вклада по мере композиции слоёв. Поэтому затухание может быть заметно медленнее.

Например,

L=1:quad Oleft(ell^{-beta_{text{tail}}}log(1+ell)right).

При этом, если

L(1-beta_{text{tail}})<1,

то влияние старых токенов при фиксированной глубине всё ещё затухает. Вне этого режима теорема в таком виде гарантирует уже не обязательное затухание, а контролируемую верхнюю оценку.

Нижняя оценка, как и в однослойном режиме, остаётся тривиальной: она больше либо равна нулю.

Теперь перейдём к рассмотрению маршрутизации влияния в Sessa как ориентированный ациклический граф.

Допустим, мы применяем к начальному состоянию s_0=f_0 рекурентно миксер Sessa получая еще дополнительно три состояния s_1, s_2, s_3. Исходя из формул миксера, это будет выглядеть следующим образом:

begin{aligned} s_0 &=f_0 \ s_1 &=f_1 + B_{1,0}s_0 \ s_2 &=f_2 + B_{2,1}s_1 + B_{2,0}s_0 \ s_3 &=f_3 + B_{3,2}s_2 + B_{3,1}s_1 + B_{3,0}s_0 end{aligned}

подставим s_2,s_1,s_0 в s_3

s_1=f_1 + B_{1,0}f_0.

подставим s_1 и s_0 в s_2

begin{aligned} s_2 &=f_2 + B_{2,1}s_1 + B_{2,0}s_0 \ &=f_2 + B_{2,1}(f_1 + B_{1,0}f_0) + B_{2,0}f_0 \ &=f_2 + B_{2,1}f_1 + (B_{2,1}B_{1,0}+B_{2,0})f_0 end{aligned}

подставим s_2,s_1,s_0 в s_3

begin{aligned} s_3 &=f_3 + B_{3,2}s_2 + B_{3,1}s_1 + B_{3,0}s_0 \[4pt] &=f_3 + B_{3,2}Big(f_2 + B_{2,1}f_1 + (B_{2,1}B_{1,0}+B_{2,0})f_0Big)    + B_{3,1}Big(f_1 + B_{1,0}f_0Big)    + B_{3,0}f_0 \[4pt] &=f_3  + B_{3,2}f_2  + (B_{3,2}B_{2,1}+B_{3,1})f_1 \[4pt] &quad + Big( underbrace{B_{3,2}B_{2,1}B_{1,0}}_{text{путь }0to1to2to3} +underbrace{B_{3,2}B_{2,0}}_{text{путь }0to2to3} +underbrace{B_{3,1}B_{1,0}}_{text{путь }0to1to3} +underbrace{B_{3,0}}_{text{путь }0to3} Big)f_0 end{aligned}

Если посмотреть на коэффициент при  f_0, то есть влияние из момента 0 в момент 3:

J_{3,0}=B_{3,0} + B_{3,1}B_{1,0} + B_{3,2}B_{2,0} + B_{3,2}B_{2,1}B_{1,0}

Можно выделить many paths и multi-hop:

Путь 0to 3 с одним переходом: B_{3,0},

Путь 0to1to3 c двумя переходами: B_{3,1}, B_{1,0} и путь 0to2to3 также с двумя переходами: B_{3,2}, B_{2,0},

Путь 0to1to2to3 с тремя переходами: B_{3,2}, B_{2,1}, B_{1,0}.

За счёт того, что модель может добираться до элементов различными путями, это и даёт более медленное затухание.

Если более интуитивно, то за счёт attention в feedback-части модель может обратиться к любому предыдущему feedback-шагу. Этот шаг, в свою очередь, уже учитывает предыдущие шаги, а также все входы через attention в forward-части. Поэтому и возникает множество путей с возможностью реализации множества переходов.

Как я создал альтернативу трансформерам - 231

Гибкое селективное извлечение (flexible selective retrieval):

Это один из наиболее интересных и практически важных результатов. Суть в том, что в каждой из трёх моделей можно настроить параметры так, чтобы она без проблем извлекала необходимый вход. Но при анализе селективного извлечения лучше не смотреть только на якобианы, потому что память, транспорт нужного сигнала и селективность — это немного разные вещи. То есть память, которую мы рассматривали выше, отвечает за то, дошло ли влияние вообще. Транспорт нужного сигнала отвечает за то, дошёл ли именно нужный компонент. А селективность — за то, победил ли нужный источник всех конкурентов. И все эти три вещи нужны для данного анализа.

При этом у селективности есть три профиля:

Затухающий профиль. Чем дальше источник во времени, тем слабее гарантированная селективность. Модель всё ещё может извлекать нужный токен, но это становится всё труднее с ростом расстояния из-за того, что разница между текущим токеном и остальной массой затухает.

Замороженный профиль. Качество retrieval не деградирует с расстоянием, то есть по мере роста лага разница между текущим токеном и остальной массой остаётся одного и того же порядка.

Возрастающий профиль. Разница между текущим токеном и остальной массой может не просто не падать, а расти, то есть идет накопление преимущества.

Все три профиля в комбинации и дают гибкое селективное извлечение. Вне ограничительных режимов каждую из архитектур можно настроить так, чтобы она успешно извлекала нужный вход.

Но в диффузном режиме и его аналоге для Mamba различия становятся принципиальными. Суть в том, что он не только нужен как равноценный режим сравнения памяти, сам по себе он часто может возникать, особенно на длинном контексте, то есть коэффициенты attention начинают конкурировать за массу, или в последовательности немало шума и Mamba теряет нужный токен. А значит данный анализ полезен на практике для длинного контекста.

Каждая из архитектур с одним слоем может реализовать только затухающий профиль. Но когда слоёв становится два или больше, только Sessa может реализовать замороженный и возрастающий профиль. То есть диффузный режим для неё не проблема. Интуитивно это связано с тем, что в многослойной модели возникает столько путей ко входу, что, несмотря на диффузность внимания, сама структура памяти может сфокусировать всю эту массу на нужном входе.

Позиционное кодирование и универсальная аппроксимация.

Sessa, как и Mamba, может кодировать APE внутри себя без явных таблиц или экстраполяции, в отличие от трансформера. Следовательно, её можно обучать без встроенного позиционного кодирования, но стоит учитывать, что к нему или к его альтернативному варианту модель должна ещё прийти в процессе обучения.

Ну, а так как есть APE, то универсальная теорема аппроксимации для отображения последовательности в последовательность сама собой напрашивалась как дополнительный теоретический результат. И это довольно приятно, так как трансформеру для универсальной аппроксимации необходим именно внешний APE.

Итог

Как видно, модели декодеров можно описать на общем фундаменте, где явно проявляются их различия. В статье я сравнил только три архитектуры, но к этому списку вполне можно добавить и другие.

Что касается Sessa, в исследовании я был больше сфокусирован на теории. Но в ближайших планах обучить модель на несколько миллиардов параметров и посмотреть, что там уже с длинным контекстом на больших масштабах.

Если вы хотите поддержать это исследование, я был бы очень благодарен за ваш голос на Hugging Face. Спасибо!

Автор: Flokis_guy

Источник