Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler. imbalanced classification.. imbalanced classification. PyTorch.. imbalanced classification. PyTorch. weightedrandomsampler.

Как известно, если в наборе данных для обучения классификатора разные классы представлены в разном объёме, это может привести к ухудшению качества результата.

Одним из методов борьбы с дисбалансом классов является оверсемплинг, т.е. демонстрация классификатору в процессе обучения редких классов с повышенной частотой.

В исследовании 2017 года авторы утверждают, что из всех испробованных ими методов овесемплинг показал лучший результат и не приводил к переобучению классификаторов на основе свёрточных нейронных сетей.

Класс WeightedRandomSampler в PyTorch позволяет гибко настраивать оверсемплинг и избавляет от излишнего копирования данных внутри датасета. Однако, документации и других материалов по нему довольно мало.

В статье (за пейволлом) дано теоретическое обоснование работы этого класса. Здесь мы не будем углубляться в теорию вероятности, а вместо этого будем запускать код и смотреть на результат.

В статье будут даны ответы на следующие вопросы.

  1. Как вычислить веса для балансировки датасета?

  2. Поскольку подход использует генератор случайных чисел, можно ли быть уверенным, что датасет будет сбалансирован так, как мне нужно?

  3. Увидит ли модель все мои данные в процессе обучения?

  4. Как поступить, если мне целенаправленно нужно получить разбалансированный датасет? Как контролировать соотношение классов?

Блокнот Jupyter из оргинальной статьи доступен по ссылке.

Создание разбалансированного набора данных

В качестве основы возьмем набор данных Oxford Pets, который содержит фотографии 37 разных пород собак и кошек.

Выкачиваем и распаковываем данные:

Скрытый текст
$ wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz -P data/pets

--2022-08-23 19:00:06--  https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz
Resolving thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)... 129.67.95.98
Connecting to thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)|129.67.95.98|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 791918971 (755M) [application/octet-stream]
Saving to: ‘data/pets/images.tar.gz’

images.tar.gz       100%[===================>] 755.23M   138MB/s    in 5.1s    

2022-08-23 19:00:11 (149 MB/s) - ‘data/pets/images.tar.gz’ saved [791918971/791918971]

$ tar -xzf data/pets/images.tar.gz -C data/pets

Теперь посмотрим на них. Код ниже сканирует каталог с картинками из датасета и строит Pandas DataFrame из двух колонок: имя файла и метка.

Скрытый текст
from pathlib import Path
import pandas as pd

data_path = Path('data/pets/images')
images = []
labels = []

for p in data_path.glob('*.jpg'):
    image_name = p.parts[-1]
    images.append(image_name)
    labels.append('_'.join(image_name.split('_')[0:-1]))

df = pd.DataFrame(data={'image': images, 'label': labels})
df.head()

Результат

image

label

0

newfoundland_200.jpg

newfoundland

1

chihuahua_56.jpg

chihuahua

2

Abyssinian_31.jpg

Abyssinian

3

newfoundland_195.jpg

newfoundland

4

staffordshire_bull_terrier_74.jpg

staffordshire_bull_terrier

Сколько картинок для каждого класса можно узнать с помощью метода .value_counts()

Полный вывод df.value_counts()
newfoundland                  200
chihuahua                     200
shiba_inu                     200
Persian                       200
Bombay                        200
Siamese                       200
beagle                        200
saint_bernard                 200
Ragdoll                       200
miniature_pinscher            200
yorkshire_terrier             200
basset_hound                  200
Bengal                        200
german_shorthaired            200
Egyptian_Mau                  200
keeshond                      200
Maine_Coon                    200
Russian_Blue                  200
Sphynx                        200
boxer                         200
leonberger                    200
Abyssinian                    200
pug                           200
samoyed                       200
japanese_chin                 200
wheaten_terrier               200
British_Shorthair             200
havanese                      200
american_bulldog              200
pomeranian                    200
american_pit_bull_terrier     200
Birman                        200
great_pyrenees                200
english_cocker_spaniel        200
english_setter                200
scottish_terrier              199
staffordshire_bull_terrier    191

Датасет неплохо сбалансирован, для каждой породы есть по 200 картинок или чуть меньше.

Для целей исследования нам нужно выбрать из него разбалансированное подмножество. Функция create_dfs выбирает из датасета два класса с указанными именами и нужным соотношением семплов.

def create_dfs
from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split

