This page was generated from examples/cfrl_adult.ipynb.

Counterfactual with Reinforcement Learning (CFRL) on Adult Census

This method is described in Model-agnostic and Scalable Counterfactual Explanations via Reinforcement Learning and can generate counterfactual instances for any black-box model. The usual optimization procedure is transformed into a learnable process allowing to generate batches of counterfactual instances in a single forward pass even for high dimensional data. The training pipeline is model-agnostic and relies only on prediction feedback by querying the black-box model. Furthermore, the method allows target and feature conditioning.

We exemplify the use case for the TensorFlow backend. This means that all models: the autoencoder, the actor and the critic are TensorFlow models. Our implementation supports PyTorch backend as well.

CFRL uses Deep Deterministic Policy Gradient (DDPG) by interleaving a state-action function approximator called critic, with a learning an approximator called actor to predict the optimal action. The method assumes that the critic is differentiable with respect to the action argument, thus allowing to optimize the actor’s parameters efficiently through gradient-based methods.

The DDPG algorithm requires two separate networks, an actor \(\mu\) and a critic \(Q\). Given the encoded representation of the input instance \(z = enc(x)\), the model prediction \(y_M\), the target prediction \(y_T\) and the conditioning vector \(c\), the actor outputs the counterfactual’s latent representation \(z_{CF} = \mu(z, y_M, y_T, c)\). The decoder then projects the embedding \(z_{CF}\) back to the original input space, followed by optional post-processing.

The training step consists of simultaneously optimizing the actor and critic networks. The critic regresses on the reward \(R\) determined by the model prediction, while the actor maximizes the critic’s output for the given instance through \(L_{max}\). The actor also minimizes two objectives to encourage the generation of sparse, in-distribution counterfactuals. The sparsity loss \(L_{sparsity}\) operates on the decoded counterfactual \(x_{CF}\) and combines the \(L_1\) loss over the standardized numerical features and the \(L_0\) loss over the categorical ones. The consistency loss \(L_{consist}\) aims to encode the counterfactual \(x_{CF}\) back to the same latent representation where it was decoded from and helps to produce in-distribution counterfactual instances. Formally, the actor’s loss can be written as: \(L_{actor} = L_{max} + \lambda_{1}L_{sparsity} + \lambda_{2}L_{consistency}\)

This example will use the xgboost library, which can be installed with:

[ ]:
!pip install xgboost
[1]:
import os
import numpy as np
import pandas as pd
from copy import deepcopy
from typing import List, Tuple, Dict, Callable

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

from alibi.explainers import CounterfactualRLTabular, CounterfactualRL
from alibi.datasets import fetch_adult
from alibi.models.tensorflow.autoencoder import HeAE
from alibi.models.tensorflow.actor_critic import Actor, Critic
from alibi.models.tensorflow.cfrl_models import ADULTEncoder, ADULTDecoder
from alibi.explainers.cfrl_base import Callback
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor, get_statistics, \
    get_conditional_vector, apply_category_mapping

Load Adult Census Dataset

[2]:
# Fetch adult dataset
adult = fetch_adult()

# Separate columns in numerical and categorical.
categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]
categorical_ids = list(adult.category_map.keys())

numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]
numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]

# Split data into train and test
X, Y = adult.data, adult.target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=13)

Train black-box classifier

[3]:
# Define numerical standard scaler.
num_transf = StandardScaler()

# Define categorical one-hot encoder.
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in adult.category_map.values()],
    handle_unknown="ignore"
)

# Define column transformer
preprocessor = ColumnTransformer(
    transformers=[
        ("cat", cat_transf, categorical_ids),
        ("num", num_transf, numerical_ids),
    ],
    sparse_threshold=0
)
[4]:
# Fit preprocessor.
preprocessor.fit(X_train)

# Preprocess train and test dataset.
X_train_ohe = preprocessor.transform(X_train)
X_test_ohe = preprocessor.transform(X_test)
[5]:
# Select one of the below classifiers.
# clf = XGBClassifier(min_child_weight=0.5, max_depth=3, gamma=0.2)
# clf = LogisticRegression(C=10)
# clf = DecisionTreeClassifier(max_depth=10, min_samples_split=5)
clf = RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)

# Fit the classifier.
clf.fit(X_train_ohe, Y_train)
[5]:
RandomForestClassifier(max_depth=15, min_samples_split=10, n_estimators=50)

Define the predictor (black-box)

Now that we’ve trained the classifier, we can define the black-box model. Note that the output of the black-box is a distribution which can be either a soft-label distribution (probabilities/logits for each class) or a hard-label distribution (one-hot encoding). Internally, CFRL takes the argmax. Moreover the output DOES NOT HAVE TO BE DIFFERENTIABLE.

[6]:
# Define prediction function.
predictor = lambda x: clf.predict_proba(preprocessor.transform(x))
[7]:
# Compute accuracy.
acc = accuracy_score(y_true=Y_test, y_pred=predictor(X_test).argmax(axis=1))
print("Accuracy: %.3f" % acc)
Accuracy: 0.862

Define and train autoencoder

Instead of directly modelling the perturbation vector in the potentially high-dimensional input space, we first train an autoencoder. The weights of the encoder are frozen and the actor applies the counterfactual perturbations in the latent space of the encoder. The pre-trained decoder maps the counterfactual embedding back to the input feature space.

The autoencoder follows a standard design. The model is composed from two submodules, the encoder and the decoder. The forward pass consists of passing the input to the encoder, obtain the input embedding and pass the embedding through the decoder.

class HeAE(keras.Model):
    def __init__(self, encoder: keras.Model, decoder: keras.Model, **kwargs) -> None:
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, x: tf.Tensor, **kwargs):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

The heterogeneous variant used in this example uses an additional type checking to ensure that the output of the decoder is a list of tensors.

Heterogeneous dataset require special treatment. In this work we modeled the numerical features by normal distributions with constant standard deviation and categorical features by categorical distributions. Due to the choice of feature modeling, some numerical features can end up having different types than the original numerical features. For example, a feature like Age having the type of int can become a float due to the autoencoder reconstruction (e.g., Age=26 -> Age=26.3). This behavior can be undesirable. Thus we performed casting when process the output of the autoencoder (decoder component).

[8]:
# Define attribute types, required for datatype conversion.
feature_types = {"Age": int, "Capital Gain": int, "Capital Loss": int, "Hours per week": int}

