Параллельность RNN?. apple.. apple. ICLR 2026.. apple. ICLR 2026. rnn.. apple. ICLR 2026. rnn. машинное+обучение.. apple. ICLR 2026. rnn. машинное+обучение. рекуррентная нейросеть.

Смотрели итоги прошедшего ICLR? Меня заинтересовала довольно провокационная, на первый взгляд, статья от Эплов — ParaRNN. Казалось бы, параллельность РНН — это их главный недостаток, благодаря которому их заменили трансформеры (в большинстве задач).

Так вот, давайте разберемся со всем, на максимально низком уровне, если знаете, что такое RNN и производная — то эта статья для вас.

1. Алгоритм DEER

DEER = Deep Equilibrium Evaluation of Recurrence (Lim et al., 2024). Базовый алгоритм, на котором строится ParaRNN.

1.1. Постановка как задача нахождения корня

Пусть у нас есть обыкновенная RNN с переходной функцией f: mathbb{R}^D to mathbb{R}^D, начальным состоянием mathbf{s}_0 и неизвестными состояниями mathbf{s}_1, ldots, mathbf{s}_T. Введем остаток (residual):

mathbf{r}(mathbf{s}_{1:T}) :=[mathbf{s}_1 - f(mathbf{s}_0), mathbf{s}_2 - f(mathbf{s}_1), ldots, mathbf{s}_T - f(mathbf{s}_{T-1})] in mathbb{R}^{T times D}

Истинная траектория mathbf{s}^*_1, ldots, mathbf{s}^*_T – это единственное решение уравнения:

mathbf{r}(mathbf{s}^*_{1:T})=mathbf{0}


Когда говорят «применить RNN к последовательности», имеют в виду стандартную процедуру: взять начальное состояние mathbf{s}_0, применить переходную функцию f, получить mathbf{s}_1, потом еще раз применить f, получить mathbf{s}_2, и так далее:

mathbf{s}_1=f(mathbf{s}_0), quad mathbf{s}_2=f(mathbf{s}_1), quad ldots, quad mathbf{s}_T=f(mathbf{s}_{T-1})

Соответственно, получается, что mathbf{r} – вектор, у которого все элементы равны 0, опять же потому что при соблюдении рекуррентности mathbf{s}_1=f(mathbf{s}_0) и mathbf{s}_1 - f(mathbf{s}_0)=0.


1.2. Итерации Ньютона

Соответственно дальше, необходимо найти решение уравнения mathbf{r}(mathbf{s})=0, или в полном случае — вектор, решающий систему уравнений. Но для начала разберемся со скалярным случаем.

Скалярный случай: одно уравнение от одной переменной

Пусть у нас есть гладкая функция r: mathbb{R} to mathbb{R} и мы хотим найти такое s^*, что r(s^*)=0. Геометрически — найти точку, где график функции пересекает ось абсцисс.

Идея метода Ньютона строится на простой мысли: в малой окрестности точки гладкая функция почти неотличима от своей касательной. Если мы стоим в текущем приближении s^{(i)} (которое, в общем случае, не корень — там r(s^{(i)}) neq 0), мы можем сделать вид, что r — это ее касательная в этой точке, и для такой линейной функции легко аналитически найти, где она пересекает ось.

Касательная к r в точке s^{(i)} — это первое слагаемое разложения Тейлора:

r(s) approx r(s^{(i)}) + r'(s^{(i)}),(s - s^{(i)})


Вспомним: Разложение Тейлора — это способ приблизить любую гладкую функцию вблизи точки s_0 многочленом:

r(s)=r(s_0) + r'(s_0)(s - s_0) + frac{r''(s_0)}{2!}(s - s_0)^2 + frac{r'''(s_0)}{3!}(s - s_0)^3 + ldots