def create_dfs(data_path, majority_class_name, minority_class_name, num_minority_class):
    data_path = Path(data_path)
    images = []
    labels = []

    for p in data_path.glob("*.jpg"):
        image_name = p.parts[-1]
        images.append(image_name)
        labels.append("_".join(image_name.split("_")[0:-1]))

    df = pd.DataFrame(data={"image": images, "label": labels})
    sample_df = df.query(f'label == "{majority_class_name}" or label == "{minority_class_name}"')
    train_df, val_df = train_test_split(
        sample_df, stratify=sample_df.label, train_size=0.8, random_state=42
    )
    keep_images = set(train_df.query(f'label == "{minority_class_name}"').image.head(num_minority_class).tolist())
    imbalanced_train_df = train_df[
        (train_df.image.isin(keep_images)) | (train_df.label == f"{majority_class_name}")
    ]

    return imbalanced_train_df, val_df

Посмотрим, как работает

imbalanced_train_df, val_df = create_dfs(data_path,
                                         "Siamese",
                                         "Birman",
                                         num_minority_class=10)

print('Train distribution')
print(imbalanced_train_df.label.value_counts())
print('==============')
print('Validation distribution')

val_df.label.value_counts()

Результат

Train distribution
Siamese    160
Birman      10
Name: label, dtype: int64
==============
Validation distribution
Birman     40
Siamese    40
Name: label, dtype: int64

Давайте закодируем метки классов целыми числами и создадим словарь с соответствием между текстовыми и целочисленными метками.

label_to_id = {v: idx for idx, v in
               enumerate(imbalanced_train_df.label.unique())}
{'Siamese': 0, 'Birman': 1}

Сиамская и бирманская породы кошек выбраны потому, что они немного похожи внешне, это можно увидеть, посмотрев на случайно выбранные картинки из датасета. Потом эти данные можно использовать для решения несложной, но не тривиальной задачи.

from PIL import Image

def show_image(row):
    print(row.label)
    return Image.open(data_path/row.image)

show_image(imbalanced_train_df.iloc[0])
show_image(imbalanced_train_df.iloc[7])

В результате получаются такие киски

Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 1

Визуализация

После того, как мы поглядели на данные, давайте разберемся с дисбалансом классов, которые модель будет видеть в каждом батче в ходе одной эпохи. Пока нам не надо загружать картинки, нас интересуют только метки классов. Поэтому для ускорения создадим TensorDataset из меток и индексов.

Скрытый текст
import torch
from torch.utils.data import DataLoader, TensorDataset

ds = TensorDataset(torch.as_tensor([(idx, label_to_id[l]) for idx, l in enumerate(imbalanced_train_df.label.values)]))
dl = DataLoader(ds, shuffle=True, batch_size=10)

Ниже приведена функция, которая проходит по датасету и сохраняет метки и индексы, появляющиеся из DataLoader-а. Если её аргумент with_outputs имеет значение True, результат печатается и рисуется на гистограмме.

visualize_dataloader
import numpy as np
import matplotlib.pyplot as plt

def visualise_dataloader(dl, id_to_label=None, with_outputs=True):
    
    total_num_images = len(dl.dataset)
    idxs_seen = []
    class_0_batch_counts = []
    class_1_batch_counts = []
        
    for i, batch in enumerate(dl):
        idxs = batch[0][:, 0].tolist()
        classes = batch[0][:, 1]
        class_ids, class_counts = classes.unique(return_counts=True)
        class_ids = set(class_ids.tolist())
        class_counts = class_counts.tolist()

        idxs_seen.extend(idxs)

        if len(class_ids) == 2:
            class_0_batch_counts.append(class_counts[0])
            class_1_batch_counts.append(class_counts[1])
        elif len(class_ids) == 1 and 0 in class_ids:
            class_0_batch_counts.append(class_counts[0])
            class_1_batch_counts.append(0)
        elif len(class_ids) == 1 and 1 in class_ids:
            class_0_batch_counts.append(0)
            class_1_batch_counts.append(class_counts[0])   
        else:
            raise ValueError('More than two classes detected')
            
    if with_outputs:
        
        fig, ax = plt.subplots(1, figsize=(15, 15))

        ind = np.arange(len(class_0_batch_counts))
        width = 0.35  
    
        ax.bar(ind, class_0_batch_counts, width, label=(id_to_label[0] if id_to_label is not None else '0'))
        ax.bar(ind + width, class_1_batch_counts, width, label=(id_to_label[1] if id_to_label is not None else '1'))
        ax.set_xticks(ind, ind+1)
        ax.set_xlabel('Batch index', fontsize=12)
        ax.set_ylabel('No. of images in batch', fontsize=12)
        ax.set_aspect('equal')
    
        plt.legend()
        plt.show()
        
        num_images_seen = len(idxs_seen)
        
        print(f'Avg Proportion of {(id_to_label[0] 
              if id_to_label is not None else "Class 0")} per batch: {(np.array(class_0_batch_counts)/10).mean()}')
        print(f'Avg Proportion of {(id_to_label[1] if id_to_label is not None else "Class 1")} per batch: {(np.array(class_1_batch_counts)/10).mean()}')
        print('=============')
        print(f'Num. unique images seen: {len(set(idxs_seen))}/{total_num_images}')
    return class_0_batch_counts, class_1_batch_counts, idxs_seen
       