# Define data preprocessor and inverse preprocessor. The invers preprocessor include datatype conversions.
heae_preprocessor, heae_inv_preprocessor = get_he_preprocessor(X=X_train,
                                                               feature_names=adult.feature_names,
                                                               category_map=adult.category_map,
                                                               feature_types=feature_types)

# Define trainset
trainset_input = heae_preprocessor(X_train).astype(np.float32)
trainset_outputs = {
    "output_1": X_train_ohe[:, :len(numerical_ids)]
}

for i, cat_id in enumerate(categorical_ids):
    trainset_outputs.update({
        f"output_{i+2}": X_train[:, cat_id]
    })

trainset = tf.data.Dataset.from_tensor_slices((trainset_input, trainset_outputs))
trainset = trainset.shuffle(1024).batch(128, drop_remainder=True)
[9]:
# Define autoencoder path and create dir if it doesn't exist.
heae_path = os.path.join("tensorflow", "ADULT_autoencoder")
if not os.path.exists(heae_path):
    os.makedirs(heae_path)

# Define constants.
EPOCHS = 50              # epochs to train the autoencoder
HIDDEN_DIM = 128         # hidden dimension of the autoencoder
LATENT_DIM = 15          # define latent dimension

# Define output dimensions.
OUTPUT_DIMS = [len(numerical_ids)]
OUTPUT_DIMS += [len(adult.category_map[cat_id]) for cat_id in categorical_ids]

# Define the heterogeneous auto-encoder.
heae = HeAE(encoder=ADULTEncoder(hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM),
            decoder=ADULTDecoder(hidden_dim=HIDDEN_DIM, output_dims=OUTPUT_DIMS))

# Define loss functions.
he_loss = [keras.losses.MeanSquaredError()]
he_loss_weights = [1.]

# Add categorical losses.
for i in range(len(categorical_names)):
    he_loss.append(keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    he_loss_weights.append(1./len(categorical_names))

# Define metrics.
metrics = {}
for i, cat_name in enumerate(categorical_names):
    metrics.update({f"output_{i+2}": keras.metrics.SparseCategoricalAccuracy()})

# Compile model.
heae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
             loss=he_loss,
             loss_weights=he_loss_weights,
             metrics=metrics)

if len(os.listdir(heae_path)) == 0:
    # Fit and save autoencoder.
    heae.fit(trainset, epochs=EPOCHS)
    heae.save(heae_path, save_format="tf")
else:
    # Load the model.
    heae = keras.models.load_model(heae_path, compile=False)
