Автодифференцирование на C++: обратное распространение через лямбды и std::function. backpropagation.. backpropagation. graph.. backpropagation. graph. с++.

Меня зовут Кирилл Колодяжный, я разрабатываю системы хранения данных в YADRO. Это третья, заключительная часть моего цикла о паттернах C++, которые я применяю для решения задач машинного обучения, а вы можете использовать и в другой работе. В этой статье поговорим, как построить вычислительные графы и реализовать обратное распространение ошибки без сложных иерархий классов, с помощью лямбда-функций и стандартной библиотеки.

В конце материала я сравнил свой подход с вариантом PyTorch и оставил ссылки на полезные материалы, в том числе на предыдущие части цикла.

Недавно в рассылке для разработчиков на С++ я рассказал об еще одной полезной практике — проведении код-ревью по правилам от Google. В этом же выпуске особенностями проверки кода поделились мои коллеги из команды разработки ПО для KORNFELD и телекома. Оставьте email-адрес на странице подписки, чтобы получить это письмо.

Зачем разработчикам вычислительный граф

Когда мы обучаем нейросеть, входные данные проходят через десятки слоев, каждый из которых применяет свои веса, активации, нормализацию. В конце мы получаем решение, сравниваем его с правильным ответом и затем разбираемся, как изменить каждый параметр сети, чтобы ошибка уменьшилась.

Автодифференцирование на C++: обратное распространение через лямбды и std::function - 1

В математике это называется вычислением производной функции потерь по каждому весу. Но если в сети есть хотя бы 10 слоев, аналитическая формула производной превращается в чудовище, которое невозможно ни описать, ни отладить.

Хорошая новость: правило цепочки (chain rule) или правило дифференцирования сложной функции из матанализа позволяет разбить эту гигантскую производную на маленькие кусочки — производные элементарных операций (сложение, умножение, сигмоида).

Плохая новость: нужно как-то запомнить, в каком порядке применялись операции, чтобы потом пройти назад и собрать все эти кусочки вместе.

Именно для этого нам пригодится концепция вычислительного графа.

Вычислительный граф — это способ представить программу как направленный ациклический граф (DAG), в котором:

  • вершины — это операции (сложение, умножение, матричное умножение);

  • ребра — это связи, по которым передаются данные между операциями.

Вот как выглядит граф для простого выражения z = x * y + b:

Голубые узлы — входные данные (листья графа), оранжевые узлы — операции, зеленый узел — финальный результат.

Голубые узлы — входные данные (листья графа), оранжевые узлы — операции, зеленый узел — финальный результат.

Во время прямого прохода (forward pass) мы вычисляем значения снизу вверх: сначала x * y, потом результат плюс b. Во время обратного прохода (backward pass) идем сверху вниз и вычисляем градиенты: от z к x, y и b.

Класс Variable: операция и данные

Давайте перейдем к коду. Нам нужен класс, который будет представлять узел графа.

Класс Variable — это не только хранилище данных, но и контейнер для операции для обратного прохода. В нашей реализации мы также храним результат прямого прохода и градиент, но это скорее деталь архитектуры, а не обязательное свойство узла вычислительного графа. В других фреймворках данные и операции могут быть разделены.

class Variable {
public:
    Variable(const Tensor& data, 
             const std::vector<Variable>& children,
             bool requires_grad = false);

    void backward();

private:
    // Результат операции (данные, полученные после вычисления)
    Tensor data_;
    
    // Градиент — накапливаемая производная
    Tensor grad_;
    
    // Аргументы операции (предки в графе прямого прохода)
    std::vector<Variable> children_;
    
    // Функция обратного прохода для этой конкретной операции
    std::function<void(const Tensor&)> backward_fn_;
    
    // Флаг: нужно ли вычислять градиент для этого узла
    bool requires_grad_{true};
};
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 3
  • data_ — результат операции. Например, если это узел умножения x * y, то здесь будет произведение.

  • grad_ — градиент, который приходит с предыдущего шага обратного прохода. Изначально нулевой, накапливается при вызове backward().

  • children_ — входные аргументы операции. В примере z = x * y у узла z будет два аргумента: x и y. При прямом проходе это на самом деле предки данного узла, но на обратном их роль меняется.

  • backward_fn_ — функция, которая знает, как вычислить градиенты для этой операции.

  • requires_grad_ — оптимизация: не для всех узлов нужно вычислять градиенты (например, константы можно пропускать).

