Shap-графики: как наглядно объяснить заказчику логику работы модели. CatBoost.. CatBoost. data science.. CatBoost. data science. machine learning.. CatBoost. data science. machine learning. python.. CatBoost. data science. machine learning. python. shap.. CatBoost. data science. machine learning. python. shap. моделирование.
Shap-графики: как наглядно объяснить заказчику логику работы модели - 1

Всем привет. Я Андрей Бояренков, лидер кластера бизнес-моделей стрима “Разработка моделей КИБ и СМБ” банка ВТБ.

Наш кластер отвечает за:

 – выстраивание и внедрение процессов AutoML,

 – за разработку моделей для процессов: ПОДФТ Precollection ЖЦК (жизненного цикла клиента),

 – а также за разработку моделей цифровых помощников для различных подразделений банка, в т.ч. риск-андеррайтинга, операционных рисков и комплаенса.

Участие заказчика в процессах разработки и применения моделей является достаточно важным. Как правило, оно является ключевым на следующих этапах:

– постановка задачи на разработку модели, включая определение сегмента и целевой переменной,

– согласование лонг-листа фичей модели и методологии их расчета,

– прием результатов разработки модели (подтверждение соответствия метрик качества модели изначально заявленным),

– подтверждение бизнес-логики работы фичей в модели.

В данной статье хочу рассказать о том, какие на мой взгляд типы графиков необходимо построить, чтобы наиболее оптимальным образом показать заказчику логику работы фичей в моделях. В качестве ограничения установим, что целью является показать логику работы фичей не на уровне конкретного наблюдения, а на уровне выборки в целом. 

Графики покажем на примере находящегося в открытом доступе датасете telecom_churn. В частности, его можно найти по указанным ссылкам: 

1. https://www.kaggle.com/datasets/keyush06/telecom-churncsv

2. https://www.kaggle.com/datasets/nikkitha8/telecom-churn

3. https://www.kaggle.com/code/kashnitsky/topic-1-exploratory-data-analysis-with-pandas

4. https://habr.com/ru/companies/ods/articles/322626/

Ниже представлены основные шаги для обработки выборки telecom_churn в целях ее дальнейшего использования:

1. Сначала импортируем библиотеки которые пригодятся для нашего исследования.

2. Далее создадим датафрейм data с выборкой telecom_churn, скорректируем названия строк и присвоим бинарный тип данных целевой переменной churn. Детальный EDA (Explanatory Data Analysis) проводить не будем, т.к. цель статьи заключается только в том чтобы показать возможные графики для интерпретации работы фичей.

3. Для расширения признакового пространства также рассчитаем ряд дополнительных фичей.

4. Сделаем дополнительную предобработку типов данных.

5. Создадим лонг-лист фичей features и список категориальных фичей cat_feat. 

6. Разделим выборку со стратификацией по churn на train (75%) и test (25%) для проверки качества модели на независимой выборке.

7. Выборку train разделим на train (60%) для обучения модели и val (15%) для использования критерия ранной остановки.

Код на Python для проведения данных действий с выборкой см. ниже:

import pandas as pd
import re
import shap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import make_scorer, roc_auc_score
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv('telecom_churn.csv', sep=',') # чтение датафрейма
data.columns = data.columns.str.replace(' ', '_') # корректировка названий строк
data['churn'] = data['churn'].map({True: 1, False: 0}).astype('int32')
#расчет суммы показателей за все время
col_minutes = [x for x in data.columns if len(re.findall('total_w*_minutes', x)) != 0]
col_calls = [x for x in data.columns if len(re.findall('total_w*_calls', x)) != 0]
col_charge = [x for x in data.columns if len(re.findall('total_w*_charge', x)) != 0]
data['total_minutes'] = data[col_minutes].sum(axis=1)
data['total_calls'] = data[col_calls].sum(axis=1)
data['total_charge'] = data[col_charge].sum(axis=1)
data['charge_per_minute'] = data['total_charge'] / data['total_minutes'] #стоимость совокупной минуты
data['charge_per_call'] = data['total_charge'] / data['total_calls'] #стоимость одного звонка
data['minutes_per_call'] = data['total_minutes'] / data['total_calls'] #продолжительность одного звонка
#сделать категориальными фичи по которым уникальных значений от 3 до 9 включительно
for column in data.columns: if data[column].nunique() < 10 and data[column].nunique() >= 3: data[column] = data[column].astype('str') #сделать бинарными фичи по которым 2 уникальных значения и при этом они являются типами 'object' или 'bool'
bool_columns = []
for column in data.dtypes[(data.dtypes=='object')|(data.dtypes=='bool')].index: if data[column].nunique() == 2: bool_columns.append(column)
le = preprocessing.LabelEncoder()
for column in bool_columns: data[column] = le.fit_transform(data[column])
# для того чтобы показать как графики с интерпретацией фичей работают в том числе на категориальных фичах
# искусственно сделаем фичу 'customer_service_calls' категориальной (для каждого значения введем отдельное обозначение
# от 'A' до 'F', а для значений 6 и более создадим отдельную категорию 'G' в которую попадет чуть более 1% выборки
data['customer_service_calls'] = data['customer_service_calls'].replace( {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}).mask(data['customer_service_calls'] >= 6, 'G')
# Создадим лонг-лист фичей features и список категориальных фичей cat_feat (исключен идентификатор phone_number и таргет churn)
features = data.columns.to_list()
features = [x for x in features if x not in ['phone_number','churn']]
cat_feat = data.select_dtypes(exclude = [np.number]).columns.to_list()
cat_feat = [x for x in cat_feat if x not in ['phone_number','churn']]
# Разделим выборку со стратификацией по churn на трейн(75%) и тест(25%) для проверки качества модели на независимой выборке:
x_train,x_test,y_train,y_test=train_test_split(data, data['churn'], stratify = data['churn'], test_size = 0.25, random_state=42)
x_train,x_val,y_train,y_val=train_test_split(x_train, y_train, stratify = y_train, test_size = 0.20, random_state = 42)

Далее на всех фичах из лонг–листа обучим Catboost небольшой глубины и c небольшим количеством деревьев. Нам это нужно для получения списка наболее значимых фичей.

Полученная модель на тестовой выборке показала коэффициент Джини 83,08%.

Значимость оценим с помощью get_feature_importance() в основе которого алгоритм PredictionValueChange, который показывает, насколько в среднем изменится прогноз модели при изменении значения одной из фичей.

При изменении значения фичи замена осуществляется на среднее значение фичи по выборке, а для категориальных признаков — на самый частый категориальный класс.

Выберем пять наиболее значимых фич. Именно на них мы построим новую модель и проанализируем графики для интерпретации.

Как видно из таблицы пятью самыми значимыми получились: ‘total_charge’, ‘customer_service_calls’, ‘international_plan’, ‘total_intl_calls’ и ‘number_vmail_messages’.

Для информации выведем в отдельной таблице аналитику по каждой из них. Сделать это можно с помощью стандартной функции describe с дополнительной обработкой для добавления полезной информации.

Код на Python для проведения данных действий см. ниже:

def gini(y_true, y_pred): gini = 2 roc_auc_score(y_true, y_pred) - 1 return gini
gini_scorer = make_scorer(gini, greater_is_better = True)
params = { "verbose": False, "eval_metric": 'Logloss', 'iterations': 1000, 'random_state': 42, 'early_stopping_rounds': 10, 'max_depth': 4
}
model = CatBoostClassifier(*params)
model.fit(x_train[features], y_train, eval_set = (x_val[features], y_val), cat_features = cat_feat)
y_pred_train = model.predict_proba(x_train[features]).T[1]
y_pred_val = model.predict_proba(x_val[features]).T[1]
y_pred_test = model.predict_proba(x_test[features]).T[1]
gini_train = np.round(gini(y_train, y_pred_train),3)
gini_val = np.round(gini(y_val, y_pred_val),3)
gini_test = np.round(gini(y_test, y_pred_test),3)
print('Джини бустинга на всех фичах:', 'train', gini_train, 'val', gini_val, 'test', gini_test)
fe_stats = pd.DataFrame({'feature_importance': model.get_feature_importance(), 'feature_names':features}).sort_values(by=['feature_importance'], ascending=False)
display(fe_stats[0:5])
features = fe_stats[0:5].feature_names.to_list()
cat_feat = ['customer_service_calls']


