- BrainTools - https://www.braintools.ru -

Пробуем KAN (Kolmogorov-Arnold Networks) для классификации данных ЭЭГ

После ознакомления с несколькими статьями (ЭЭГ [1]. Архитектура KAN показалась перспективной благодаря своей способности моделировать сложные нелинейные зависимости, что может быть особенно полезно для анализа сигналов ЭЭГ.

Из‑за ограниченного времени было решено взять за основу готовое исследование и проверить эффективность KAN на тех же данных, используя преимущественно стандартные настройки.

Для классификации был выбран датасет, доступный по ссылке [2]. Набор данных включает записи ЭЭГ от 14 пациентов с параноидной шизофренией и 14 здоровых людей из контрольной группы. Данные были записаны с частотой дискретизации 250 Гц с использованием стандартной схемы размещения электродов 10–20 и 19 каналов: Fp1, Fp2, F7, F3, Fz, F4, F8, T3, C3, Cz, C4, T4, T5, P3, Pz, P4, T6, O1, O2. Предобработка и фильтрация данных выполнены в соответствии с примером на гитхаб [3].

Тестирование pykan

Первым этапом было тестирование библиотеки pykan [4] с различными конфигурациями, представленными в таблице:

Модель

width

grid

k

model_pykan_1

[19, 1]

3

3

model_pykan_2

[19, 9, 1]

3

3

model_pykan_3

[19, 9, 4, 1]

3

3

Основным изменяемым параметром был width. В дальнейшем планируется экспериментировать с другими параметрами.

Также можно визуализировать архитектуры моделей. Пример кода для построения графиков:

model_pykan_1 = KAN(width=[19, 1], grid=3, k=3)

x = torch.normal(0,1,size=(10, 19))

model_pykan_1(x)

figure(figsize=(38, 20), dpi=80)

model_pykan_1.plot(beta=100, )

Графики архитектур моделей

Графики архитектур моделей

Для разделения данных на обучающую и валидационную выборки использовался метод gkf.split(), который гарантирует, что данные одного пациента не попадут одновременно в обе выборки. Функция обучения [5]:

def kan_train(df_x, df_y, df_group, model):

accuracy_train = []

accuracy_test=[]

for train_index, val_index in gkf.split(df_x, df_y, groups=df_group):

train_features, train_labels = df_x[train_index], df_y[train_index]

val_features, val_labels = df_x[val_index], df_y[val_index]

dataset = dataset_user(train_features[:, -1, :], val_features[:, -1, :], train_labels, val_labels)

def train_acc():

return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():

return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

results = model.fit [6](dataset, opt="LBFGS", steps=50, metrics=(train_acc, test_acc));

print(results['train_acc'][-1], results['test_acc'][-1])

accuracy_train.append(results['train_acc'][-1])

accuracy_test.append(results['test_acc'][-1])

return accuracy_train, accuracy_test

После обучения были получены следующие результаты:

Модель

Test Acc

Val Acc

model_pykan_1

0.67

0.55

model_pykan_2

0.88

0.58

model_pykan_3

0.81

0.53

Как видно из таблицы, с увеличением количества слоёв модель начинает переобучаться.

Тестирование DeepKAN

Первым этапом было тестирование библиотеки deepkan [7] с различными конфигурациями, представленными в таблице:

Модель

input_dim

hidden_layers

num_knots

spline_order

model_deepkan_1

19

[1]

3

3

model_deepkan_2

19

[9, 1]

3

3

model_deepkan_3

19

[9, 4, 1]

3

3

Основным изменяемым параметром был hidden_layers (Список, определяющий размерности скрытых слоев). В дальнейшем планируется экспериментировать с другими параметрами.

После обучения были получены следующие результаты:

Модель

Test Acc

Val Acc

model_deepkan_1

0.5490

0.5476

model_deepkan_2

0.5480

0.5466

model_deepkan_3

0.5425

0.4730

Как видно из таблицы, данная реализация KAN ведет себя стабильней.

Тестирование модели с GitHub

Далее была протестирована модель, найденная на GitHub. Подробнее с ней можно ознакомиться по ссылке [3]. Сама модель выглядит так:

модель представленная на гитхаб

модель представленная на гитхаб

Результат: avg cross‑val acc: 0.6662.

Замена линейного слоя на SplineLinearLayer

После этого я заменил линейный слой на SplineLinearLayer из библиотеки KAN. Описание библиотеки доступно тут [7].

Получившиеся модель:

Модель с замененным линейным слоем на SplineLinearLayer

Модель с замененным линейным слоем на SplineLinearLayer

Результат: avg cross‑val acc: 0.6588.

Выводы

Учитывая, что модель с GitHub построена на довольно старой архитектуре, KAN на базовых настройках не показал значительного улучшения результата. Однако стоит отметить, что совместное использование KAN с другими методами может привести к более высоким результатам.

Тетрадка с представленными тестами [8]

Автор: MxaTs

Источник [9]


Сайт-источник BrainTools: https://www.braintools.ru

Путь до страницы источника: https://www.braintools.ru/article/12738

URLs in this post:

[1] ЭЭГ: http://www.braintools.ru/methods-for-studying-brain/electroencephalography-eeg

[2] ссылке: https://repod.icm.edu.pl/dataset.xhtml?persistentId=doi:10.18150/repod.0107441

[3] гитхаб: https://github.com/talhaanwarch/youtube-tutorials/blob/main/2.2%20EEG%20DL%20Classification.ipynb

[4] pykan: https://kindxiaoming.github.io/pykan/intro.html

[5] обучения: http://www.braintools.ru/article/5125

[6] model.fit: http://model.fit

[7] deepkan: https://pypi.org/project/Deep-KAN/

[8] Тетрадка с представленными тестами: https://colab.research.google.com/drive/1ct9h4dc3ypKKpr1LFRDOqZRx4T6BWkcX?usp=sharing

[9] Источник: https://habr.com/ru/articles/887204/?utm_campaign=887204&utm_source=habrahabr&utm_medium=rss

www.BrainTools.ru

Rambler's Top100