Momentum Attention: когда внимание получает инерцию. ai.. ai. attention.. ai. attention. deep learning.. ai. attention. deep learning. machine learning.. ai. attention. deep learning. machine learning. PyTorch.. ai. attention. deep learning. machine learning. PyTorch. Research.. ai. attention. deep learning. machine learning. PyTorch. Research. transformers.. ai. attention. deep learning. machine learning. PyTorch. Research. transformers. нейросети.

В классическом self-attention каждый токен смотрит на другие токены, чтобы понять, что важно в данный момент.
Внимание распределяется мгновенно:

Momentum Attention: когда внимание получает инерцию - 1

Именно этот механизм сделал трансформеры тем, чем они стали.

Но вот в чём проблема – внимание не имеет памяти.
На каждой итерации оно переобучается заново, не зная, куда оно смотрело в прошлый раз.
Из за этого внимание может скакать, шуметь и терять контекст, особенно в длинных последовательностях.

Проблема: внимание без инерции

Представьте, что вы идёте по неровной дороге.
Если вы будете менять направление мгновенно, без инерции, вас просто будет бросать из стороны в сторону.
Точно так же и внимание в трансформере:
оно то цепляется за один токен, то внезапно переключается на другой,
порождая хаотичные изменения в градиентах и мешая стабильному обучению.

А что, если добавить вниманию немного физики?

Momentum это понятие из механики.
Если у тела есть скорость, оно не останавливается мгновенно, а плавно замедляется.
Почему бы не применить тот же принцип к вниманию?

Идея:

Пусть текущее внимание немного зависит от того, каким оно было раньше.
Не только “куда я смотрю сейчас?”,
но и “куда я смотрел мгновение назад?”.

От классического внимания к Momentum Attention

В классике:

Momentum Attention: когда внимание получает инерцию - 2

Теперь добавим инерцию к Value-векторам:

Пояснение: Если бы я добавил инерцию к attn_scores, модель была бы вынуждена смотреть на те же самые токены, что и на прошлом шаге. Это очень жесткое ограничение. Добавляя инерцию к V, я позволяю вниманию свободно выбирать, куда смотреть на каждом шаге (Q и K новые), но информация, которую оно извлекает (V), будет смесью новой и старой.

Momentum Attention: когда внимание получает инерцию - 3

Тогда:

Momentum Attention: когда внимание получает инерцию - 4

То есть текущее внимание теперь частично помнит, какие значения были важны на предыдущем шаге. α (например, 0.9) задаёт вес настоящего по сравнению с прошлым

Простой пример на pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class MomentumAttention(nn.Module):
    def __init__(self, d_model, n_heads=8, alpha=0.9):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model должен делиться на n_heads без остатка")
            
        self.alpha = alpha
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, prev_V=None):

        B, T_q, D = Q.shape
        _, T_k, _ = K.shape

        # Линейные проекции и разделение на головы
        q = self.W_q(Q).view(B, T_q, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_q, d_k]
        k = self.W_k(K).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]
        v = self.W_v(V).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]

        # Применение Momentum к векторам Value
        if prev_V is None:
            # На самом первом шаге инерции нет, используем текущее значение
            v_momentum = v
        else:
            # Совмещаем текущее значение с прошлым
            v_momentum = self.alpha * v + (1 - self.alpha) * prev_V

        # 3. Сохраняем новое состояние для следующего шага.
        # .detach() используется, чтобы градиенты не текли через всю историю состояний,
        # что превратило бы механизм в полноценный RNN и сильно усложнило бы обучение.
        new_prev_V = v_momentum.detach()

        # 4. Стандартный механизм self-attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Внимание применяется к инерционным значениям v_momentum
        out = torch.matmul(attn_weights, v_momentum)

        # 5. Собираем головы вместе и пропускаем через финальный линейный слой
        out = out.transpose(1, 2).contiguous().view(B, T_q, D)
        
        return self.W_o(out), new_prev_V