target = ['churn']
data = data[features + target].copy()
data_info = data.describe(percentiles = [0.01, 0.05, 0.5, 0.95, 0.99], include=list(np.unique(data[data.columns].dtypes.astype('str').values))).T
data_info['type'] = data[data.columns].dtypes
data_info['null'] = data[data.columns].isnull().sum()
data_info['null%'] = np.round(data[data.columns].isnull().mean() * 100, 1)
data_info['nunique'] = data[data.columns].nunique()
data_info['count'] = data_info['count'].astype('int')
data_info = data_info.drop('unique', axis = 1)
data_info.sort_values(by = ['type','nunique'], ascending = [False, False], inplace = True)
data_info=data_info.loc[:,['type','count','nunique','null','null%','top','freq','min','max','mean','std','1%','5%','50%','95%','99%']]
data_info[['min','max','mean','std','1%','5%','50%','95%','99%']] = data_info[['min','max','mean','std','1%','5%','50%','95%','99%']].astype(float).round(4)
display(data_info)

Джини бустинга на всех фичах: train 0.934 val 0.859 test 0.831

feature_importance

feature_names

21

39.692171

total_charge

18

17.874816

customer_service_calls

3

14.006796

international_plan

16

6.891904

total_intl_calls

5

5.593090

number_vmail_messages

type

count

nunique

null

null%

top

freq

min

max

mean

std

1%

5%

50%

95%

99%

customer_service_calls

object

3333

7

0

0.0

B

1181

NaN

NaN

NaN

NaN

NaN

NaN

NaN

NaN

NaN

total_charge

float64

3333

2678

0

0.0

NaN

NaN

22.93

96.15

59.4498

10.5023

33.8532

42.338

59.47

76.516

83.8396

number_vmail_messages

int64

3333

46

0

0.0

NaN

NaN

0.00

51.00

8.0990

13.6884

0.0000

0.000

0.00

36.000

43.0000

total_intl_calls

int64

3333

21

0

0.0

NaN

NaN

0.00

20.00

4.4794

2.4612

1.0000

1.000

4.00

9.000

13.0000

international_plan

int32

3333

2

0

0.0

NaN

NaN

0.00

1.00

0.0969

0.2959

0.0000

0.000

0.00

1.000

1.0000

churn

int32

3333

2

0

0.0

NaN

NaN

0.00

1.00

0.1449

0.3521

0.0000

0.000

0.00

1.000

1.0000

Далее обучим catboost только на пяти самых значимых фичах. Полученная модель на тестовой выборке показала коэффициент Джини 82,5%.

params = {

    "verbose": False,

    'eval_metric': 'Logloss',

    'iterations': 1000,

    'early_stopping_rounds': 10,

    'depth': 4, #по умолчанию 6

    'random_state': 42,

}

model = CatBoostClassifier(**params)

model.fit(x_train[features], y_train, eval_set = (x_val[features], y_val), cat_features = cat_feat)


y_pred_train = model.predict_proba(x_train[features]).T[1]

y_pred_val = model.predict_proba(x_val[features]).T[1]

y_pred_test = model.predict_proba(x_test[features]).T[1]

gini_train = np.round(gini(y_train, y_pred_train),3)

gini_val = np.round(gini(y_val, y_pred_val),3)

gini_test = np.round(gini(y_test, y_pred_test),3)

print('Джини бустинга на всех фичах:', 'train', gini_train, 'val', gini_val, 'test', gini_test)

Джини бустинга на всех фичах: train 0.904 val 0.865 test 0.825.

Интерпретировать фичи модели очень удобно с помощью библиотеки SHAP. Детально про алгоритмы расчета Shap написано много, посмотреть можно в следующих статьях:

https://habr.com/ru/articles/428213/

https://habr.com/ru/companies/wunderfund/articles/739744/

https://habr.com/ru/companies/ods/articles/599573/

https://habr.com/ru/companies/otus/articles/465329/

https://www.kaggle.com/code/dansbecker/shap-values

Ссылка на официальную документацию по Shap:

https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Catboost%20tutorial.html

Сутево SHAP-value отвечают на вопрос: “Насколько изменится предсказание модели для конкретного наблюдения по сравнению со средним прогнозом, если мы добавим данную фичу, учитывая все возможные комбинации / перестановки фичей”.

В первую очередь рекомендую смотреть график Shap.summary_plot. На нем показываются фичи на уровне модели в целом.

Код для вывода графика следующий:

explainer = shap.TreeExplainer(model)

shap_values = explainer(x_train[features].iloc[:,:])

shap.summary_plot(shap_values, x_train[features], max_display = len(features))

Сортировка на графике идет сверху вниз от самой важной фичи до менее важных по среднему абсолютному значению Shap.values.