где каждое следующее слагаемое уточняет приближение, добавляя информацию о все более тонкой особенности формы функции (наклон, кривизна, и т.д.). Логический смысл такой: если функция гладкая, то ее поведение в окрестности точки полностью закодировано в значениях ее производных в этой одной точке — измерив несколько чисел в s_0, мы можем восстановить значения функции рядом. Делитель k! возникает естественно из требования, чтобы в точке s_0 совпадали все производные многочлена и самой функции (он сокращается с факториалом, выскакивающим при k-кратном дифференцировании (s - s_0)^k).


Приравниваем эту линейную аппроксимацию к нулю и находим, где она пересекает ось:

r(s^{(i)}) + r'(s^{(i)}),(s - s^{(i)})=0

Решаем относительно s :

s=s^{(i)} - frac{r(s^{(i)})}{r'(s^{(i)})}

Это и объявляем следующим приближением:

s^{(i+1)}=s^{(i)} - frac{r(s^{(i)})}{r'(s^{(i)})}

Параллельность RNN? - 37

Тут показан графически шаг s^{(i)} to s^{(i+1)}, и на графике видно, что корень уравнения (нас интересует пересечение функции с осью абсцисс) поменялся с 2 до 1, что показывает улучшение, так как эталонное значение — 0.

Можно переписать это через приращение Delta s^{(i+1)} :=s^{(i+1)} - s^{(i)}, что окажется удобнее при обобщении:

r'(s^{(i)}),Delta s^{(i+1)}=-r(s^{(i)})

То есть «найти такое приращение, чтобы линейная поправка r' cdot Delta s скомпенсировала текущий остаток r».

Многомерный случай: N уравнений от N переменных

Теперь обобщаем. Вместо одной функции от одной переменной mathbf{r}: mathbb{R}^N to mathbb{R}^N — векторнозначная функция векторного аргумента, и ищем такой вектор mathbf{s}^* in mathbb{R}^N, что mathbf{r}(mathbf{s}^*)=mathbf{0}.

Логика остается такой же, меняются только объекты или их размерности:

Скаляр

Вектор

функция r(s)

вектор-функция mathbf{r}(mathbf{s})

производная r'(s) — число

якобиан J(mathbf{s}) – матрица N times N

касательная (прямая)

касательная гиперплоскость

деление на r'

умножение на J^{-1} (то есть, решение линейной системы)

Где J(mathbf{s})=dfrac{partial mathbf{r}}{partial mathbf{s}} — якобиан, многомерный аналог обычной производной (об этом подробнее ниже).

Якобиан J(mathbf{s}) — это просто матрица всех частных производных: в позиции (i, j) стоит partial r_i / partial s_j. Он играет роль производной — показывает, как малое изменение mathbf{s} влияет на mathbf{r} в линейном приближении.

Линеаризация mathbf{r} вокруг точки mathbf{s}^{(i)}:

mathbf{r}(mathbf{s}) approx mathbf{r}(mathbf{s}^{(i)}) + J(mathbf{s}^{(i)}),(mathbf{s} - mathbf{s}^{(i)})

Приравниваем линейную аппроксимацию к нулевому вектору:

mathbf{r}(mathbf{s}^{(i)}) + J(mathbf{s}^{(i)}),(mathbf{s} - mathbf{s}^{(i)})=mathbf{0}

И обозначив приращение Delta mathbf{s}^{(i+1)} :=mathbf{s} - mathbf{s}^{(i)}, получаем линейную систему относительно Delta mathbf{s}^{(i+1)}:

J(mathbf{s}^{(i)}),Deltamathbf{s}^{(i+1)}=-mathbf{r}(mathbf{s}^{(i)})

Решив систему и получив Delta mathbf{s}^{(i+1)}, обновляем приближение:

mathbf{s}^{(i+1)}=mathbf{s}^{(i)} + Deltamathbf{s}^{(i+1)}

Также можно записать более компактно через обратную матрицу:

mathbf{s}^{(i+1)}=mathbf{s}^{(i)} - J(mathbf{s}^{(i)})^{-1},mathbf{r}(mathbf{s}^{(i)})

— это та же самая формула, просто короче. Запись с J^{-1} — чисто нотационная: на практике обратную матрицу никто никогда не вычисляет, потому что это и дорого, и численно неустойчиво. Вместо этого решают систему J cdot Deltamathbf{s}=-mathbf{r} напрямую — например, через LU-разложение, или через прямую подстановку, если J имеет специальную структуру (что и происходит у нас).

1.3. Применение к нашей задаче с RNN

В случае RNN все ровно по этому шаблону, только размерности конкретные:

  • mathbf{s}=(mathbf{s}_1, ldots, mathbf{s}_T) in mathbb{R}^{TD} – все скрытые состояния, склеенные в один длинный вектор длины TD.

  • mathbf{r}: mathbb{R}^{TD} to mathbb{R}^{TD} — вектор всех одношаговых остатков, той же длины.

  • J(mathbf{s}) in mathbb{R}^{TD times TD} — якобиан остатка по состоянию.

Применяем тот же ньютоновский шаг:

J(mathbf{s}^{(i)}),Deltamathbf{s}^{(i+1)}=-mathbf{r}(mathbf{s}^{(i)})

И тут возникает вопрос: «А разве решить линейную систему размера TD times TD — это не та же самая последовательная задача? Где здесь параллелизация?»

Если бы J была произвольной плотной матрицей, то да — наивно решение стоило бы O((TD)^3), и никакой выгоды бы не было. Но J не произвольная. Из-за марковости RNN (каждый шаг f видит только предыдущее состояние mathbf{s}_{t-1}, а не всю историю) в якобиане подавляющее большинство блоков — нули. Конкретно: в блочной строке номер t ненулевые элементы есть только в столбцах t и t-1. Получается блочно-бидиагональная структура:

J(mathbf{s})=begin{pmatrix}I_D & 0 & 0 & cdots & 0 \-frac{partial f}{partial mathbf{s}}(mathbf{s}_1) & I_D & 0 & cdots & 0 \0 & -frac{partial f}{partial mathbf{s}}(mathbf{s}_2) & I_D & cdots & 0 \vdots & & ddots & ddots & vdots \0 & 0 & cdots & -frac{partial f}{partial mathbf{s}}(mathbf{s}_{T-1}) & I_Dend{pmatrix}


Что такое якобиан вообще

Когда есть обычная функция от одной переменной r: mathbb{R} to mathbb{R}, ее производная r'(s) – это одно число, которое говорит «насколько быстро меняется выход при малом изменении входа». Оно играет роль локального коэффициента пропорциональности: если сместить s на маленькое delta, то r изменится примерно на r'(s) cdot delta.

Теперь представим, что функция, у которой и вход, и выход — векторы. Скажем, mathbf{r}: mathbb{R}^N to mathbb{R}^M: на вход подаем вектор из N чисел, на выход получаем вектор из M чисел. Понятие «производной» здесь усложняется, потому что теперь надо отвечать на N times M вопросов одновременно: «как меняется i-я компонента выхода при изменении j-й компоненты входа?». Ответы на все эти вопросы естественно собираются в матрицу размера M times N – это и есть якобиан:

J(mathbf{s})=frac{partial mathbf{r}}{partial mathbf{s}}=begin{pmatrix}frac{partial r_1}{partial s_1} & frac{partial r_1}{partial s_2} & cdots & frac{partial r_1}{partial s_N} \frac{partial r_2}{partial s_1} & frac{partial r_2}{partial s_2} & cdots & frac{partial r_2}{partial s_N} \vdots & vdots & ddots & vdots \frac{partial r_M}{partial s_1} & frac{partial r_M}{partial s_2} & cdots & frac{partial r_M}{partial s_N}end{pmatrix}

В позиции (i, j) стоит число partial r_i / partial s_j — частная производная i-й компоненты выхода по j-й компоненте входа. То есть якобиан буквально — это полная карта чувствительностей: каждая ячейка отвечает на конкретный вопрос «насколько чутко эта выходная координата реагирует на эту входную координату».


Вспомним определение остатка:

mathbf{r}_t(mathbf{s}_{1:T})=mathbf{s}_t - f(mathbf{s}_{t-1})

Это выражение зависит только от двух переменных: от mathbf{s}_t (через первое слагаемое) и от mathbf{s}_{t-1} (через второе). Все остальные mathbf{s}_k в формуле просто не присутствуют. А производная по переменной, которой в формуле нет, равна нулю.

Разберем по случаям, какой блок partial mathbf{r}_t / partial mathbf{s}_k получается при разных k:

Случай 1: k=t. Берем производную mathbf{s}_t - f(mathbf{s}_{t-1}) по mathbf{s}_t. Только первое слагаемое зависит от mathbf{s}_t, и его производная по самому себе — единичная матрица. Получаем:

frac{partial mathbf{r}_t}{partial mathbf{s}_t}=I_D

Случай 2: k=t - 1. Берем производную по mathbf{s}_{t-1}. Первое слагаемое от нее не зависит, второе — это -f(mathbf{s}_{t-1}), и его производная — это -partial f / partial mathbf{s}, вычисленная в точке mathbf{s}_{t-1}:

frac{partial mathbf{r}_t}{partial mathbf{s}_{t-1}}=-frac{partial f}{partial mathbf{s}}(mathbf{s}_{t-1})

Случай 3: все остальные k (то есть k neq t и k neq t - 1). Переменная mathbf{s}_k в формуле для mathbf{r}_t просто не встречается. Значит:

frac{partial mathbf{r}_t}{partial mathbf{s}_k}=0_{D times D} quad (text{нулевая матрица})

Вот и все. Из T^2 блоков ненулевых ровно T + (T-1)=2T - 1: T единичных матриц на главной диагонали и T-1 якобианов перехода на поддиагонали. Все остальное — нули. Если расписать всю матрицу:

J(mathbf{s})=begin{pmatrix}I_D & 0 & 0 & cdots & 0 \-frac{partial f}{partial mathbf{s}}(mathbf{s}_1) & I_D & 0 & cdots & 0 \0 & -frac{partial f}{partial mathbf{s}}(mathbf{s}_2) & I_D & cdots & 0 \vdots & & ddots & ddots & vdots \0 & 0 & cdots & -frac{partial f}{partial mathbf{s}}(mathbf{s}_{T-1}) & I_Dend{pmatrix}

Где здесь марковость

Ключевая причина, по которой эта структура появилась — марковское свойство RNN. Переходная функция f в каждом шаге смотрит только на предыдущее состояние mathbf{s}_{t-1}, а не на всю историю mathbf{s}_1, ldots, mathbf{s}_{t-1}. Из-за этого остаток mathbf{r}_t оказывается «локальным» объектом: он зависит только от двух соседних состояний — текущего mathbf{s}_t и предыдущего mathbf{s}_{t-1}.

Сколько нам реально нужно памяти

Хотя формально якобиан имеет TD times TD ячеек, нам нужно хранить только ненулевые блоки. Это:

  • T единичных матриц I_D — но их даже хранить не нужно, мы знаем, что это I_D и можем подставлять на лету;

  • T - 1 якобианов перехода partial f / partial mathbf{s}(mathbf{s}_t) размера D times D — это (T-1) cdot D^2 чисел, что для T=1000, D=256 дает около 65 миллионов чисел вместо 65 миллиардов. Уже выполнимо.

Как структура якобиана дает параллелизм

Теперь главный вопрос: почему такая структура позволяет решить систему J cdot Deltamathbf{s}=-mathbf{r} параллельно? Здесь важно различать два уровня:

Уровень 1: структура позволяет решить систему через прямую подстановку. Возьмем систему J cdot Deltamathbf{s}=-mathbf{r} и распишем ее построчно. Первая блочная строка матрицы J — это (I_D, 0, 0, ldots, 0), поэтому первое уравнение системы это:

I_D cdot Deltamathbf{s}_1=-mathbf{r}_1

  • то есть просто Deltamathbf{s}_1=-mathbf{r}_1. Получили первый кусок ответа практически даром.

Вторая блочная строка J это (-partial f / partial mathbf{s}(mathbf{s}_1), I_D, 0, ldots, 0), поэтому второе уравнение:

-frac{partial f}{partial mathbf{s}}(mathbf{s}_1) cdot Deltamathbf{s}_1 + I_D cdot Deltamathbf{s}_2=-mathbf{r}_2

Откуда:

Deltamathbf{s}_2=frac{partial f}{partial mathbf{s}}(mathbf{s}_1) cdot Deltamathbf{s}_1 - mathbf{r}_2

И вообще, для произвольного

t > 1:

Deltamathbf{s}_t=frac{partial f}{partial mathbf{s}}(mathbf{s}_{t-1}) , Deltamathbf{s}_{t-1} - mathbf{r}_t

Это и есть линейная рекуррентность, в которую превратилась наша гигантская система TD times TD. Заметим, что в общем случае решение линейной системы стоит O(N^3) — но мы здесь обошлись без обращения какой-либо матрицы, благодаря тому что J блочно-бидиагональна. Система решается простой пробежкой по уравнениям сверху вниз. Это и называется forward substitution.

Если бы дело закончилось здесь, мы бы получили лишь последовательный алгоритм за O(T) шагов — каждый Deltamathbf{s}_t зависит от Deltamathbf{s}_{t-1}, и пробежать рекурренцию надо строго по порядку. То же самое, что просто прогнать RNN последовательно. Параллелизм рождается на следующем уровне.

Уровень 2: рекуррентность линейна, и потому ассоциативна. В этом — главный фокус. Обратим внимание на принципиальную разницу между двумя ситуациями:

  • Исходная RNN: mathbf{s}_t=f(mathbf{s}_{t-1}, mathbf{x}_t) — функция f нелинейна, поэтому распараллелить такую рекуррентность нельзя: приходится считать каждый шаг честно по очереди.

  • Рекуррентность для Deltamathbf{s}: Deltamathbf{s}_t=A_t cdot Deltamathbf{s}_{t-1} + mathbf{b}_t (где A_t=partial f / partial mathbf{s}(mathbf{s}_{t-1}), mathbf{b}_t=-mathbf{r}_t) – она линейна. Это значит, что из нее можно получить замкнутую формулу:

Deltamathbf{s}_t=A_t A_{t-1} cdots A_2 cdot Deltamathbf{s}_1 + (A_t cdots A_3 cdot mathbf{b}_2) + (A_t cdots A_4 cdot mathbf{b}_3) + ldots + mathbf{b}_t

Все эти произведения матриц A_t cdot A_{t-1} cdot ldots можно вычислить в любом порядке (умножение матриц ассоциативно: (AB)C=A(BC)). А значит, можно построить дерево вычислений, в котором мы сначала параллельно считаем все попарные произведения A_2 cdot A_1, A_4 cdot A_3, A_6 cdot A_5, …, затем все четверки A_4 cdot A_3 cdot A_2 cdot A_1, A_8 cdot A_7 cdot A_6 cdot A_5, …, и так далее. За log_2 T уровней дерева мы получаем все нужные накопленные произведения, и из них собираем все Deltamathbf{s}_t одновременно.

Это и есть параллельный скан (он же parallel prefix sum в обобщенной форме). Аналогия с обычным сложением: если надо сложить миллиард чисел, последовательно — миллиард шагов, а попарным деревом — всего log_2(10^9) approx 30 уровней. Тот же прием работает для любой ассоциативной операции, а композиция линейных отображений (умножение их матриц) — ассоциативна.

Итог по сложности: один шаг Ньютона выполняется за O(log T) параллельной глубины (вместо O(T) последовательных шагов), а все применение RNN — за O(text{iters} cdot log T), где iters — число итераций Ньютона.

Автор: yeetmq

Источник