Epoch 1/50
203/203 [==============================] - 2s 3ms/step - loss: 1.1512 - output_1_loss: 0.0282 - output_2_loss: 1.3514 - output_3_loss: 1.4338 - output_4_loss: 0.8165 - output_5_loss: 1.8348 - output_6_loss: 1.2185 - output_7_loss: 0.8043 - output_8_loss: 0.4713 - output_9_loss: 1.0538 - output_2_sparse_categorical_accuracy: 0.6555 - output_3_sparse_categorical_accuracy: 0.4752 - output_4_sparse_categorical_accuracy: 0.7058 - output_5_sparse_categorical_accuracy: 0.3164 - output_6_sparse_categorical_accuracy: 0.5849 - output_7_sparse_categorical_accuracy: 0.7588 - output_8_sparse_categorical_accuracy: 0.7580 - output_9_sparse_categorical_accuracy: 0.8097
Epoch 2/50
203/203 [==============================] - 1s 3ms/step - loss: 0.2462 - output_1_loss: 0.0093 - output_2_loss: 0.2583 - output_3_loss: 0.3268 - output_4_loss: 0.1679 - output_5_loss: 0.3543 - output_6_loss: 0.2553 - output_7_loss: 0.2078 - output_8_loss: 0.0690 - output_9_loss: 0.2557 - output_2_sparse_categorical_accuracy: 0.9339 - output_3_sparse_categorical_accuracy: 0.9160 - output_4_sparse_categorical_accuracy: 0.9530 - output_5_sparse_categorical_accuracy: 0.9150 - output_6_sparse_categorical_accuracy: 0.9340 - output_7_sparse_categorical_accuracy: 0.9378 - output_8_sparse_categorical_accuracy: 0.9830 - output_9_sparse_categorical_accuracy: 0.9210
Epoch 3/50
203/203 [==============================] - 1s 3ms/step - loss: 0.1100 - output_1_loss: 0.0065 - output_2_loss: 0.1009 - output_3_loss: 0.1382 - output_4_loss: 0.0677 - output_5_loss: 0.1248 - output_6_loss: 0.1131 - output_7_loss: 0.0871 - output_8_loss: 0.0255 - output_9_loss: 0.1711 - output_2_sparse_categorical_accuracy: 0.9803 - output_3_sparse_categorical_accuracy: 0.9678 - output_4_sparse_categorical_accuracy: 0.9856 - output_5_sparse_categorical_accuracy: 0.9756 - output_6_sparse_categorical_accuracy: 0.9712 - output_7_sparse_categorical_accuracy: 0.9759 - output_8_sparse_categorical_accuracy: 0.9968 - output_9_sparse_categorical_accuracy: 0.9452
Epoch 4/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0711 - output_1_loss: 0.0051 - output_2_loss: 0.0652 - output_3_loss: 0.0857 - output_4_loss: 0.0348 - output_5_loss: 0.0779 - output_6_loss: 0.0630 - output_7_loss: 0.0597 - output_8_loss: 0.0137 - output_9_loss: 0.1280 - output_2_sparse_categorical_accuracy: 0.9863 - output_3_sparse_categorical_accuracy: 0.9797 - output_4_sparse_categorical_accuracy: 0.9935 - output_5_sparse_categorical_accuracy: 0.9839 - output_6_sparse_categorical_accuracy: 0.9857 - output_7_sparse_categorical_accuracy: 0.9834 - output_8_sparse_categorical_accuracy: 0.9985 - output_9_sparse_categorical_accuracy: 0.9593
Epoch 5/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0514 - output_1_loss: 0.0040 - output_2_loss: 0.0461 - output_3_loss: 0.0595 - output_4_loss: 0.0233 - output_5_loss: 0.0585 - output_6_loss: 0.0413 - output_7_loss: 0.0445 - output_8_loss: 0.0087 - output_9_loss: 0.0974 - output_2_sparse_categorical_accuracy: 0.9894 - output_3_sparse_categorical_accuracy: 0.9861 - output_4_sparse_categorical_accuracy: 0.9957 - output_5_sparse_categorical_accuracy: 0.9872 - output_6_sparse_categorical_accuracy: 0.9913 - output_7_sparse_categorical_accuracy: 0.9883 - output_8_sparse_categorical_accuracy: 0.9993 - output_9_sparse_categorical_accuracy: 0.9703
Epoch 6/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0406 - output_1_loss: 0.0034 - output_2_loss: 0.0371 - output_3_loss: 0.0458 - output_4_loss: 0.0180 - output_5_loss: 0.0471 - output_6_loss: 0.0315 - output_7_loss: 0.0342 - output_8_loss: 0.0061 - output_9_loss: 0.0774 - output_2_sparse_categorical_accuracy: 0.9906 - output_3_sparse_categorical_accuracy: 0.9893 - output_4_sparse_categorical_accuracy: 0.9968 - output_5_sparse_categorical_accuracy: 0.9896 - output_6_sparse_categorical_accuracy: 0.9929 - output_7_sparse_categorical_accuracy: 0.9912 - output_8_sparse_categorical_accuracy: 0.9995 - output_9_sparse_categorical_accuracy: 0.9786
Epoch 7/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0330 - output_1_loss: 0.0029 - output_2_loss: 0.0287 - output_3_loss: 0.0361 - output_4_loss: 0.0145 - output_5_loss: 0.0389 - output_6_loss: 0.0256 - output_7_loss: 0.0273 - output_8_loss: 0.0049 - output_9_loss: 0.0646 - output_2_sparse_categorical_accuracy: 0.9935 - output_3_sparse_categorical_accuracy: 0.9910 - output_4_sparse_categorical_accuracy: 0.9974 - output_5_sparse_categorical_accuracy: 0.9911 - output_6_sparse_categorical_accuracy: 0.9946 - output_7_sparse_categorical_accuracy: 0.9937 - output_8_sparse_categorical_accuracy: 0.9996 - output_9_sparse_categorical_accuracy: 0.9814
Epoch 8/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0278 - output_1_loss: 0.0025 - output_2_loss: 0.0241 - output_3_loss: 0.0291 - output_4_loss: 0.0124 - output_5_loss: 0.0335 - output_6_loss: 0.0214 - output_7_loss: 0.0241 - output_8_loss: 0.0040 - output_9_loss: 0.0538 - output_2_sparse_categorical_accuracy: 0.9949 - output_3_sparse_categorical_accuracy: 0.9927 - output_4_sparse_categorical_accuracy: 0.9978 - output_5_sparse_categorical_accuracy: 0.9926 - output_6_sparse_categorical_accuracy: 0.9952 - output_7_sparse_categorical_accuracy: 0.9946 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9847
Epoch 9/50
203/203 [==============================] - 1s 4ms/step - loss: 0.0237 - output_1_loss: 0.0022 - output_2_loss: 0.0198 - output_3_loss: 0.0241 - output_4_loss: 0.0109 - output_5_loss: 0.0286 - output_6_loss: 0.0187 - output_7_loss: 0.0206 - output_8_loss: 0.0034 - output_9_loss: 0.0457 - output_2_sparse_categorical_accuracy: 0.9955 - output_3_sparse_categorical_accuracy: 0.9941 - output_4_sparse_categorical_accuracy: 0.9981 - output_5_sparse_categorical_accuracy: 0.9935 - output_6_sparse_categorical_accuracy: 0.9954 - output_7_sparse_categorical_accuracy: 0.9953 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9876
Epoch 10/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0203 - output_1_loss: 0.0020 - output_2_loss: 0.0163 - output_3_loss: 0.0216 - output_4_loss: 0.0094 - output_5_loss: 0.0241 - output_6_loss: 0.0159 - output_7_loss: 0.0181 - output_8_loss: 0.0027 - output_9_loss: 0.0385 - output_2_sparse_categorical_accuracy: 0.9962 - output_3_sparse_categorical_accuracy: 0.9946 - output_4_sparse_categorical_accuracy: 0.9983 - output_5_sparse_categorical_accuracy: 0.9943 - output_6_sparse_categorical_accuracy: 0.9972 - output_7_sparse_categorical_accuracy: 0.9958 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9891
Epoch 11/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0178 - output_1_loss: 0.0017 - output_2_loss: 0.0142 - output_3_loss: 0.0182 - output_4_loss: 0.0089 - output_5_loss: 0.0211 - output_6_loss: 0.0140 - output_7_loss: 0.0163 - output_8_loss: 0.0024 - output_9_loss: 0.0333 - output_2_sparse_categorical_accuracy: 0.9971 - output_3_sparse_categorical_accuracy: 0.9957 - output_4_sparse_categorical_accuracy: 0.9981 - output_5_sparse_categorical_accuracy: 0.9952 - output_6_sparse_categorical_accuracy: 0.9974 - output_7_sparse_categorical_accuracy: 0.9963 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9912
Epoch 12/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0156 - output_1_loss: 0.0016 - output_2_loss: 0.0121 - output_3_loss: 0.0161 - output_4_loss: 0.0078 - output_5_loss: 0.0177 - output_6_loss: 0.0124 - output_7_loss: 0.0146 - output_8_loss: 0.0021 - output_9_loss: 0.0291 - output_2_sparse_categorical_accuracy: 0.9975 - output_3_sparse_categorical_accuracy: 0.9959 - output_4_sparse_categorical_accuracy: 0.9984 - output_5_sparse_categorical_accuracy: 0.9961 - output_6_sparse_categorical_accuracy: 0.9974 - output_7_sparse_categorical_accuracy: 0.9966 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9927
Epoch 13/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0138 - output_1_loss: 0.0015 - output_2_loss: 0.0102 - output_3_loss: 0.0141 - output_4_loss: 0.0071 - output_5_loss: 0.0155 - output_6_loss: 0.0112 - output_7_loss: 0.0127 - output_8_loss: 0.0020 - output_9_loss: 0.0257 - output_2_sparse_categorical_accuracy: 0.9981 - output_3_sparse_categorical_accuracy: 0.9966 - output_4_sparse_categorical_accuracy: 0.9985 - output_5_sparse_categorical_accuracy: 0.9968 - output_6_sparse_categorical_accuracy: 0.9982 - output_7_sparse_categorical_accuracy: 0.9972 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9934
Epoch 14/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0123 - output_1_loss: 0.0014 - output_2_loss: 0.0091 - output_3_loss: 0.0127 - output_4_loss: 0.0064 - output_5_loss: 0.0138 - output_6_loss: 0.0099 - output_7_loss: 0.0108 - output_8_loss: 0.0017 - output_9_loss: 0.0228 - output_2_sparse_categorical_accuracy: 0.9981 - output_3_sparse_categorical_accuracy: 0.9973 - output_4_sparse_categorical_accuracy: 0.9987 - output_5_sparse_categorical_accuracy: 0.9971 - output_6_sparse_categorical_accuracy: 0.9983 - output_7_sparse_categorical_accuracy: 0.9975 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9943
Epoch 15/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0110 - output_1_loss: 0.0013 - output_2_loss: 0.0077 - output_3_loss: 0.0112 - output_4_loss: 0.0057 - output_5_loss: 0.0120 - output_6_loss: 0.0087 - output_7_loss: 0.0101 - output_8_loss: 0.0017 - output_9_loss: 0.0200 - output_2_sparse_categorical_accuracy: 0.9989 - output_3_sparse_categorical_accuracy: 0.9981 - output_4_sparse_categorical_accuracy: 0.9988 - output_5_sparse_categorical_accuracy: 0.9976 - output_6_sparse_categorical_accuracy: 0.9985 - output_7_sparse_categorical_accuracy: 0.9973 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9952
Epoch 16/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0097 - output_1_loss: 0.0012 - output_2_loss: 0.0068 - output_3_loss: 0.0091 - output_4_loss: 0.0051 - output_5_loss: 0.0107 - output_6_loss: 0.0078 - output_7_loss: 0.0092 - output_8_loss: 0.0014 - output_9_loss: 0.0175 - output_2_sparse_categorical_accuracy: 0.9993 - output_3_sparse_categorical_accuracy: 0.9985 - output_4_sparse_categorical_accuracy: 0.9992 - output_5_sparse_categorical_accuracy: 0.9979 - output_6_sparse_categorical_accuracy: 0.9988 - output_7_sparse_categorical_accuracy: 0.9982 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9955
Epoch 17/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0088 - output_1_loss: 0.0012 - output_2_loss: 0.0060 - output_3_loss: 0.0087 - output_4_loss: 0.0046 - output_5_loss: 0.0095 - output_6_loss: 0.0071 - output_7_loss: 0.0082 - output_8_loss: 0.0011 - output_9_loss: 0.0155 - output_2_sparse_categorical_accuracy: 0.9994 - output_3_sparse_categorical_accuracy: 0.9984 - output_4_sparse_categorical_accuracy: 0.9992 - output_5_sparse_categorical_accuracy: 0.9983 - output_6_sparse_categorical_accuracy: 0.9988 - output_7_sparse_categorical_accuracy: 0.9981 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9967
Epoch 18/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0078 - output_1_loss: 0.0011 - output_2_loss: 0.0056 - output_3_loss: 0.0077 - output_4_loss: 0.0043 - output_5_loss: 0.0078 - output_6_loss: 0.0062 - output_7_loss: 0.0072 - output_8_loss: 0.0012 - output_9_loss: 0.0137 - output_2_sparse_categorical_accuracy: 0.9992 - output_3_sparse_categorical_accuracy: 0.9988 - output_4_sparse_categorical_accuracy: 0.9995 - output_5_sparse_categorical_accuracy: 0.9991 - output_6_sparse_categorical_accuracy: 0.9989 - output_7_sparse_categorical_accuracy: 0.9987 - output_8_sparse_categorical_accuracy: 0.9997 - output_9_sparse_categorical_accuracy: 0.9972
Epoch 19/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0071 - output_1_loss: 0.0010 - output_2_loss: 0.0046 - output_3_loss: 0.0069 - output_4_loss: 0.0037 - output_5_loss: 0.0075 - output_6_loss: 0.0056 - output_7_loss: 0.0068 - output_8_loss: 0.0011 - output_9_loss: 0.0125 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9989 - output_4_sparse_categorical_accuracy: 0.9995 - output_5_sparse_categorical_accuracy: 0.9988 - output_6_sparse_categorical_accuracy: 0.9991 - output_7_sparse_categorical_accuracy: 0.9985 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9976
Epoch 20/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0065 - output_1_loss: 9.6288e-04 - output_2_loss: 0.0043 - output_3_loss: 0.0061 - output_4_loss: 0.0037 - output_5_loss: 0.0068 - output_6_loss: 0.0052 - output_7_loss: 0.0059 - output_8_loss: 9.3056e-04 - output_9_loss: 0.0113 - output_2_sparse_categorical_accuracy: 0.9996 - output_3_sparse_categorical_accuracy: 0.9994 - output_4_sparse_categorical_accuracy: 0.9994 - output_5_sparse_categorical_accuracy: 0.9988 - output_6_sparse_categorical_accuracy: 0.9989 - output_7_sparse_categorical_accuracy: 0.9991 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9977
Epoch 21/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0060 - output_1_loss: 9.3135e-04 - output_2_loss: 0.0039 - output_3_loss: 0.0055 - output_4_loss: 0.0037 - output_5_loss: 0.0061 - output_6_loss: 0.0050 - output_7_loss: 0.0056 - output_8_loss: 8.3433e-04 - output_9_loss: 0.0104 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9995 - output_4_sparse_categorical_accuracy: 0.9994 - output_5_sparse_categorical_accuracy: 0.9992 - output_6_sparse_categorical_accuracy: 0.9992 - output_7_sparse_categorical_accuracy: 0.9987 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9975
Epoch 22/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0056 - output_1_loss: 8.9429e-04 - output_2_loss: 0.0036 - output_3_loss: 0.0051 - output_4_loss: 0.0030 - output_5_loss: 0.0057 - output_6_loss: 0.0044 - output_7_loss: 0.0054 - output_8_loss: 8.6064e-04 - output_9_loss: 0.0094 - output_2_sparse_categorical_accuracy: 0.9996 - output_3_sparse_categorical_accuracy: 0.9996 - output_4_sparse_categorical_accuracy: 0.9996 - output_5_sparse_categorical_accuracy: 0.9995 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9992 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9981
Epoch 23/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0056 - output_1_loss: 8.8916e-04 - output_2_loss: 0.0034 - output_3_loss: 0.0052 - output_4_loss: 0.0032 - output_5_loss: 0.0059 - output_6_loss: 0.0044 - output_7_loss: 0.0051 - output_8_loss: 6.4341e-04 - output_9_loss: 0.0094 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9993 - output_4_sparse_categorical_accuracy: 0.9997 - output_5_sparse_categorical_accuracy: 0.9992 - output_6_sparse_categorical_accuracy: 0.9994 - output_7_sparse_categorical_accuracy: 0.9987 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9975
Epoch 24/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0048 - output_1_loss: 7.9097e-04 - output_2_loss: 0.0038 - output_3_loss: 0.0044 - output_4_loss: 0.0027 - output_5_loss: 0.0046 - output_6_loss: 0.0036 - output_7_loss: 0.0044 - output_8_loss: 6.6250e-04 - output_9_loss: 0.0077 - output_2_sparse_categorical_accuracy: 0.9994 - output_3_sparse_categorical_accuracy: 0.9994 - output_4_sparse_categorical_accuracy: 0.9996 - output_5_sparse_categorical_accuracy: 0.9996 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9994 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9989
Epoch 25/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0041 - output_1_loss: 7.7147e-04 - output_2_loss: 0.0027 - output_3_loss: 0.0039 - output_4_loss: 0.0022 - output_5_loss: 0.0041 - output_6_loss: 0.0030 - output_7_loss: 0.0037 - output_8_loss: 5.6684e-04 - output_9_loss: 0.0067 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9997 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9997 - output_6_sparse_categorical_accuracy: 0.9997 - output_7_sparse_categorical_accuracy: 0.9993 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9988
Epoch 26/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0039 - output_1_loss: 7.5079e-04 - output_2_loss: 0.0023 - output_3_loss: 0.0033 - output_4_loss: 0.0024 - output_5_loss: 0.0037 - output_6_loss: 0.0030 - output_7_loss: 0.0037 - output_8_loss: 5.3748e-04 - output_9_loss: 0.0062 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9997 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9998 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9995 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9989
Epoch 27/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0035 - output_1_loss: 6.8695e-04 - output_2_loss: 0.0021 - output_3_loss: 0.0030 - output_4_loss: 0.0020 - output_5_loss: 0.0036 - output_6_loss: 0.0026 - output_7_loss: 0.0032 - output_8_loss: 4.6048e-04 - output_9_loss: 0.0056 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9999 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9994 - output_6_sparse_categorical_accuracy: 0.9997 - output_7_sparse_categorical_accuracy: 0.9995 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9994
Epoch 28/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0036 - output_1_loss: 6.9408e-04 - output_2_loss: 0.0022 - output_3_loss: 0.0030 - output_4_loss: 0.0024 - output_5_loss: 0.0034 - output_6_loss: 0.0031 - output_7_loss: 0.0032 - output_8_loss: 6.4372e-04 - output_9_loss: 0.0054 - output_2_sparse_categorical_accuracy: 0.9998 - output_3_sparse_categorical_accuracy: 0.9998 - output_4_sparse_categorical_accuracy: 0.9997 - output_5_sparse_categorical_accuracy: 0.9998 - output_6_sparse_categorical_accuracy: 0.9995 - output_7_sparse_categorical_accuracy: 0.9996 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9992
Epoch 29/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0030 - output_1_loss: 6.3778e-04 - output_2_loss: 0.0018 - output_3_loss: 0.0024 - output_4_loss: 0.0017 - output_5_loss: 0.0029 - output_6_loss: 0.0023 - output_7_loss: 0.0028 - output_8_loss: 4.5146e-04 - output_9_loss: 0.0046 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 0.9998 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 0.9998 - output_6_sparse_categorical_accuracy: 0.9998 - output_7_sparse_categorical_accuracy: 0.9994 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9994
Epoch 30/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0038 - output_1_loss: 7.3461e-04 - output_2_loss: 0.0027 - output_3_loss: 0.0029 - output_4_loss: 0.0026 - output_5_loss: 0.0033 - output_6_loss: 0.0042 - output_7_loss: 0.0031 - output_8_loss: 0.0010 - output_9_loss: 0.0048 - output_2_sparse_categorical_accuracy: 0.9996 - output_3_sparse_categorical_accuracy: 0.9998 - output_4_sparse_categorical_accuracy: 0.9994 - output_5_sparse_categorical_accuracy: 0.9997 - output_6_sparse_categorical_accuracy: 0.9992 - output_7_sparse_categorical_accuracy: 0.9995 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9992
Epoch 31/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0036 - output_1_loss: 7.2778e-04 - output_2_loss: 0.0021 - output_3_loss: 0.0036 - output_4_loss: 0.0021 - output_5_loss: 0.0032 - output_6_loss: 0.0028 - output_7_loss: 0.0033 - output_8_loss: 7.6275e-04 - output_9_loss: 0.0049 - output_2_sparse_categorical_accuracy: 0.9998 - output_3_sparse_categorical_accuracy: 0.9993 - output_4_sparse_categorical_accuracy: 0.9994 - output_5_sparse_categorical_accuracy: 0.9997 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9994 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9995
Epoch 32/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0028 - output_1_loss: 6.1721e-04 - output_2_loss: 0.0017 - output_3_loss: 0.0025 - output_4_loss: 0.0019 - output_5_loss: 0.0024 - output_6_loss: 0.0029 - output_7_loss: 0.0023 - output_8_loss: 3.2920e-04 - output_9_loss: 0.0038 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9996 - output_4_sparse_categorical_accuracy: 0.9997 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9996 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9997
Epoch 33/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0022 - output_1_loss: 5.3992e-04 - output_2_loss: 0.0014 - output_3_loss: 0.0019 - output_4_loss: 0.0011 - output_5_loss: 0.0020 - output_6_loss: 0.0015 - output_7_loss: 0.0020 - output_8_loss: 3.0189e-04 - output_9_loss: 0.0035 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 0.9999 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9994
Epoch 34/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0020 - output_1_loss: 5.1552e-04 - output_2_loss: 0.0012 - output_3_loss: 0.0016 - output_4_loss: 9.2134e-04 - output_5_loss: 0.0019 - output_6_loss: 0.0015 - output_7_loss: 0.0019 - output_8_loss: 2.3325e-04 - output_9_loss: 0.0029 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 0.9998 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9998
Epoch 35/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0019 - output_1_loss: 5.0352e-04 - output_2_loss: 0.0011 - output_3_loss: 0.0014 - output_4_loss: 0.0012 - output_5_loss: 0.0017 - output_6_loss: 0.0013 - output_7_loss: 0.0018 - output_8_loss: 2.4349e-04 - output_9_loss: 0.0027 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 0.9999 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9999
Epoch 36/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0019 - output_1_loss: 4.7634e-04 - output_2_loss: 9.9962e-04 - output_3_loss: 0.0014 - output_4_loss: 9.8430e-04 - output_5_loss: 0.0017 - output_6_loss: 0.0012 - output_7_loss: 0.0018 - output_8_loss: 3.5185e-04 - output_9_loss: 0.0025 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9998 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9998
Epoch 37/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0021 - output_1_loss: 4.9684e-04 - output_2_loss: 0.0012 - output_3_loss: 0.0016 - output_4_loss: 0.0011 - output_5_loss: 0.0017 - output_6_loss: 0.0017 - output_7_loss: 0.0018 - output_8_loss: 3.7633e-04 - output_9_loss: 0.0031 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9997 - output_6_sparse_categorical_accuracy: 0.9997 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9998
Epoch 38/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0017 - output_1_loss: 4.5096e-04 - output_2_loss: 0.0011 - output_3_loss: 0.0015 - output_4_loss: 8.4671e-04 - output_5_loss: 0.0015 - output_6_loss: 0.0011 - output_7_loss: 0.0016 - output_8_loss: 3.7380e-04 - output_9_loss: 0.0023 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9999 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9998
Epoch 39/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0018 - output_1_loss: 4.7383e-04 - output_2_loss: 0.0011 - output_3_loss: 0.0012 - output_4_loss: 0.0012 - output_5_loss: 0.0016 - output_6_loss: 0.0013 - output_7_loss: 0.0015 - output_8_loss: 4.9181e-04 - output_9_loss: 0.0024 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 0.9999 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 0.9999 - output_7_sparse_categorical_accuracy: 0.9998 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9999
Epoch 40/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0024 - output_1_loss: 5.3746e-04 - output_2_loss: 0.0014 - output_3_loss: 0.0024 - output_4_loss: 0.0011 - output_5_loss: 0.0024 - output_6_loss: 0.0016 - output_7_loss: 0.0022 - output_8_loss: 1.9366e-04 - output_9_loss: 0.0037 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9996 - output_4_sparse_categorical_accuracy: 0.9998 - output_5_sparse_categorical_accuracy: 0.9996 - output_6_sparse_categorical_accuracy: 0.9997 - output_7_sparse_categorical_accuracy: 0.9993 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9993
Epoch 41/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0025 - output_1_loss: 5.6036e-04 - output_2_loss: 0.0017 - output_3_loss: 0.0016 - output_4_loss: 0.0017 - output_5_loss: 0.0030 - output_6_loss: 0.0023 - output_7_loss: 0.0020 - output_8_loss: 3.4897e-04 - output_9_loss: 0.0030 - output_2_sparse_categorical_accuracy: 0.9998 - output_3_sparse_categorical_accuracy: 0.9999 - output_4_sparse_categorical_accuracy: 0.9996 - output_5_sparse_categorical_accuracy: 0.9993 - output_6_sparse_categorical_accuracy: 0.9994 - output_7_sparse_categorical_accuracy: 0.9996 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 0.9995
Epoch 42/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0020 - output_1_loss: 4.9773e-04 - output_2_loss: 0.0013 - output_3_loss: 0.0016 - output_4_loss: 9.0410e-04 - output_5_loss: 0.0019 - output_6_loss: 0.0014 - output_7_loss: 0.0019 - output_8_loss: 3.2195e-04 - output_9_loss: 0.0030 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9997 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 0.9996 - output_6_sparse_categorical_accuracy: 0.9998 - output_7_sparse_categorical_accuracy: 0.9996 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9995
Epoch 43/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0023 - output_1_loss: 5.2164e-04 - output_2_loss: 0.0014 - output_3_loss: 0.0021 - output_4_loss: 0.0017 - output_5_loss: 0.0020 - output_6_loss: 0.0014 - output_7_loss: 0.0027 - output_8_loss: 4.1985e-04 - output_9_loss: 0.0027 - output_2_sparse_categorical_accuracy: 0.9997 - output_3_sparse_categorical_accuracy: 0.9995 - output_4_sparse_categorical_accuracy: 0.9997 - output_5_sparse_categorical_accuracy: 0.9997 - output_6_sparse_categorical_accuracy: 0.9996 - output_7_sparse_categorical_accuracy: 0.9993 - output_8_sparse_categorical_accuracy: 0.9998 - output_9_sparse_categorical_accuracy: 0.9996
Epoch 44/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0015 - output_1_loss: 4.4213e-04 - output_2_loss: 8.2949e-04 - output_3_loss: 0.0011 - output_4_loss: 7.3075e-04 - output_5_loss: 0.0014 - output_6_loss: 9.0615e-04 - output_7_loss: 0.0013 - output_8_loss: 2.2119e-04 - output_9_loss: 0.0018 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 0.9999 - output_5_sparse_categorical_accuracy: 0.9998 - output_6_sparse_categorical_accuracy: 0.9999 - output_7_sparse_categorical_accuracy: 0.9997 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9996
Epoch 45/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0015 - output_1_loss: 4.7961e-04 - output_2_loss: 0.0018 - output_3_loss: 0.0010 - output_4_loss: 5.5773e-04 - output_5_loss: 0.0011 - output_6_loss: 8.5916e-04 - output_7_loss: 0.0011 - output_8_loss: 2.1571e-04 - output_9_loss: 0.0018 - output_2_sparse_categorical_accuracy: 0.9995 - output_3_sparse_categorical_accuracy: 0.9999 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 0.9999 - output_7_sparse_categorical_accuracy: 0.9999 - output_8_sparse_categorical_accuracy: 0.9999 - output_9_sparse_categorical_accuracy: 1.0000
Epoch 46/50
203/203 [==============================] - 1s 3ms/step - loss: 0.0011 - output_1_loss: 3.7174e-04 - output_2_loss: 4.5386e-04 - output_3_loss: 0.0011 - output_4_loss: 4.3714e-04 - output_5_loss: 9.6851e-04 - output_6_loss: 6.5524e-04 - output_7_loss: 0.0011 - output_8_loss: 1.2404e-04 - output_9_loss: 0.0012 - output_2_sparse_categorical_accuracy: 0.9999 - output_3_sparse_categorical_accuracy: 0.9998 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9999 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 1.0000
Epoch 47/50
203/203 [==============================] - 1s 3ms/step - loss: 8.7130e-04 - output_1_loss: 3.5579e-04 - output_2_loss: 3.7196e-04 - output_3_loss: 6.0979e-04 - output_4_loss: 3.6849e-04 - output_5_loss: 6.6380e-04 - output_6_loss: 4.5160e-04 - output_7_loss: 5.4598e-04 - output_8_loss: 1.0443e-04 - output_9_loss: 0.0010 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9999 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9999
Epoch 48/50
203/203 [==============================] - 1s 3ms/step - loss: 7.9094e-04 - output_1_loss: 3.0706e-04 - output_2_loss: 3.4289e-04 - output_3_loss: 5.3048e-04 - output_4_loss: 2.6831e-04 - output_5_loss: 6.4289e-04 - output_6_loss: 4.3096e-04 - output_7_loss: 6.1185e-04 - output_8_loss: 9.0556e-05 - output_9_loss: 9.5306e-04 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9999 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9999
Epoch 49/50
203/203 [==============================] - 1s 3ms/step - loss: 7.5165e-04 - output_1_loss: 2.9656e-04 - output_2_loss: 3.2701e-04 - output_3_loss: 5.1400e-04 - output_4_loss: 3.3226e-04 - output_5_loss: 5.3249e-04 - output_6_loss: 3.9655e-04 - output_7_loss: 5.1729e-04 - output_8_loss: 1.1316e-04 - output_9_loss: 9.0793e-04 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 1.0000 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 1.0000 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 0.9999
Epoch 50/50
203/203 [==============================] - 1s 3ms/step - loss: 7.3585e-04 - output_1_loss: 2.9797e-04 - output_2_loss: 3.0015e-04 - output_3_loss: 4.5490e-04 - output_4_loss: 3.0367e-04 - output_5_loss: 6.2487e-04 - output_6_loss: 3.9739e-04 - output_7_loss: 5.6449e-04 - output_8_loss: 8.0562e-05 - output_9_loss: 7.7697e-04 - output_2_sparse_categorical_accuracy: 1.0000 - output_3_sparse_categorical_accuracy: 1.0000 - output_4_sparse_categorical_accuracy: 1.0000 - output_5_sparse_categorical_accuracy: 0.9999 - output_6_sparse_categorical_accuracy: 1.0000 - output_7_sparse_categorical_accuracy: 0.9999 - output_8_sparse_categorical_accuracy: 1.0000 - output_9_sparse_categorical_accuracy: 1.0000
INFO:tensorflow:Assets written to: tensorflow/ADULT_autoencoder/assets