Но нам также интересна дополнительная информация в виде значений важности фичей, которая выводится с помощью отдельного графика:

shap.plots.bar(shap_values, max_display = len(features))

Поэтому чтобы не создавать много графиков мы можем улучшить shap.summary_plot, добавив непосредственно на него среднее абсолютное значение Shap.values. Код в данном случае будет выглядеть следующим образом:

# Инициализация explainer

explainer = shap.TreeExplainer(model)

# Вычисление SHAP значений

shap_values = explainer(x_train[features].iloc[:,:])

# Вычисление mean(|SHAP value|) для каждого признака

mean_abs_shap = np.mean(np.abs(shap_values.values), axis=0)

# Создание подписей с SHAP значениями

feature_labels = [f"{features[i]} ({mean_abs_shap[i]*100:.2f})" 

                 for i in range(len(features))]

# Построение графика с модифицированными подписями

shap.summary_plot(shap_values, x_train[features], feature_names=feature_labels,  # Используем названия фичей как подписи

                 max_display=len(features), show=False)

# Добавляем заголовок с пояснением

plt.title("SHAP Summary Plot: значения в скобках - mean(|SHAP value|)", y = 1.05)

plt.tight_layout()

plt.show()

Shap-графики: как наглядно объяснить заказчику логику работы модели - 2

Посмотрев на этот график, можно определить степень влияния и направление влияния по каждой фиче.

В качестве напоминания о том, как интерпретировать графики Shap:

1. Каждая линия на графике представляет собой фактор модели.

2. Каждая точка для определенного фактора — это отдельный прогноз в выборке.

3. Расположение по оси Х меньше или больше нуля показывает, увеличивает ли фича или уменьшает прогноз относительно среднего по выборке.

4. Также показывает значительность влияния — чем дальше от нуля, тем соответственно значительнее было влияние фичи на конкретное предсказание. Например, если наблюдение по определенной фиче имеет значение Shap равное +0,1%, это означает, что значение фичи для данного наблюдения приводит к увеличению значения Shap на эту величину.

5. Справа располагается шкала значений фичей. Точка красного цвета означает, что значение фичи очень высокое, синего цвета — низкое значение фичи. Если множество прогнозов дают похожий результат для данной фичи, то это приводит к тому, что линия становится намного шире (точки начинают накапливаться).

Анализируя ось x возникает логичный вопрос — а в каких единицах измерения выводится график shap.summary_plot?

В shap.TreeExplainer есть следующие параметры, установленные по умолчанию:

feature_perturbation='tree_path_dependent', model_output = 'raw'.

Для бинарной классификации Shap-значения при параметрах по умолчанию показывают, насколько каждый признак отклоняет предсказание модели от среднего значения в пространстве логарифмических шансов (log-odds).

Базовое значение (среднее): explainer.expected_value (средний log-odds по выборке).

Формула: log-odds=log⁡(p /(p1 − p)), где p — вероятность положительного класса.

Но всегда интересно посмотреть график Shap.summary_plot именно в вероятностях. В таком случае есть два варианта. Первый — указать feature_perturbation=”interventional”, model_output=”probability”. Но данный вариант в некоторых версиях Shap может не сработать, поэтому можно пойти обходным путем через ручное преобразование log-odds в вероятности.

Код и график при переводе на вероятности см. ниже:

try:

    # Способ 1: Прямое получение вероятностей (может не работать в некоторых версиях), работает медленнее

    explainer = shap.TreeExplainer(model, feature_perturbation="interventional", model_output="probability")

    shap_values = explainer(x_train[features].iloc[:,:])

except ValueError:

    # Способ 2: Обходной путь через ручное преобразование

    explainer = shap.TreeExplainer(model)

    shap_values_raw = explainer(x_train[features].iloc[:,:])

    expected_value = explainer.expected_value

    shap_values = 1/(1+np.exp(-(expected_value + shap_values_raw.values))) - 1/(1+np.exp(-expected_value))

shap.summary_plot(shap_values, x_train[features], max_display = len(features))

Shap-графики: как наглядно объяснить заказчику логику работы модели - 5

Но Shap.summary_plot — очень верхнеуровневый график, на нем не всегда понятны детали работы фичи. 