# Пример модели, которая использует MomentumAttention

class AutoregressiveModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, alpha):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.momentum_attn = MomentumAttention(d_model, n_heads, alpha)
        self.layernorm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
        self.out_proj = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):

        B, T = input_ids.shape
        x = self.embedding(input_ids)

        # Инициализируем состояние для всей последовательности
        prev_V_state = None
        all_step_outputs = []

        # Цикл по каждому шагу (токену) в последовательности
        for t in range(T):
            # Берем срез данных для текущего шага
            # В реальном декодере Q - это текущий токен, K и V - все предыдущие.
            # Для простоты демонстрации механизма инерции, мы используем только текущий токен
            # как Q, K, и V. Это показывает, как состояние `prev_V_state` передается.
            current_x_step = x[:, t:t+1, :] # Shape: [B, 1, D]
            
            # Вызываем слой внимания, передавая ему состояние с прошлого шага
            attn_output, prev_V_state = self.momentum_attn(
                Q=current_x_step, 
                K=current_x_step, 
                V=current_x_step, 
                prev_V=prev_V_state
            )
            
            # Стандартные блоки трансформера (residual connection, layernorm, FFN)
            h = self.layernorm(current_x_step + attn_output)
            step_output = self.ffn(h)
            all_step_outputs.append(step_output)

        # Собираем выходы со всех шагов в один тензор
        full_output = torch.cat(all_step_outputs, dim=1) # Shape: [B, T, D]
        
        # Финальная проекция в размер словаря
        logits = self.out_proj(full_output)
        return logits


# Параметры
batch_size = 4
seq_len = 10
vocab_size = 100
d_model = 64
n_heads = 8
alpha = 0.9

# Создаем модель
model = AutoregressiveModel(vocab_size, d_model, n_heads, alpha)

# Создаем случайные входные данные
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

print(f"Входные данные (shape): {input_ids.shape}")

# Получаем выход модели
output_logits = model(input_ids)

print(f"Выходные логиты (shape): {output_logits.shape}")

# Проверка корректности размеров
assert output_logits.shape == (batch_size, seq_len, vocab_size)

print("nМодель успешно отработала")

Что это даёт

  • Сглаживание представлений.
    Вектора V не перескакивают резко между шагами прошлое состояние частично сохраняется, что снижает турбулентность активаций.

  • Более стабильное распределение внимания.
    Модель получает эффект инерции в значениях, и внимание не скачет при малых изменениях входа. Это особенно полезно в авторегрессионных моделях, где выходы сильно зависят от предыдущего шага.

  • Облегчённое обучение.
    Так как prev_V передаётся через detach(), градиенты не текут сквозь всю историю, что предотвращает взрыв или затухание градиентов в отличие от полного RNN-подхода.

  • Простая интеграция.
    Механизм не требует изменения архитектуры он полностью совместим с обычным MultiHeadAttention и может быть вставлен в любой трансформерный блок.

Возможные минусы

  • Накопление смещения (drift).
    Если alpha слишком велико, старые состояния начинают тянуть новые векторные представления, и внимание может начать запоминать шум.

  • Сложность выбора alpha.
    Значение 0.9 подходит не всегда при быстрых изменениях контекста модель может терять реактивность (поздно реагировать на новые токены).

  • Невозможность параллелизации по времени.
    Так как состояние prev_V передаётся последовательно, обучение по всей последовательности становится менее параллельным (особенно при autoregressive setup).

  • Потенциальная инерция ошибок.
    Если модель делает ошибку на шаге t, она может частично переноситься дальше через prev_V, особенно при большом alpha.

Заключение

Momentum Attention это шаг в сторону более живых архитектур.
Мы не просто учим модель смотреть на токены,
мы учим её чувствовать движение своего внимания как будто у неё появилась инерция восприятия.

Автор: YH7H22

Источник

Rambler's Top100