- BrainTools - https://www.braintools.ru -
После ознакомления с несколькими статьями (ЭЭГ [1]. Архитектура KAN показалась перспективной благодаря своей способности моделировать сложные нелинейные зависимости, что может быть особенно полезно для анализа сигналов ЭЭГ.
Из‑за ограниченного времени было решено взять за основу готовое исследование и проверить эффективность KAN на тех же данных, используя преимущественно стандартные настройки.
Для классификации был выбран датасет, доступный по ссылке [2]. Набор данных включает записи ЭЭГ от 14 пациентов с параноидной шизофренией и 14 здоровых людей из контрольной группы. Данные были записаны с частотой дискретизации 250 Гц с использованием стандартной схемы размещения электродов 10–20 и 19 каналов: Fp1, Fp2, F7, F3, Fz, F4, F8, T3, C3, Cz, C4, T4, T5, P3, Pz, P4, T6, O1, O2. Предобработка и фильтрация данных выполнены в соответствии с примером на гитхаб [3].
Первым этапом было тестирование библиотеки pykan [4] с различными конфигурациями, представленными в таблице:
|
Модель |
width |
grid |
k |
|
model_pykan_1 |
[19, 1] |
3 |
3 |
|
model_pykan_2 |
[19, 9, 1] |
3 |
3 |
|
model_pykan_3 |
[19, 9, 4, 1] |
3 |
3 |
Основным изменяемым параметром был width. В дальнейшем планируется экспериментировать с другими параметрами.
Также можно визуализировать архитектуры моделей. Пример кода для построения графиков:
model_pykan_1 = KAN(width=[19, 1], grid=3, k=3)
x = torch.normal(0,1,size=(10, 19))
model_pykan_1(x)
figure(figsize=(38, 20), dpi=80)
model_pykan_1.plot(beta=100, )
Для разделения данных на обучающую и валидационную выборки использовался метод gkf.split(), который гарантирует, что данные одного пациента не попадут одновременно в обе выборки. Функция обучения [5]:
def kan_train(df_x, df_y, df_group, model):
accuracy_train = []
accuracy_test=[]
for train_index, val_index in gkf.split(df_x, df_y, groups=df_group):
train_features, train_labels = df_x[train_index], df_y[train_index]
val_features, val_labels = df_x[val_index], df_y[val_index]
dataset = dataset_user(train_features[:, -1, :], val_features[:, -1, :], train_labels, val_labels)
def train_acc():
return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())
def test_acc():
return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
results = model.fit [6](dataset, opt="LBFGS", steps=50, metrics=(train_acc, test_acc));
print(results['train_acc'][-1], results['test_acc'][-1])
accuracy_train.append(results['train_acc'][-1])
accuracy_test.append(results['test_acc'][-1])
return accuracy_train, accuracy_test
После обучения были получены следующие результаты:
|
Модель |
Test Acc |
Val Acc |
|
model_pykan_1 |
0.67 |
0.55 |
|
model_pykan_2 |
0.88 |
0.58 |
|
model_pykan_3 |
0.81 |
0.53 |
Как видно из таблицы, с увеличением количества слоёв модель начинает переобучаться.
Первым этапом было тестирование библиотеки deepkan [7] с различными конфигурациями, представленными в таблице:
|
Модель |
input_dim |
hidden_layers |
num_knots |
spline_order |
|
model_deepkan_1 |
19 |
[1] |
3 |
3 |
|
model_deepkan_2 |
19 |
[9, 1] |
3 |
3 |
|
model_deepkan_3 |
19 |
[9, 4, 1] |
3 |
3 |
Основным изменяемым параметром был hidden_layers (Список, определяющий размерности скрытых слоев). В дальнейшем планируется экспериментировать с другими параметрами.
После обучения были получены следующие результаты:
|
Модель |
Test Acc |
Val Acc |
|
model_deepkan_1 |
0.5490 |
0.5476 |
|
model_deepkan_2 |
0.5480 |
0.5466 |
|
model_deepkan_3 |
0.5425 |
0.4730 |
Как видно из таблицы, данная реализация KAN ведет себя стабильней.
Далее была протестирована модель, найденная на GitHub. Подробнее с ней можно ознакомиться по ссылке [3]. Сама модель выглядит так:
Результат: avg cross‑val acc: 0.6662.
После этого я заменил линейный слой на SplineLinearLayer из библиотеки KAN. Описание библиотеки доступно тут [7].
Получившиеся модель:
Результат: avg cross‑val acc: 0.6588.
Учитывая, что модель с GitHub построена на довольно старой архитектуре, KAN на базовых настройках не показал значительного улучшения результата. Однако стоит отметить, что совместное использование KAN с другими методами может привести к более высоким результатам.
Автор: MxaTs
Источник [9]
Сайт-источник BrainTools: https://www.braintools.ru
Путь до страницы источника: https://www.braintools.ru/article/12738
URLs in this post:
[1] ЭЭГ: http://www.braintools.ru/methods-for-studying-brain/electroencephalography-eeg
[2] ссылке: https://repod.icm.edu.pl/dataset.xhtml?persistentId=doi:10.18150/repod.0107441
[3] гитхаб: https://github.com/talhaanwarch/youtube-tutorials/blob/main/2.2%20EEG%20DL%20Classification.ipynb
[4] pykan: https://kindxiaoming.github.io/pykan/intro.html
[5] обучения: http://www.braintools.ru/article/5125
[6] model.fit: http://model.fit
[7] deepkan: https://pypi.org/project/Deep-KAN/
[8] Тетрадка с представленными тестами: https://colab.research.google.com/drive/1ct9h4dc3ypKKpr1LFRDOqZRx4T6BWkcX?usp=sharing
[9] Источник: https://habr.com/ru/articles/887204/?utm_campaign=887204&utm_source=habrahabr&utm_medium=rss
Нажмите здесь для печати.