Например, по наиболее важной фиче total_charge видно, что с ростом значения фичи вероятность таргета также существенно растет, а вот что происходит при низких и средних значениях — из графика не вполне понятно. Поэтому многие фичи, особенно с нелинейной зависимостью лучше еще дополнительно посмотреть на отдельном графике shap.plots.scatter(shap_values[:,x])

В данный график можно добавить много полезной информации, и эту доработку удобнее  сделать с помощью кастомного графика scatter plot. На нем детально видно изменение логики работы фичи в зависимости от ее значения. 

В нашем конкретном случае по фиче total_charge видно, что сначала некоторые наблюдения даже увеличивают вероятность срабатывания таргета, потом происходит снижение, а далее резкий рост предсказаний. 

Дополнительно можно отразить на графике: 

1) Наблюдения соответствующие таргету (синие точки);

2) Динамика средних значений изменения предсказаний в зависимости от значений фичи;

3) Распределение значений фичи по выборке;

4) Среднее / медиана значений фичи.

# Ручное создание scatter plot с кастомными настройками

= 'total_charge'

shap_values = explainer(x_train[features].iloc[:,:])

# Получаем данные

x_data = x_train[x]

y_data = shap_values[:, x].values

y_labels = y_train

# Создаем фигуру и основную ось

fig, ax1 = plt.subplots(figsize=(11, 6))

fig.patch.set_facecolor('white')

ax1.set_facecolor('white')

colors = np.where(y_train == 1, 'blue', 'red')

scatter = ax1.scatter(

    x=x_train[x], y=shap_values[:, x].values, c=colors, cmap='coolwarm', alpha=1, edgecolors='w', linewidths=0.1)

#Добавляем линию средних значений

# Разбиваем на 20 бинов для расчета среднего

bin_means, bin_edges,  = stats.binnedstatistic(x_data, y_data, statistic='mean', bins=20)

bin_centers = bin_edges[:-1] + np.diff(bin_edges)/2

# Рисуем линию средних

ax1.plot(bin_centers, bin_means, color='black', linewidth=2.2, linestyle='-.', label='Среднее SHAP')

# Настройки основной оси

ax1.set_ylabel('SHAP value', fontsize=14)

ax1.grid(True, alpha=0.3)

ax1.axvline(np.median(x_data), color='black', linestyle='-.', label='Медиана', linewidth=1)

#Добавляем вторую ось для распределения

ax2 = ax1.twinx()  # Создаем вторую ось с общим X

ax2.hist(x_data, bins=50, color='skyblue', alpha=0.2, density=True)

ax2.set_ylabel('Плотность распределения', fontsize=14)

ax2.grid(False)  # Отключаем сетку для второй оси

# Общие элементы

plt.title(f'SHAP значения и распределение', fontsize=14)

Text(0.5, 1.0, ‘SHAP-значения и распределение’)

Shap-графики: как наглядно объяснить заказчику логику работы модели - 6

Графики PDP (Partial Dependence Plot) и ICE (Individual Conditional Expectation) могут хорошо дополнять Shap.Summary_plot и Scatter_plot с точки зрения раскрытия бизнес-логики фичи. PDP отвечает на вопрос: как в среднем меняется прогноз модели, если зафиксировать определенное значение исследуемой фичи, усредняя влияние всех остальных. ICE отвечает на вопрос: «Как меняется прогноз по конкретному наблюдению при изменении значения фичи, оставляя остальные неизменными». Более детальную информацию про графики PDP и ICE можно посмотреть по следующей ссылке: https://scikit-learn.org/stable/modules/partial_dependence.html

Методика расчета PDP (усредняет все кривые, показывая общий тренд):
Шаг 1: фиксируем возможные значения фичи в выборке;
Шаг 2: для каждого значения фичи подставляем его во все наблюдения и пересчитываем прогноз по каждому наблюдению;
Шаг 3: далее усредняем значения оценок по всем наблюдениям каждого значения фичи;
Шаг 4: на данных усредненных значениях прогнозов строим график PDP.

Методика расчета ICE (показывает индивидуальные зависимости для каждого наблюдения):
Шаг 1: фиксируем диапазон значений для фичи;
Шаг 2: для каждого наблюдения подставляем все значения данного признака и пересчитываем прогноз по каждому наблюдению;
Шаг 3: выводим кривую для каждого наблюдения.

