This page was generated from examples/cfproto_cat_adult_ord.ipynb.

Counterfactual explanations with ordinally encoded categorical variables

This example notebook illustrates how to obtain counterfactual explanations for instances with a mixture of ordinally encoded categorical and numerical variables. A more elaborate notebook highlighting additional functionality can be found here. We generate counterfactuals for instances in the adult dataset where we predict whether a person’s income is above or below $50k.

[1]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Input, Embedding, Concatenate, Reshape, Dropout, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.preprocessing import OneHotEncoder
from time import time
from alibi.datasets import fetch_adult
from alibi.explainers import CounterFactualProto

Load adult dataset

The fetch_adult function returns a Bunch object containing the features, the targets, the feature names and a mapping of the categories in each categorical variable.

[2]:
adult = fetch_adult()
data = adult.data
target = adult.target
feature_names = adult.feature_names
category_map_tmp = adult.category_map
target_names = adult.target_names

Define shuffled training and test set:

[3]:
def set_seed(s=0):
    np.random.seed(s)
    tf.set_random_seed(s)
[4]:
set_seed()
data_perm = np.random.permutation(np.c_[data, target])
X = data_perm[:,:-1]
y = data_perm[:,-1]
[5]:
idx = 30000
y_train, y_test = y[:idx], y[idx+1:]

Reorganize data so categorical features come first:

[6]:
X = np.c_[X[:, 1:8], X[:, 11], X[:, 0], X[:, 8:11]]

Adjust feature_names and category_map as well:

[7]:
feature_names = feature_names[1:8] + feature_names[11:12] + feature_names[0:1] + feature_names[8:11]
print(feature_names)
['Workclass', 'Education', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Country', 'Age', 'Capital Gain', 'Capital Loss', 'Hours per week']
[8]:
category_map = {}
for i, (_, v) in enumerate(category_map_tmp.items()):
    category_map[i] = v

Create a dictionary with as keys the categorical columns and values the number of categories for each variable in the dataset. This dictionary will later be used in the counterfactual explanation.

[9]:
cat_vars_ord = {}
n_categories = len(list(category_map.keys()))
for i in range(n_categories):
    cat_vars_ord[i] = len(np.unique(X[:, i]))
print(cat_vars_ord)
{0: 9, 1: 7, 2: 4, 3: 9, 4: 6, 5: 5, 6: 2, 7: 11}

Preprocess data

Scale numerical features between -1 and 1:

[10]:
X_num = X[:, -4:].astype(np.float32, copy=False)
xmin, xmax = X_num.min(axis=0), X_num.max(axis=0)
rng = (-1., 1.)
X_num_scaled = (X_num - xmin) / (xmax - xmin) * (rng[1] - rng[0]) + rng[0]
X_num_scaled_train = X_num_scaled[:idx, :]
X_num_scaled_test = X_num_scaled[idx+1:, :]

Combine numerical and categorical data:

[11]:
X = np.c_[X[:, :-4], X_num_scaled].astype(np.float32, copy=False)
X_train, X_test = X[:idx, :], X[idx+1:, :]
print(X_train.shape, X_test.shape)
(30000, 12) (2560, 12)

Train a neural net

The neural net will use entity embeddings for the categorical variables.

[12]:
def nn_ord():

    x_in = Input(shape=(12,))
    layers_in = []

    # embedding layers
    for i, (_, v) in enumerate(cat_vars_ord.items()):
        emb_in = Lambda(lambda x: x[:, i:i+1])(x_in)
        emb_dim = int(max(min(np.ceil(.5 * v), 50), 2))
        emb_layer = Embedding(input_dim=v+1, output_dim=emb_dim, input_length=1)(emb_in)
        emb_layer = Reshape(target_shape=(emb_dim,))(emb_layer)
        layers_in.append(emb_layer)

    # numerical layers
    num_in = Lambda(lambda x: x[:, -4:])(x_in)
    num_layer = Dense(16)(num_in)
    layers_in.append(num_layer)

    # combine
    x = Concatenate()(layers_in)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x_out = Dense(2, activation='softmax')(x)

    nn = Model(inputs=x_in, outputs=x_out)
    nn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return nn
[13]:
set_seed()
nn = nn_ord()
nn.summary()
nn.fit(X_train, to_categorical(y_train), batch_size=128, epochs=20, verbose=0)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 12)]         0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 5)         50          lambda[0][0]
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 4)         32          lambda_1[0][0]
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 1, 2)         10          lambda_2[0][0]
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 1, 5)         50          lambda_3[0][0]
__________________________________________________________________________________________________
embedding_4 (Embedding)         (None, 1, 3)         21          lambda_4[0][0]
__________________________________________________________________________________________________
embedding_5 (Embedding)         (None, 1, 3)         18          lambda_5[0][0]
__________________________________________________________________________________________________
embedding_6 (Embedding)         (None, 1, 2)         6           lambda_6[0][0]
__________________________________________________________________________________________________
embedding_7 (Embedding)         (None, 1, 6)         72          lambda_7[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 4)            0           input_1[0][0]
__________________________________________________________________________________________________
reshape (Reshape)               (None, 5)            0           embedding[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 4)            0           embedding_1[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 2)            0           embedding_2[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 5)            0           embedding_3[0][0]
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 3)            0           embedding_4[0][0]
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 3)            0           embedding_5[0][0]
__________________________________________________________________________________________________
reshape_6 (Reshape)             (None, 2)            0           embedding_6[0][0]
__________________________________________________________________________________________________
reshape_7 (Reshape)             (None, 6)            0           embedding_7[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 16)           80          lambda_8[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 46)           0           reshape[0][0]
                                                                 reshape_1[0][0]
                                                                 reshape_2[0][0]
                                                                 reshape_3[0][0]
                                                                 reshape_4[0][0]
                                                                 reshape_5[0][0]
                                                                 reshape_6[0][0]
                                                                 reshape_7[0][0]
                                                                 dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 60)           2820        concatenate[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 60)           0           dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 60)           3660        dropout[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 60)           0           dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 60)           3660        dropout_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 60)           0           dense_3[0][0]
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 2)            122         dropout_2[0][0]
==================================================================================================
Total params: 10,601
Trainable params: 10,601
Non-trainable params: 0
__________________________________________________________________________________________________
[13]:
<tensorflow.python.keras.callbacks.History at 0x7fa7c32fb2e8>