Вот что увидела бы модель в процессе обучения

class_0_batch_counts, class_1_batch_counts, idxs_seen 
            = visualise_dataloader(dl, 
                                  {0: "Siamese (Majority class)",
                                   1: "Birman (Minority class)"})
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 2
Avg Proportion of Siamese (Majority class) per batch: 0.9411764705882353
Avg Proportion of Birman (Minority class) per batch: 0.05882352941176472
=============
Num. unique images seen: 170/170

В некоторых батчах вообще нет бирманских кошек. Классификатор увидел все 170 уникальных картинок.

Если взглянуть на нормализованные значения .value_counts(), цифры будут те же, что получены выше осреднением по батчам.

Скрытый текст
imbalanced_train_df.label.value_counts(normalize=True)

Siamese    0.941176
Birman     0.058824
Name: label, dtype: float64

Балансировка датасета с помощью WeightedRandomSampler

В справке к этому классу сказано, что его конструктору нужно задать список из весов ( weights ) и длину генерируемой последовательности ( num_samples ). Вес следует задать каждому изображению из нашего дасета.

Также указано, что сумма весов в списке не обязательно должна равняться 1. Это вводит в заблуждение, т.к. эти веса соответствуют вероятностям появления соответствующих семплов из датасета, и PyTorch их масштабирует, чтобы они суммировались к 1.

Исходный код WeightedRandomSampler доступен, он довольно прост. Ниже приведена его существенная часть, без проверок на корректность входных данных.

Скрытый текст
class WeightedRandomSampler(Sampler[int]):

    weights: torch.Tensor
    num_samples: int
    replacement: bool

    def __init__(
        self,
        weights: Sequence[float], num_samples: int,
        replacement: bool = True, generator = None,
    ) -> None:

        weights_tensor = torch.as_tensor(weights, dtype=torch.double)

        self.weights = weights_tensor
        self.num_samples = num_samples
        self.replacement = replacement
        self.generator = generator

    def __iter__(self) -> Iterator[int]:
        rand_tensor = torch.multinomial(
            self.weights, self.num_samples, 
            self.replacement, generator=self.generator
        )
        yield from iter(rand_tensor.tolist())

    def __len__(self) -> int:
        return self.num_samples

Теперь посмотрим, как можно задать веса для отдельных семплов из данных. Ранее мы уже определили с помощью метода .value_counts, сколько у нас картинок каждого класса: 160 картинок с сиамскими кошками и 10 с бирманскими.

Примем за вес каждого класса величину, обратную количеству картинок в этом классе:

class_counts = imbalanced_train_df.label.value_counts()
class_weights = 1 / class_counts 
class_weights
Siamese    0.00625
Birman     0.10000
Name: label, dtype: float64

Присваиваем вес каждому элементу данных,

sample_weights = [1.0 / class_counts[i] 
                  for i in imbalanced_train_df.label.values]

… создаем sampler и DataLoader, передавая ему sampler в конструкторе

from torch.utils.data import WeightedRandomSampler

sampler= WeightedRandomSampler(weights=sample_weights, 
                               num_samples=len(ds), 
                               replacement=True)

dl = DataLoader(ds, sampler=sampler, batch_size=10)

При создании объекта sampler, ему в конструкторе класса был указан аргумент replacement=True, без этого не получилось бы никакого оверсемплинга. Парамер num_samples задан равным размеру датасета, len(ds). Позже мы это еще обсудим.

Визуализируем распределение данных в батчах с новым семплером:

Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 3
Avg Proportion of Siamese (Majority class) per batch: 0.4882352941176471
Avg Proportion of Birman (Minority class) per batch: 0.5117647058823529
=============
Num. unique images seen: 71/170