Каждая линия на графике — это отдельное наблюдение в выборке с посчитанным Shap-value в разрезе каждого значения исследуемой фичи. PDP для total_charge показывает, что с ростом значения фичи предсказание в среднем немного снижается. Но посмотрев на ICE мы видим, что для основной части клиентов предсказание не изменилось, а для небольшой доли наблюдений предсказание снизилось очень сильно. Именно за счет данных наблюдений сложилась такая ситуация с бизнес-логикой. Соответственно можно провести отдельный анализ и понять причины данных изменений.

Т.е. с помощью ICE можно находить аномальные наблюдения или клиентов, у которых зависимости резко отличаются от общих.
Недостаток ICE заключается в том, что на больших датасетах бывает трудно интерпретировать (много пересекающихся линий). Также ICE показывает логику работы фичи, но не показывает статистическую значимость эффекта.

У графика ICE есть два варианта визуализации:
При ‘centered’ = True из всех значений кривой вычитается предсказание при стартовом значении фичи, поэтому все кривые в точке старта графика формально начинаются с Y = 0. Этот вариант может быть полезным для упрощения анализа, т.к. все кривые выровнены по одной стартовой точке, что фильтрует индивидуальные смещения, акцентируя внимание именно на форме зависимости.

Если отключить нормализацию ‘centered’ = False, кривые будут начинаться с фактических значений предсказаний для каждой кривой.
Оба варианта графика представим ниже.

from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
features_info = {"features": ['total_charge'], "kind": "both", "centered": False,}
common_params = {"subsample": 500, "n_jobs": 2, "grid_resolution": 30, "random_state": 42, 'method':'auto'} display = PartialDependenceDisplay.from_estimator(model, x_train[features].iloc[:,:], features_info, ax = ax, common_params)
= display.figure.suptitle("ICE and PDP total_charge", fontsize=12)

Shap-графики: как наглядно объяснить заказчику логику работы модели - 7

from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
features_info = {"features": ['total_charge'], "kind": "both", "centered": True,}
common_params = {"subsample": 500, "n_jobs": 2, "grid_resolution": 30, "random_state": 42, 'method':'auto'} display = PartialDependenceDisplay.from_estimator(model, x_train[features].iloc[:,:], features_info, ax = ax, common_params)
= display.figure.suptitle("ICE and PDP total_charge (centered)", fontsize=12)

Shap-графики: как наглядно объяснить заказчику логику работы модели - 8

Графики, которые мы показывали, хорошо подходят под непрерывные значения фичи, но мало подходят под категориальные.
Как мы видели на Shap.summary_plot, категориальные фичи не были раскрашены, детально понять ситуацию было сложно.
График plotbar поможет детальнее раскрыть логику работы для категориальных фичей.
График plt.boxplot (ящик с усами) визуализирует основные описательные статистики распределения данных. На нем представлена следующая информация:
1. Ящик (Box)
Границы ящика: Нижняя граница (Q1) — первый квартиль (25-й процентиль), Верхняя граница (Q3) — третий квартиль (75-й процентиль).
Линия внутри ящика (оранжевая линия): Медиана (Q2, 50-й процентиль).
Ширина ящика показывает межквартильный размах (Interquartile range) (IQR = Q3 – Q1) — диапазон, где сосредоточено 50% данных.
2. Усы (Whiskers)
Верхний ус: Q3 + 1.5 IQR (максимальное значение, не считая выбросов). Нижний ус: Обычно Q1 – 1.5 IQR (минимальное значение, не считая выбросов). Усы могут дополнительно настраиваться (например, до 95% процентилей). Отдельно также добавим в каждый ящик линию средних значений (зеленый пунктир) и отдельно общую линию (красный пунктир), соединяющую средние значения между ящиками, чтобы наглядно видеть динамику. Ниже под график добавим гистограмму с распределением количества значений в каждом ящике.