Counterfactual with Reinforcement Learning

[10]:
# Define constants
COEFF_SPARSITY = 0.5               # sparisty coefficient
COEFF_CONSISTENCY = 0.5            # consisteny coefficient
TRAIN_STEPS = 10000                # number of training steps -> consider increasing the number of steps
BATCH_SIZE = 100                   # batch size

Define dataset specific attributes and constraints

A desirable property of a method for generating counterfactuals is to allow feature conditioning. Real-world datasets usually include immutable features such as Sex or Race, which should remain unchanged throughout the counterfactual search procedure. Similarly, a numerical feature such as Age should only increase for a counterfactual to be actionable.

[11]:
# Define immutable features.
immutable_features = ['Marital Status', 'Relationship', 'Race', 'Sex']

# Define ranges. This means that the `Age` feature can not decrease.
ranges = {'Age': [0.0, 1.0]}

Define and fit the explainer

[12]:
explainer = CounterfactualRLTabular(predictor=predictor,
                                    encoder=heae.encoder,
                                    decoder=heae.decoder,
                                    latent_dim=LATENT_DIM,
                                    encoder_preprocessor=heae_preprocessor,
                                    decoder_inv_preprocessor=heae_inv_preprocessor,
                                    coeff_sparsity=COEFF_SPARSITY,
                                    coeff_consistency=COEFF_CONSISTENCY,
                                    category_map=adult.category_map,
                                    feature_names=adult.feature_names,
                                    ranges=ranges,
                                    immutable_features=immutable_features,
                                    train_steps=TRAIN_STEPS,
                                    batch_size=BATCH_SIZE,
                                    backend="tensorflow")