Как видно, теперь классы распределены приблизительно одинаково, но за эпоху было выбрано чуть более трети данных.

Чтобы понять, какие именно картинки выбираются, мы можем создать DataFrame с количеством раз, которое каждая картинка за эпоху выбиралась для показа.

Скрытый текст
from collections import Counter

image_counts_df = (
    pd.merge(
        imbalanced_train_df.reset_index(drop=True)
                           .rename(columns={"index": "image_idx"}),
        pd.DataFrame.from_records(
            {"image_idx": k, "seen_count": v} for k, v in Counter(idxs_seen).items()
        ),
        how="left",
    )
    .fillna(0)
    .sort_values("seen_count", ascending=False)
    .reset_index(drop=True)
)

image_counts_df.query('label == "Birman"')

image_idx

image

label

seen_count

0

7

Birman_153.jpg

Birman

13.0

1

9

Birman_54.jpg

Birman

11.0

2

12

Birman_173.jpg

Birman

11.0

3

2

Birman_95.jpg

Birman

10.0

4

15

Birman_135.jpg

Birman

9.0

5

8

Birman_100.jpg

Birman

9.0

6

13

Birman_110.jpg

Birman

9.0

7

16

Birman_87.jpg

Birman

8.0

8

3

Birman_123.jpg

Birman

5.0

9

4

Birman_86.jpg

Birman

5.0

Для наглядности таблицу можно представить в виде графика выборочной функции распределения и изобразить его с спомощью функции ecdfplot из библиотеки Seaborn.

Код построения графика
import seaborn as sns

fig, ax = plt.subplots(figsize=(10, 5))
sns.ecdfplot(image_counts_df.query('label == "Birman"').seen_count, ax=ax, label='Birman')
sns.ecdfplot(image_counts_df.query('label == "Siamese"').seen_count, ax=ax, label='Siamese')
ax.set_xlabel('Количество показов', fontsize=12)
ax.set_ylabel('Доля изображений в данных', fontsize=12)
ax.legend(fontsize=12)
fig.tight_layout()
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 4

Из таблицы и графика видно, что для достижения нужной пропорции классов, каждое изображение редкого класса демонстрировалось от 5 до 13 раз за эпоху. Много изображений из доминирующего класса вообще не были показаны, а некоторые были показаны несколько раз.

К сожалению, это необходимая жертва. Можно было создать WeightedRandomSampler с параметром replacement=False, но тогда не было бы никакого оверсемплинга вообще.

Вот что будет с replacement=False
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 5
Avg Proportion of Siamese (Majority class) per batch: 0.9411764705882353
Avg Proportion of Birman (Minority class) per batch: 0.058823529411764705
=============
Num. unique images seen: 170/170

sampler в начале выбрал все изображения бирманских кошек, а потом выдавал только сиамских

Напрашивается вопрос: как обеспечить показ всех изображений в ходе обучения?

Как показать модели все данные в процессе обучения?

Настройка арумента num_samples

Если мы удвоим num_samples в конструкторе класса WeightedRandomSampler, показанных данных будет больше:

Код
sampler= WeightedRandomSampler(weights=sample_weights, 
                               num_samples=2*len(ds),   # <<<<< 
                               replacement=True)
dl = DataLoader(ds, sampler=sampler, batch_size=10)
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 6
Avg Proportion of Siamese (Majority class) per batch: 0.49117647058823527
Avg Proportion of Birman (Minority class) per batch: 0.5088235294117647
=============
Num. unique images seen: 114/170

В прошлый раз было показано 71 уникальное изображение из 170.

Обратите внимание на удвоившееся количество батчей, именно в этом проявляется эффект удвоения num_samples

Это рабочий вариант, но он размывает определение эпохи. В машинном обучении эпохой обычно называется один полный проход по всему набору данных. Если каждый элемент из датасета демонстрируется только один раз, из такого определения сразу понятно, что видела модель в каждый конкретный момент обучения. Если некоторые элементы датасета выбираются по нескольку раз, а некоторые пропускаются, ясности уже меньше.

Т.к. определение эпохи нужно по большей части нам для отслеживания процесса обучения и слабо связано с самой моделью (при обучении она просто наблюдает постоянный поток изображений), я предпочитаю оставить значение num_samples равным размеру датасета и ожидать, что рано или поздно на какой-то эпохе все данные будут показаны.

Сколько эпох потребуется, чтобы показать все изображения?