x = 'customer_service_calls'
values = shap_values[:,features.index(x)]
data = x_train[x]
categories = sorted(list(set(x_train[x])))
groups = []
means = []
for category in categories: relevant_values = values.values[values.data == category] groups.append(relevant_values) means.append(np.mean(relevant_values))
labels = [u for u in categories] # Создаём сетку графиков: верх — boxplot, низ — гистограмма
fig, (ax_box, ax_hist) = plt.subplots(nrows=2, sharex=True, gridspec_kw={"height_ratios": (0.6, 0.4)}, figsize=(8, 6)) #plt.figure(figsize=(8, 5))
ax_box.boxplot(groups, labels = labels, showmeans=True, meanline=True)
ax_box.set_ylabel('Shap values', size=15)
ax_box.set_xlabel(x, size=15)
# Добавление линии, соединяющей средние
ax_box.plot(range(1, len(categories) + 1), means, marker='.', color='red', linestyle='--', linewidth=1, label='Средние')
ax_box.grid(True, linestyle='--', alpha=0.5) # Добавляем подписи
for i, category in enumerate(categories): ax_hist.text(i+1, plt.ylim()[0], f"n={len(groups[i])}", ha="center", va='bottom') # Гистограмма количества наблюдений на нижней панели
counts = [len(x) for x in groups]
ax_hist.bar(range(1, len(categories)+1), counts, color="skyblue", edgecolor="black")
ax_hist.set_ylabel("Количество")
plt.tight_layout()
plt.show()

Shap-графики: как наглядно объяснить заказчику логику работы модели - 9

Можно провести дополнительную аналитику и посчитать корреляцию между значениями фичи и shap-значениями. Корреляция значений признаков с shap показывает, как связаны исходные значения фичей с их вкладом в предсказание модели. Соответственно, при сильной положительной корреляции чем больше значение фичи, тем сильнее он увеличивает предсказание. При отрицательной корреляции, в свою очередь, чем больше значение фичи, тем сильнее он уменьшает предсказание. Если корреляция околонулевая, то нет линейной зависимости между значением фичи и его влиянием. Неожиданно низкие корреляции могут говорить о нелинейных зависимостях (например, U-образная кривая) или о сильных взаимодействиях с другими фичами. Для примера покажем график корреляции в нашем кейсе. Видим, что по фиче total_charge достаточно высокая корреляция, что говорит о линейной зависимости фичи, а по фиче total_intl_calls —корреляция очень низкая.

# Вычисление корреляций
features.remove('customer_service_calls')
correlations = {}
for i, feat in enumerate(features): correlations[feat] = np.corrcoef(x_train[feat], shap_values.values[:,i])[0,1]
# График
plt.figure(figsize=(8, 2))
sns.barplot(x = list(correlations.values()), y = list(correlations.keys()), palette = "vlag")
plt.axvline(0, color = 'black', linestyle = '--')
plt.title('Корреляция значений фичи с Shap')
plt.xlabel('Коэффициент корреляции Пирсона')
plt.xlim(-1, 1)
plt.show()
features = fe_stats[0:5].feature_names.to_list()

Shap-графики: как наглядно объяснить заказчику логику работы модели - 10

Чтобы определить взаимодействие total_intl_calls с другими фичами, посмотрим на нее более детально на графике Shap.dependence_plot(). Можно вывести характер взаимодействия total_intl_calls с другой интересующей нас фичей прямо указав ее в interaction_index. Если такую не указывать, то shap.dependence_plot по умолчанию выберет для раскраски точек на графике фичу с наибольшим взаимодействием. Чем сильнее Shap-значения основной фичи меняются в зависимости от значений другой фичи, тем выше взаимодействие. Формально это оценивается через дисперсию Shap-значений одной фичи, объяснённых другой фичей. Например, если при высоком значении international_plan Shap-значения total_intl_calls резко вырастут, а при низком снизятся,
то international_plan автоматически будет выбрана для раскраски точек (см. пример на графике).

# не указываем фичу для проверки взаимодействия, выводится фича с наибольшим взаимодействием - international_plan
shap.dependence_plot('total_intl_calls', shap_values.values, x_train[features].iloc[:,:])

Shap-графики: как наглядно объяснить заказчику логику работы модели - 11

# указываем фичу с которой хотим проверить взаимодействие - 'number_vmail_messages'
shap.dependence_plot(‘total_intl_calls’, shap_values.values, x_train[features].iloc[:,:], interaction_index = ‘number_vmail_messages’)

Shap-графики: как наглядно объяснить заказчику логику работы модели - 12

Shap.dependence_plot позволяет уведить неочевидные зависимости, например, что наличие international_plan по разному влияет на прогноз по клиентам с небольшим и большим количеством звонков total_intl_calls. Более детально про Shap.dependence_plot можно прочитать в документации: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/NHANES I Survival Model.html

На этом завершаем. Надеюсь, что данная статья была полезной для вас.
Какие еще интересные возможности есть у графиков, которые были бы полезны для исследований, пишите в комментариях.

Автор: andrey_boyarenkov

Источник

Rambler's Top100