На самом деле далее я подразумеваю, что описанная функциональность и поля класса содержаться в объекте класса типа VariableImpl, а объекты типа Varibale хранят только указатели на них, что дает возможность копировать их.

Строим граф через перегрузку операторов

Теперь интересный момент: мы хотим, чтобы обычный код вида z = x * y + b автоматически строил граф. Для этого перегружаем операторы:

Variable operator*(const Variable& a, const Variable& b) {
    // 1. Прямой проход: вычисляем результат
    Variable result(a.data_ * b.data_, {a, b}, 
                    a.requires_grad_ || b.requires_grad_);

    // 2. Инициализируем функцию обратного прохода
    result.backward_fn_ = [a = а, b = b](const Tensor& out_grad) mutable {
        // Правило произведения: d(a*b)/da = b, d(a*b)/db = a
        if (a->requires_grad_)
            a->grad_ += b->data_ * out_grad;
        if (b->requires_grad_)
            b->grad_ += a->data_ * out_grad;
    };

    return result;
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 4
  1. Создаем новый узел result с данными получившимися в результате операции a.data_ * b.data_ и детьми(или предками) {a, b}.

  2. Записываем в backward_fn_ лямбду, которая знает правило дифференцирования для умножения.

  3. Через capture list [a = ..., b = ...] захватываем указатели на аргументы. Таким неявным способом создаются ребра графа. Это критически важно: когда позже мы будем вызывать backward(), нам понадобятся исходные значения a.data_ и b.data_ для вычисления градиентов.

  4. mutable позволяет изменять захваченные объекты (в данном случае — добавлять к grad_).

Аналогично реализуются +, -, matmul, sigmoid и другие операции. Каждая операция — содержит свою лямбду с правилом дифференцирования.

Обратный проход: топологическая сортировка

Теперь у нас есть граф. Как его пройти в обратном порядке?

Проблема в том, что граф — это нелинейная цепочка. Один узел может быть использован в нескольких местах. Например:

Автодифференцирование на C++: обратное распространение через лямбды и std::function - 5

Здесь x участвует в двух умножениях. При обратном проходе градиенты от обоих путей должны сложиться в x.grad_.

Решение — топологическая сортировка. Мы обходим граф в глубину (DFS), запоминаем порядок посещения, а потом разворачиваем его.

void build_topo_order(const Variable& node,
                      std::vector<Variable>& topo,
                      std::set<Variable>& visited) {
if (!visited.contains(node)) {
  visited.insert(node);
  for (auto& child : node->children_) {
    build_topo_order(child.get(), topo, visited);
  }
  topo.push_back(node);
}

void Variable::backward() {
    if (!requires_grad_) return;

    // 1. Топологическая сортировка (DFS)
    std::vector<Variable> topo_ordered;
    std::set<Variable> visited;
    build_topo_order(*this, topo_ordered, visited);

    // 2. Разворачиваем порядок (от листьев к корню)
    std::reverse(topo_ordered.begin(), topo_ordered.end());

    // 3. Инициализируем градиент корня как 1
    grad_ = Tensor::ones_like(data_);

    // 4. Последовательно вызываем backward_fn_ для каждого узла
    for (auto& v : topo_ordered) {
        if (v->requires_grad_ && v->grad_.defined()) {
            v->backward_fn_(v->grad_);
        }
    }
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 6

После этого все градиенты накапливаются в полях grad_, в листьях графа.

Обновляем веса: градиентный спуск

Финальный шаг — использовать градиенты для обновления параметров, то есть для обучения. Но откуда берутся эти параметры? Давайте посмотрим на примере простого линейного слоя:

class Linear {
public:
    Linear(size_t in_features, size_t out_features) {
        // Инициализируем веса случайными значениями
        // requires_grad = true, потому что веса нужно обучать
        weight_ = Variable(
            Tensor::randn({in_features, out_features}), 
            {}, 
            true  // <-- этот параметр нужно дифференцировать
        );
        
        // Инициализируем смещение (bias) нулями
        // тоже требует градиента
        bias_ = Variable(
            Tensor::zeros({out_features}), 
            {}, 
            true  // <-- bias тоже обучаемый параметр
        );
    }
    
    // Прямой проход: y = x @ W + b
    Variable forward(const Variable& input) {
        return matmul(input, weight_) + bias_;
    }
    
    // Возвращаем список всех обучаемых параметров слоя
    std::vector<Variable> parameters() {
        return {weight_, bias_};
    }
    
private:
    Variable weight_;  // Веса слоя
    Variable bias_;    // Смещение
};
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 7
  • weight_ — матрица весов размера [in_features × out_features]. Это основной обучаемый параметр, который сеть подстраивает в процессе обучения.

  • bias_ — вектор смещения размера [out_features]. Добавляется к результату умножения, позволяет сдвигать функцию активации.

  • requires_grad = true — критически важный флаг. Только для этих узлов будет вычисляться градиент при вызове backward().

  • parameters() — метод, который возвращает ссылки на все обучаемые параметры. Именно эти переменные мы будем обновлять в цикле градиентного спуска.

    Собираем все вместе: полный цикл обучения

    Теперь разберемся, как использовать Linear в реальном цикле обучения:

// Создаём слой: 784 входа (например, пиксели изображения), 10 выходов (классы)
Linear layer(784, 10);

// Собираем все параметры сети в один список
// В реальной сети может быть много слоёв, поэтому собираем от всех
std::vector<Variable> parameters = layer.parameters();

float learning_rate = 0.01f;

for (size_t epoch = 0; epoch < 100; ++epoch) {
    // 1. Прямой проход: получаем предсказания
    Variable predictions = layer.forward(input_data);
    
    // 2. Вычисляем функцию потерь (например, MSE)
    Variable loss = mse_loss(predictions, target_labels);
    
    // 3. Обратный проход: вычисляем градиенты всех параметров
    loss.backward();
    
    // 4. Обновляем веса через градиентный спуск
    for (auto& param : parameters) {
        if (param->requires_grad() && param->grad().defined()) {
            // Основное правило обновления: w = w - lr * grad
            param->data() -= param->grad() * learning_rate;
            
            // Сбрасываем градиент для следующей итерации
            // Иначе градиенты будут накапливаться между эпохами
            param->grad().zero_();
        }
    }
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 8

Ключевые моменты:

  1. layer.parameters() — возвращает параметры. В нашем случае это weight_ и bias_ из класса Linear.

  2. loss.backward() — после этого вызова в полях grad_ всех параметров (весов и смещений) накоплены корректные градиенты.

  3. param->data() -= param->grad() * learning_rate — градиентный спуск: двигает каждый параметр в направлении, противоположном градиенту, чтобы уменьшить функцию потерь.

  4. param->grad().zero_() — обязательный сброс градиентов. Если этого не сделать, градиенты будут суммироваться между итерациями, что приведёт к некорректному обучению.

Почему мы не обновляем градиенты прямо в backward_fn_?

Тут происходит разделение ответственности: backward() только вычисляет градиенты и отвечает на вопрос «в какую сторону менять параметры?». Обновление весов — это отдельный шаг, который отвечает на вопрос, насколько сильно менять. Это разделение позволяет:

  • использовать разные оптимизаторы (SGD, Adam, RMSprop) без изменения логики обратного прохода,

  • легко добавлять регуляризацию (L2, L1) на этапе обновления весов,

  • делать градиентный клиппинг (ограничение градиентов) перед обновлением.

Так и работает градиентный спуск: после вычисления всех градиентов мы обновляем параметры сети, чтобы минимизировать функцию потерь.

А зачем каждый раз создавать новую лямбда-функцию? Не проще ли сделать иерархию классов, как в промышленных фреймворках? Давайте посмотрим, как это реализовано в PyTorch на уровне C++.

В основе autograd PyTorch лежит система классов, определенных в torch/csrc/autograd/. Основные компоненты рассмотрим ниже.

Класс Node, базовый класс для всех операций:

struct TORCH_API Node : std::enable_shared_from_this<Node> {
    // ... другие поля и методы
    
    // Оператор вызова функции — это точка входа в узел
    variable_list operator()(variable_list&& inputs) {
        // Здесь происходит логика вызова apply() с дополнительной обработкой
        // (проверки, трассировка, обработка ошибок)
        return apply(std::move(inputs));
    }

protected:
    // Чисто виртуальный метод, который переопределяется для каждой операции
    // Именно здесь реализуется логика вычисления градиентов
    virtual variable_list apply(variable_list&& inputs) = 0;
    
    // Список рёбер для связи с другими узлами
    // Каждое ребро указывает на следующий узел в графе
    edge_list next_edges_;
    
    // ... другие поля
};
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 9
  • operator() — это точка входа, через которую вызывается узел. Когда движок autograd обходит граф, он вызывает этот оператор для каждого узла.

  • apply() — виртуальный метод, который переопределяется в каждом конкретном классе операции (например, MulBackward0). Именно здесь содержится логика вычисления градиентов.

  • next_edges_ — список ребер, которые связывают этот узел с другими узлами графа. Через эти ребра происходит обход графа при обратном проходе.

Структура Edge, которая связывает узлы между собой:

struct Edge {
    // Указатель на узел (функцию), к которой ведёт это ребро
    std::shared_ptr<Node> function;
    
    // Номер входа функции — идентифицирует конкретный вход среди нескольких
    uint32_t input_nr;
};
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 10
  • function — указатель на узел графа. Это позволяет ребрам хранить ссылки на другие операции.

  • input_nr — номер входа функции. Если у узла несколько входов (например, умножение принимает два аргумента), это поле указывает, к какому именно входу относится данное ребро.

 Класс AutogradMeta, метаданные тензора:

struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
    // Накопленный градиент для этого тензора
    Variable grad_;
    
    // Указатель на функцию градиента — узел, который вычислит градиент
    // для этой операции при обратном проходе
    std::shared_ptr<Node> grad_fn_;
    
    // Аккумулятор градиентов — используется, когда тензор участвует
    // в нескольких операциях и нужно собрать градиенты со всех путей
    std::weak_ptr<Node> grad_accumulator_;
    
    // ... другие поля
};
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 11
  • grad_ — здесь накапливается градиент для данного тензора. При каждом вызове backward() градиенты суммируются в этом поле.

  • grad_fn_ — указатель на узел графа, который отвечает за вычисление градиента этой операции. Именно через этот указатель движок autograd знает, какую функцию вызвать при обратном проходе.

  • grad_accumulator_ — специальный механизм для случая, когда тензор используется в нескольких операциях. Например, если один и тот же тензор участвует в двух разных путях вычислений, градиенты от обоих путей нужно сложить.

Объекты этого класса создаются, например, так:

Tensor at::empty(IntArrayRef size, const TensorOptions& options) {
    // Создаём базовый тензор с данными
    Tensor tensor = detail::empty_aten_default(size, options);
    
    // Если требуется градиент — создаём AutogradMeta
    if (options.requires_grad()) {
        // Выделяем метаданные для autograd
        // Это происходит только если requires_grad = true
        tensor.mutable_autograd_meta() = 
            c10::make_intrusive<AutogradMeta>();
    }
    
    return tensor;
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 12

Конкретная операция в PyTorch, пример умножения тензоров:

// Упрощённая версия
Tensor mul(const Tensor& self, const Tensor& other) {
    // 1. Вычисляем результат прямого прохода
    Tensor result = at::mul(self, other);
    
    // 2. Проверяем, нужно ли строить граф
    if (self.requires_grad() || other.requires_grad()) {
        // 3. Создаём узел обратного прохода
        // Сохраняем self и other для использования в backward()
        auto grad_fn = std::make_shared<MulBackward0>(self, other);
        
        // 4. Собираем рёбра от входных тензоров
        grad_fn->set_next_edges(collect_next_edges({self, other}));
        
        // 5. Связываем результат с узлом графа
        // Теперь result.grad_fn() указывает на MulBackward0
        set_history(result, grad_fn, 0);
    }
    
    return result;
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 13

На самом деле эта и другие функции в PyTorch в основном создаются автоматической кодогенерацией при сборке библиотеки. Существует специальный набор правил в формате YAML.

Связывание результата с узлом через set_history:

inline void set_history(
    Tensor& self,
    std::shared_ptr<Node> grad_fn,
    uint32_t output_nr = 0) {
    
    // Получаем или создаём AutogradMeta для тензора
    auto* meta = self.mutable_autograd_meta();
    
    // Устанавливаем указатель на функцию градиента
    // Это связывает тензор с узлом графа
    meta->set_grad_fn(std::move(grad_fn));
    
    // Устанавливает номер выхода (если у узла несколько выходов)
    meta->set_output_nr(output_nr);
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 14
  • mutable_autograd_meta() — гарантирует, что у тензора есть AutogradMeta. Если нет — создает.

  • set_grad_fn() — устанавливает связь между тензором и узлом графа. Теперь при вызове backward() движок знает, какую функцию для расчета градиента вызывать для этого тензора.

  • output_nr — номер выхода узла. Если операция возвращает несколько тензоров (например, torch.topk), каждый выход имеет свой номер.

    А вот так реализуется функция для расчета градиента:

variable_list MulBackward0::apply(variable_list&& grads) {
     ...
    
    // Создаём список для хранения градиентов по входам
    // Размер списка равен количеству входов, для которых нужно вычислить градиент
    variable_list grad_inputs(gen.size());
    
    // Извлекаем градиент с предыдущего шага обратного прохода
    // grads[0] — это производная функции потерь по выходу этой операции
    auto& grad = grads[0];
    
    ...    
    
    // Проверяем, определён ли хотя бы один градиент во входном списке
    // Это нужно для корректной обработки случаев, когда некоторые
    // градиенты могут быть неопределены
    bool any_grad_defined = any_variable_defined(grads);
    
    // Проверяем, нужно ли вычислять градиент для первого аргумента (self)
    if (should_compute_output({ self_ix })) {
        // Вычисляем градиент по первому аргументу: d(a*b)/da = b * grad
        // Если градиент определён, умножаем grad на первый аргумент (self)
        // Умножаем grad на второй аргумент (other)
        auto grad_result = any_grad_defined ? 
            (mul_tensor_backward(grad, other, self_scalar_type)) : Tensor();
        // mul_tensor_backward — функция для обратного умножения тензоров
        
        // Копируем вычисленный градиент в соответствующую позицию
        copy_range(grad_inputs, self_ix, grad_result);
    }
    
    // Повторяем для вторго аргумента
    ...
    
    // Возвращаем список вычисленных градиентов по всем входам операции
    // Этот список будет передан предыдущим узлам графа
    return grad_inputs;
}
Автодифференцирование на C++: обратное распространение через лямбды и std::function - 15

Как это работает вместе:

  1. При создании тензора с requires_grad=True создаётся объект AutogradMeta.

  2. При вызове операции (например, умножения) создаётся новый узел MulBackward0, наследующийся от Node.

  3. Функция collect_next_edges() собираёт ребра (Edge) от входных тензоров.

  4. Метод set_history() связывает результат операции с созданным узлом.

  5. При обратном проходе движок autograd обходит граф через Edge и вызывает apply() для каждого узла.

Сравнение подходов

PyTorch (иерархия классов)

Преимущества:

  • Явная структура: каждая операция — отдельный класс с четким интерфейсом.

  • Возможность переиспользования классов операций.

  • Поддержка сложных оптимизаций (fusion, pruning графа).

Недостатки:

  • Требует написания двух методов (forward и backward) для каждой операции.

  • Отдельный механизм для сохранения промежуточных значений (self_.unpack(), other_.unpack()).

  • Более сложный код для добавления новых операций.

  • Много шаблонного кода (генерация индексов, проверка should_compute_output).

Наш подход (лямбды + std::function)

Преимущества:

  • Все в одном месте: forward и backward записаны рядом в лямбде.

  • Не нужно создавать иерархии классов и наследоваться.

  • Гибкость: можно захватить любое состояние через capture list.

  • Проще для понимания и прототипирования.

  • Меньше шаблонного кода.

Недостатки:

  • Создается новая лямбда для каждой операции (хотя компилятор может это оптимизировать).

  • Невозможно анализировать граф для экспорта и оптимизаций (например, для автоматического слияния операций).

Давайте разберемся, достаточно ли хорошую производительность обеспечивает наш подход. Для этого рассмотрим потенциальные проблемы.

Проблема 1: вы создаете новую лямбда-функцию для каждого узла графа. Это же накладные расходы!

Ответ: во-первых, Variable и Tensor у нас реализованы по паттерну pImpl (или используют shared_ptr), поэтому «копирование» узла — это копирование указателя, а не всех данных. Во-вторых, лямбда — это просто объект-функтор, и современный компилятор C++ может его оптимизировать (inline, devirtualization).

Проблема 2: а зачем хранить backward_fn_ внутри узла? Нельзя ли вынести в отдельную таблицу?

Ответ: можно, но тогда придется отдельно хранить состояние для каждой операции (например, значения a.data_ и b.data_ для умножения). Лямбда с захватом — это естественный способ инкапсулировать и код, и данные вместе. В PyTorch для этого используется AutogradMeta и механизм save_for_backward, что требует дополнительной инфраструктуры .

Проблема 3: а как же производительность по сравнению с PyTorch?

Ответ: PyTorch использует отдельную иерархию классов для узлов и ребер графа, что дает больше контроля над памятью и позволяет применять сложные оптимизации. Но для образовательных целей или легковесного фреймворка наш подход вполне адекватен. Если нужно больше производительности — можно добавить пул объектов и избежать лишних аллокаций. Но это уже оптимизации другого порядка, которые стоит применять после получения результатов нагрузочного тестирования и профилирования.

Что дальше: тензорные компиляторы и оптимизация графов

Наш подход с явным построением графа — это только начало. В современных фреймворках вычислительные графы не просто строятся, но и оптимизируются, компилируются и упрощаются с помощью специальных компиляторов. Давайте разберем основные направления.

XLA (Accelerated Linear Algebra)

XLA — это компилятор линейной алгебры, разработанный Google для TensorFlow, а теперь используемый также в JAX и PyTorch.

Как это работает:

  1. Захват графа: XLA получает вычислительный граф (в нашем случае — это последовательность операций Variable).

  2. Оптимизации:

    • Operator Fusion: несколько операций объединяются в одну. Например, вместо отдельных ядер для a * b + c создается одно ядро, которое делает все сразу. Это уменьшает накладные расходы на запуск ядер и чтение/запись памяти.

    • Constant Folding: константные выражения вычисляются на этапе компиляции, а не во время выполнения.

    • Dead Code Elimination: удаляются операции, результаты которых нигде не используются.

    • Common Subexpression Elimination: если одно и то же выражение используют несколько раз, оно вычисляется один раз и кешируется.

  3. Компиляция: оптимизированный граф компилируется в машинный код для конкретной архитектуры (CPU, GPU, TPU).

  4. Выполнение: скомпилированный код выполняется напрямую, без интерпретации графа.

PyTorch 2.0 и Torch Dynamo

В PyTorch 2.0 появился Torch Dynamo — компилятор, который захватывает граф динамически во время выполнения.

Ключевое отличие от нашего подхода:

Наш код строит граф явно через перегрузку операторов. Dynamo работает иначе:

  • Bytecode Analysis: Dynamo анализирует байт-код Python функции во время выполнения.

  • Graph Capture: когда функция вызывается, Dynamo «перехватывает» операции и строит граф (через torch.fx).

  • Backend Compilation: граф передается в бэкенд (например, inductor, XLA, TensorRT) для оптимизации и компиляции.

  • Caching: скомпилированный граф кэшируется для повторного использования.

Torch Inductor — это компилятор по умолчанию в PyTorch 2.0, который генерирует оптимизированный код для GPU через Triton (язык для написания GPU-ядер).

JAX: функциональный подход к графам

JAX — это фреймворк Google, который использует функциональное программирование и XLA для компиляции.

Основные принципы:

  • Pure Functions: функции в JAX должны быть чистыми (без побочных эффектов). Это упрощает анализ и оптимизацию графа.

  • Functional Transforms: JAX предоставляет трансформеры:

    • jax.grad — автоматическое дифференцирование (аналог нашего backward()).

    • jax.jit — компиляция через XLA.

    • jax.vmap — векторизация (автоматический batching).

    • jax.pmap — параллелизация на нескольких устройствах.

  • Static Graph: после применения jit граф становится статическим (не меняется между вызовами), что позволяет проводить агрессивные оптимизации

Сравнение подходов к графам

Подход

Когда строится граф

Оптимизации

Компиляция

Наш подход

Во время выполнения (eager mode)

Нет

Нет

PyTorch (eager)

Во время выполнения (eager mode)

Частично

Нет

PyTorch 2.0 (Dynamo)

Динамически при первом вызове

Fusion, constant folding, …

Inductor/XLA

TensorFlow 1.x

Статически до выполнения

Полные оптимизации

XLA

TensorFlow 2.x + jit

При декорации @tf.function

Полные оптимизации

XLA

JAX

При вызове jax.jit

Агрессивные оптимизации

XLA

Наш подход — это базовый уровень понимания вычислительных графов. Современные компиляторы делают то же самое, но на несколько уровней выше. Но это важно понимать как для изучения более сложных подходов, так и для выбора инструментов для работы. Например, для прототипирования удобно использовать eager-режим, когда граф строится в императивном стиле. Но для развертывания в продакшене лучше использовать скомпилированные и оптимизированные графы.


Итак, что мы сделали:

  • Определили Variable как узел вычислительного графа, который представляет операцию над данными.

  • Перегрузили операторы так, чтобы они автоматически строили граф и записывали лямбды для обратного прохода.

  • Реализовали backward() через топологическую сортировку и последовательный вызов backward_fn_.

  • Обновили веса через обычный градиентный спуск.

Все это умещается в один класс без иерархий наследования, при этом используется только стандартная библиотека C++ и лямбда-выражения.

Полезные ссылки для тех, кто хочет подробнее изучить тему:

Автор: Mik42

Источник

Rambler's Top100