Давайте заново создадим все объекты и посмотрим, сколько эпох должно пройти до показа всего датасета полностью. Для этого проведем простой эксперимент, в котором будем перебирать датасет несколько раз, имитируя несколько эпох обучения, отслеживать количество уникальных показанных изображений и строить график этого количества в зависимости от числа эпох.

Код
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(ds), replacement=True)
dl = DataLoader(ds, sampler=sampler, batch_size=10)

num_epochs = 20

unique_idxs_seen = set()
cumulative_unique_images_per_epoch = {}

for i in range(num_epochs):
    class_0_batch_counts, class_1_batch_counts, idxs_seen = visualise_dataloader(
        dl, with_outputs=False
    )
    unique_idxs_seen.update(idxs_seen)
    cumulative_unique_images_per_epoch[i] = len(unique_idxs_seen)

target_epoch = list(cumulative_unique_images_per_epoch.values()).index(len(ds))

fig, ax = plt.subplots(1, figsize=(10, 10))

ax.plot(
    cumulative_unique_images_per_epoch.keys(),
    cumulative_unique_images_per_epoch.values(),
    label="Количество показанных уникальных изображений",
)
ax.plot(
    cumulative_unique_images_per_epoch.keys(),
    [len(ds)] * len(cumulative_unique_images_per_epoch),
    label="Размер датасета",
    linestyle="--",
)

ax.plot(target_epoch, len(ds), "go")
ax.legend(fontsize=12)
ax.set_xlabel("Количество эпох", fontsize=12)
ax.grid()

_ = ax.set_xticks(
    np.arange(20, step=2),
)
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 7

Здесь мы видим, что на 8 эпохе модели были показаны все изображения из датасета. Это число меняется в небольших пределах, если код запускать повторно. Для более надежной оценки воспользуемся идеями метода Монте-Карло, запустим эксперимент несколько раз и посмотрим, как распределяются результаты. Для наглядности результаты представим в виде графика ядерной оценки плотности.

Код
num_trials = 1000
num_epochs_needed = []

for t in range(num_trials):
    unique_idxs_seen = set()
    for i in range(num_epochs):
        class_0_batch_counts, class_1_batch_counts, idxs_seen = visualise_dataloader(dl, with_outputs=False)
        unique_idxs_seen.update(idxs_seen)
        if len(unique_idxs_seen) == len(ds):
            num_epochs_needed.append(i)
            break
            
fig, ax = plt.subplots(figsize=(10, 5))
sns.kdeplot(num_epochs_needed)
ax.set_xlabel('Количество эпох до показа всех изображений', fontsize=12)
ax.set_title('Распределение количества эпох до показа всех изображений по 1000 экспериментам')
fig.tight_layout()
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 8

Из графика видно, что после 9-10 эпох можно почти не сомневаться, что модель видела весь датасет.

Конечно, эта оценка сильно зависит от пропорций классов в разбаланстированном датаесете. Логично предположить, что чем сильнее дисбаланс, тем больше эпох понадобится для показа всего датасета. Мы можем подтвердить это, повторяя эксперимент с разными уровнями дисбаланса классов.

Код
from collections import defaultdict


def run_trial_for_different_imalances(data_path, minority_class_counts, label_to_id):
    num_epochs_needed = defaultdict(list)
    for c in minority_class_counts:
        t_df, _ = create_dfs(data_path, "Siamese", "Birman", c)
        t_class_counts = t_df.label.value_counts()
        t_class_weights = [1 / t_class_counts[i] for i in t_df.label.values]
        t_ds = TensorDataset(
            torch.as_tensor([(idx, label_to_id[l]) for idx, l in enumerate(t_df.label.values)])
        )
        t_sampler = WeightedRandomSampler(
            weights=t_class_weights, num_samples=len(t_ds), replacement=True
        )
        t_dl = DataLoader(t_ds, sampler=t_sampler, batch_size=5)

        num_trials = 1000
        for t in range(num_trials):
            unique_idxs_seen = set()
            for i in range(num_epochs):
                _, _, idxs_seen = visualise_dataloader(t_dl, with_outputs=False)
                unique_idxs_seen.update(idxs_seen)
                if len(unique_idxs_seen) == len(t_ds):
                    num_epochs_needed[c].append(i)
                    break

    return num_epochs_needed


num_majority_class = imbalanced_train_df.label.value_counts()["Siamese"]

num_epochs_needed = run_trial_for_different_imalances(
    data_path,
    [
        int(num_majority_class),
        int(num_majority_class / 2),
        int(num_majority_class / 4),
        int(num_majority_class / 8),
        int(num_majority_class / 16),
        int(num_majority_class / 32),
    ],
    label_to_id,
)