Generate counterfactual

Original instance:

[14]:
X = X_test[0].reshape((1,) + X_test[0].shape)

Initialize counterfactual parameters:

[15]:
shape = X.shape
beta = .01
c_init = 1.
c_steps = 5
max_iterations = 500
rng = (-1., 1.)  # scale features between -1 and 1
rng_shape = (1,) + data.shape[1:]
feature_range = ((np.ones(rng_shape) * rng[0]).astype(np.float32),
                 (np.ones(rng_shape) * rng[1]).astype(np.float32))

Initialize explainer. Since the Embedding layers in tf.keras do not let gradients propagate through, we will only make use of the model’s predict function, treat it as a black box and perform numerical gradient calculations.

[16]:
set_seed()

# define predict function
predict_fn = lambda x: nn.predict(x)

cf = CounterFactualProto(predict_fn,
                         shape,
                         beta=beta,
                         cat_vars=cat_vars_ord,
                         max_iterations=max_iterations,
                         feature_range=feature_range,
                         c_init=c_init,
                         c_steps=c_steps,
                         eps=(.01, .01)  # perturbation size for numerical gradients
                        )

Fit explainer. Please check the documentation for more info about the optional arguments.

[17]:
cf.fit(X_train, d_type='abdm', disc_perc=[25, 50, 75])
[17]:
CounterFactualProto(meta={
    'name': 'CounterFactualProto',
    'type': ['blackbox', 'tensorflow', 'keras'],
    'explanations': ['local'],
    'params': {
        'kappa': 0.0,
        'beta': 0.01,
        'feature_range': (
            array([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]],
      dtype=float32),
            array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)
        ),
        'gamma': 0.0,
        'theta': 0.0,
        'cat_vars': {
            0: 9,
            1: 7,
            2: 4,
            3: 9,
            4: 6,
            5: 5,
            6: 2,
            7: 11
        },
        'ohe': False,
        'use_kdtree': False,
        'learning_rate_init': 0.01,
        'max_iterations': 500,
        'c_init': 1.0,
        'c_steps': 5,
        'eps': (0.01, 0.01),
        'clip': (-1000.0, 1000.0),
        'update_num_grad': 1,
        'write_dir': None,
        'shape': (1, 12),
        'is_model': False,
        'is_model_keras': False,
        'is_ae': False,
        'is_ae_keras': False,
        'is_enc': False,
        'is_enc_keras': False,
        'enc_or_kdtree': False,
        'is_cat': True,
        'trustscore_kwargs': None,
        'd_type': 'abdm',
        'w': None,
        'disc_perc': [25, 50, 75],
        'standardize_cat_vars': False,
        'smooth': 1.0,
        'center': True,
        'update_feature_range': True
    }
})

Explain instance:

[18]:
set_seed()
explanation = cf.explain(X)

Helper function to more clearly describe explanations:

[19]:
def describe_instance(X, explanation, eps=1e-2):
    print('Original instance: {}  -- proba: {}'.format(target_names[explanation.orig_class],
                                                       explanation.orig_proba[0]))
    print('Counterfactual instance: {}  -- proba: {}'.format(target_names[explanation.cf['class']],
                                                             explanation.cf['proba'][0]))
    print('\nCounterfactual perturbations...')
    print('\nCategorical:')
    X_orig_ord = X
    X_cf_ord = explanation.cf['X']
    delta_cat = {}
    for i, (_, v) in enumerate(category_map.items()):
        cat_orig = v[int(X_orig_ord[0, i])]
        cat_cf = v[int(X_cf_ord[0, i])]
        if cat_orig != cat_cf:
            delta_cat[feature_names[i]] = [cat_orig, cat_cf]
    if delta_cat:
        for k, v in delta_cat.items():
            print('{}: {}  -->   {}'.format(k, v[0], v[1]))
    print('\nNumerical:')
    delta_num = X_cf_ord[0, -4:] - X_orig_ord[0, -4:]
    n_keys = len(list(cat_vars_ord.keys()))
    for i in range(delta_num.shape[0]):
        if np.abs(delta_num[i]) > eps:
            print('{}: {:.2f}  -->   {:.2f}'.format(feature_names[i+n_keys],
                                            X_orig_ord[0,i+n_keys],
                                            X_cf_ord[0,i+n_keys]))
[20]:
describe_instance(X, explanation)
Original instance: <=50K  -- proba: [0.8347201  0.16527994]
Counterfactual instance: >50K  -- proba: [0.47529536 0.52470464]

Counterfactual perturbations...

Categorical:
Occupation: White-Collar  -->   Professional

Numerical:
Capital Gain: -1.00  -->   -0.87

The person’s incomce is predicted to be above $50k by increasing his or her capital gain and changing occupation to Professional.