[13]:
# Fit the explainer.
explainer = explainer.fit(X=X_train)
100%|██████████| 10000/10000 [05:16<00:00, 31.59it/s]

Test explainer

[14]:
# Select some positive examples.
X_positive = X_test[np.argmax(predictor(X_test), axis=1) == 1]

X = X_positive[:1000]
Y_t = np.array([0])
C = [{"Age": [0, 20], "Workclass": ["State-gov", "?", "Local-gov"]}]
[15]:
# Generate counterfactual instances.
explanation = explainer.explain(X, Y_t, C)
100%|██████████| 10/10 [00:00<00:00, 34.95it/s]
[16]:
# Concat labels to the original instances.
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

# Concat labels to the counterfactual instances.
cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# Define new feature names and category map by including the label.
feature_names = adult.feature_names + ["Label"]
category_map = deepcopy(adult.category_map)
category_map.update({feature_names.index("Label"): adult.target_names})

# Replace label encodings with strings.
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names
)
[17]:
orig_pd.head(n=10)
[17]:
Age Workclass Education Marital Status Occupation Relationship Race Sex Capital Gain Capital Loss Hours per week Country Label
0 60 Private High School grad Married Blue-Collar Husband White Male 7298 0 40 United-States >50K
1 35 Private High School grad Married White-Collar Husband White Male 7688 0 50 United-States >50K
2 39 State-gov Masters Married Professional Wife White Female 5178 0 38 United-States >50K
3 44 Self-emp-inc High School grad Married Sales Husband White Male 0 0 50 United-States >50K
4 44 Federal-gov High School grad Married Admin Husband White Male 0 0 40 United-States >50K
5 39 Private Bachelors Separated White-Collar Not-in-family White Female 13550 0 50 United-States >50K
6 45 Private High School grad Married Blue-Collar Husband White Male 0 1902 40 ? >50K
7 50 Private Bachelors Married Professional Husband White Male 0 0 50 United-States >50K
8 29 Private Bachelors Married White-Collar Wife White Female 0 0 50 United-States >50K
9 47 Private Bachelors Married Professional Husband White Male 0 0 50 United-States >50K
[18]:
cf_pd.head(n=10)
[18]:
Age Workclass Education Marital Status Occupation Relationship Race Sex Capital Gain Capital Loss Hours per week Country Label
0 60 Private High School grad Married Blue-Collar Husband White Male 899 84 41 United-States <=50K
1 40 Private High School grad Married Blue-Collar Husband White Male 889 77 41 United-States <=50K
2 44 ? High School grad Married ? Wife White Female 909 79 40 United-States <=50K
3 44 Self-emp-inc High School grad Married Blue-Collar Husband White Male 949 84 41 United-States <=50K
4 46 State-gov High School grad Married Admin Husband White Male 924 82 41 United-States <=50K
5 39 Private Bachelors Separated White-Collar Not-in-family White Female 976 78 41 United-States <=50K
6 45 Private High School grad Married Blue-Collar Husband White Male 884 86 41 ? <=50K
7 50 Private High School grad Married Blue-Collar Husband White Male 888 81 41 United-States <=50K
8 39 Private High School grad Married Blue-Collar Wife White Female 889 75 41 United-States <=50K
9 47 Private High School grad Married Blue-Collar Husband White Male 887 81 41 United-States <=50K