fig, ax = plt.subplots(figsize=(10, 5))
for num_items, frequencies in num_epochs_needed.items():
    sns.kdeplot(frequencies, ax=ax, label=f"{int(num_majority_class / num_items)}: 1")
ax.legend(title="Соотношение классов частый : редкий")
ax.set_xlabel("Количество эпох до показа всех данных", fontsize=12)
fig.tight_layout()
Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 9

График подтвердил наше предположение.

Интересная деталь: при использовании WeightedRandomSampler на сбалансированном датасете (синяя линия 1:1) нужно около 5 эпох для показа модели всех данных. Отсюда следует, что этот семплер не очень подходит для сбалансированных данных.

Целенаправленное создание разбалансированного датасета

Теперь посмотрим, как можно использовать WeightedRandomSampler для получения любого нужного нам соотношения классов в датасете, не обязательно равномерного.

Одним из возможных вариантов, где такое может понадобиться, является задача детектирования объектов. В датасетах для таких задач обычно встречается большое количество изображений с фоном, без того объекта, на котором должна сфокусироваться наша модель.

Давайте попробуем разбалансировать датасет так, чтобы редкий класс стал доминирующим

Еще раз убедимся, что датасет не поменялся
class_counts = imbalanced_train_df.label.value_counts(); 
class_counts
label
Siamese    160
Birman      10
Name: count, dtype: int64

Теперь зададим нужные нам пропорции классов.

target_proportions = {"Siamese": 0.1, "Birman": 0.9}

Новые веса семплов зададим, умножив целевую пропорцию на вес класса

sample_weights = [target_proportions[i] * 1.0 / class_counts[i] 
                  for i in imbalanced_train_df.label.values]

Снова визуализируем батчи:

Устранение дисбаланса классов в PyTorch с помощью WeightedRandomSampler - 10
Avg Proportion of Siamese (Majority class) per batch: 0.11764705882352941
Avg Proportion of Birman (Minority class) per batch: 0.8823529411764706
=============
Num. unique images seen: 28/170

Цель достигнута, и классы распределены в нужной нам пропорции.

Улучшается ли качество обучения с оверсемплингом

Надеюсь, к этому моменту у нас уже сформировалось понимание того, как работает WeightedRandomSampler.

Теперь можно проверить на практике, что дает повторяющийся показ одного и того же изображения при обучении сети.

Давайте обучим несложный классификатор на нашем разбалансированном датасете. Для этого будем использовать следующий рецепт.

  • Архитектрура: ResNet-RS50

  • Оптимизатор: AdamW

  • Планировщик скорости обучения: Cosine decay

  • Изображения сжимаются до размера 224×224 пикселя.

Сеть возьмем предобученную на ImageNet. Т.к. размер нашего обучающего набора довольно мал, упростим себе жизнь еще больше, тренируя только последний слой классификатора. Поскольку наши изображения довольно похожи на изображения из ImageNet, резонно предположить, что признаки, которые сеть научилась извлекать из ImageNet, здесь тоже сработают. Для валидации будем использовать набор, созданный ранее, содержащий по 40 изображений сиамских и бирманских кошек, мы его пока не использовали.

Далее автор публикации приводит код, создающий, обучающий и валидирующий сеть. Он доступен в оригинальном блокноте, и я его приводить не буду. На валидации код вычисляет Precision, Recall, F1 и Confusion Matrix и печатает их.

Результатов, сохраненных автором в блокноте, мне воспроизвести не удалось, поэтому в таблице ниже приведены результаты, полученные мной на 20 эпохах обучения.

Оверсемплинг

Аугментация данных

Эпоха

F1

Accuracy

Recall

Выключен

Выключена

11

0.769

0.8125

0.625

Выключен

Включена

14

0.861

0.875

0.775

Включен

Выключена

2

0.845

0.862

0.75

Включен

Включена

2

0.899

0.899

0.899

Что же, эффект есть! Предлагаю пользоваться.

Заключение

WeightedRandomSampler позволяет гибко настраивать баланс классов в обучающих данных, избавляя от излишнего их копирования. С его помощью можно как сбалансировать датасет, так и создать дисбаланс классов в нужных соотношениях. Если данные уже сбалансированы, то при его использовании понадобится удлиннить обучение на несколько эпох, чтобы модель гарантированно увидела все данные.

Автор: wl2776

Источник

Rambler's Top100