Diversity

[19]:
# Generate counterfactual instances.
X = X_positive[1].reshape(1, -1)
explanation = explainer.explain(X=X, Y_t=Y_t, C=C, diversity=True, num_samples=100, batch_size=10)
11it [00:00, 29.65it/s]
[20]:
# Concat label column.
orig = np.concatenate(
    [explanation.data['orig']['X'], explanation.data['orig']['class']],
    axis=1
)

cf = np.concatenate(
    [explanation.data['cf']['X'], explanation.data['cf']['class']],
    axis=1
)

# Transfrom label encodings to string.
orig_pd = pd.DataFrame(
    apply_category_mapping(orig, category_map),
    columns=feature_names,
)

cf_pd = pd.DataFrame(
    apply_category_mapping(cf, category_map),
    columns=feature_names,
)
[21]:
orig_pd.head(n=5)
[21]:
Age Workclass Education Marital Status Occupation Relationship Race Sex Capital Gain Capital Loss Hours per week Country Label
0 35 Private High School grad Married White-Collar Husband White Male 7688 0 50 United-States >50K
[22]:
cf_pd.head(n=5)
[22]:
Age Workclass Education Marital Status Occupation Relationship Race Sex Capital Gain Capital Loss Hours per week Country Label
0 35 Private High School grad Married ? Husband White Male 873 74 41 United-States <=50K
1 35 Private High School grad Married Admin Husband White Male 860 74 41 United-States <=50K
2 36 Private High School grad Married Blue-Collar Husband White Male 885 78 41 United-States <=50K
3 36 Private High School grad Married Blue-Collar Husband White Male 887 79 41 United-States <=50K
4 36 Private High School grad Married Blue-Collar Husband White Male 895 81 41 United-States <=50K

Logging

Logging is clearly important when dealing with deep learning models. Thus, we provide an interface to write custom callbacks for logging purposes after each training step which we defined here. In the following cells we provide some example to log in Weights and Biases.

Logging reward callback

[23]:
class RewardCallback(Callback):
    def __call__(self,
                 step: int,
                 update: int,
                 model: CounterfactualRL,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):

        if (step + update) % 100 != 0:
            return

        # get the counterfactual and target
        Y_t = sample["Y_t"]
        X_cf = model.params["decoder_inv_preprocessor"](sample["X_cf"])

        # get prediction label
        Y_m_cf = predictor(X_cf)

        # compute reward
        reward = np.mean(model.params["reward_func"](Y_m_cf, Y_t))
        wandb.log({"reward": reward})

Logging losses callback

[24]:
class LossCallback(Callback):
    def __call__(self,
                 step: int,
                 update: int,
                 model: CounterfactualRL,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # Log training losses.
        if (step + update) % 100 == 0:
            wandb.log(losses)

Logging tables callback

[25]:
class TablesCallback(Callback):
    def __call__(self,
                 step: int,
                 update: int,
                 model: CounterfactualRL,
                 sample: Dict[str, np.ndarray],
                 losses: Dict[str, float]):
        # Log every 1000 steps
        if step % 1000 != 0:
            return

        # Define number of samples to be displayed.
        NUM_SAMPLES = 5

        X = heae_inv_preprocessor(sample["X"][:NUM_SAMPLES])        # input instance
        X_cf = heae_inv_preprocessor(sample["X_cf"][:NUM_SAMPLES])  # counterfactual

        Y_m = np.argmax(sample["Y_m"][:NUM_SAMPLES], axis=1).astype(int).reshape(-1, 1) # input labels
        Y_t = np.argmax(sample["Y_t"][:NUM_SAMPLES], axis=1).astype(int).reshape(-1, 1) # target labels
        Y_m_cf = np.argmax(predictor(X_cf), axis=1).astype(int).reshape(-1, 1)          # counterfactual labels

        # Define feature names and category map for input.
        feature_names = adult.feature_names + ["Label"]
        category_map = deepcopy(adult.category_map)
        category_map.update({feature_names.index("Label"): adult.target_names})

        # Construct input array.
        inputs = np.concatenate([X, Y_m], axis=1)
        inputs = pd.DataFrame(apply_category_mapping(inputs, category_map),
                              columns=feature_names)

        # Define feature names and category map for counterfactual output.
        feature_names += ["Target"]
        category_map.update({feature_names.index("Target"): adult.target_names})

        # Construct output array.
        outputs = np.concatenate([X_cf, Y_m_cf, Y_t], axis=1)
        outputs = pd.DataFrame(apply_category_mapping(outputs, category_map),
                               columns=feature_names)

        # Log table.
        wandb.log({
            "Input": wandb.Table(dataframe=inputs),
            "Output": wandb.Table(dataframe=outputs)
        })

Having defined the callbacks, we can define a new explainer that will include logging.

import wandb

# Initialize wandb.
wandb_project = "Adult Census Counterfactual with Reinforcement Learning"
wandb.init(project=wandb_project)

# Define explainer as before and include callbacks.
explainer = CounterfactualRLTabular(...,
                                    callbacks=[LossCallback(), RewardCallback(), TablesCallback()])

# Fit the explainers.
explainer = explainer.fit(X=X_train)

# Close wandb.
wandb.finish()