This page was generated from examples/anchor_image_fashion_mnist.ipynb.

Anchor explanations for fashion MNIST

[2]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from alibi.explainers import AnchorImage
tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages

Load and prepare fashion MNIST data

[3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)
x_train shape: (60000, 28, 28) y_train shape: (60000,)
[4]:
idx = 0
plt.imshow(x_train[idx]);
../_images/examples_anchor_image_fashion_mnist_4_0.png

Scale, reshape and categorize data

[5]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train = np.reshape(x_train, x_train.shape + (1,))
x_test = np.reshape(x_test, x_test.shape + (1,))
print('x_train shape:', x_train.shape, 'x_test shape:', x_test.shape)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
print('y_train shape:', y_train.shape, 'y_test shape:', y_test.shape)
x_train shape: (60000, 28, 28, 1) x_test shape: (10000, 28, 28, 1)
y_train shape: (60000, 10) y_test shape: (10000, 10)

Define CNN model

[6]:
def model():
    x_in = Input(shape=(28, 28, 1))
    x = Conv2D(filters=64, kernel_size=2, padding='same', activation='relu')(x_in)
    x = MaxPooling2D(pool_size=2)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(filters=32, kernel_size=2, padding='same', activation='relu')(x)
    x = MaxPooling2D(pool_size=2)(x)
    x = Dropout(0.3)(x)

    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x_out = Dense(10, activation='softmax')(x)

    cnn = Model(inputs=x_in, outputs=x_out)
    cnn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return cnn
[6]:
cnn = model()
cnn.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        320
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 32)        8224
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 32)          0
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 32)          0
_________________________________________________________________
flatten (Flatten)            (None, 1568)              0
_________________________________________________________________
dense (Dense)                (None, 256)               401664
_________________________________________________________________
dropout_2 (Dropout)          (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570
=================================================================
Total params: 412,778
Trainable params: 412,778
Non-trainable params: 0
_________________________________________________________________

Train model

[7]:
cnn.fit(x_train, y_train, batch_size=64, epochs=3)
Epoch 1/3
60000/60000 [==============================] - 38s 630us/sample - loss: 0.5847 - acc: 0.7862
Epoch 2/3
60000/60000 [==============================] - 39s 643us/sample - loss: 0.4134 - acc: 0.8507
Epoch 3/3
60000/60000 [==============================] - 39s 644us/sample - loss: 0.3660 - acc: 0.8671
[7]:
<tensorflow.python.keras.callbacks.History at 0x7ffb62bdafd0>
[8]:
# Evaluate the model on test set
score = cnn.evaluate(x_test, y_test, verbose=0)
print('Test accuracy: ', score[1])
Test accuracy:  0.8839

Define superpixels

Function to generate rectangular superpixels for a given image. Alternatively, use one of the built in methods. It is important to have meaningful superpixels in order to generate a useful explanation. Please check scikit-image’s segmentation methods (felzenszwalb, slic and quickshift built in the explainer) for more information on the built in methods.

[9]:
def superpixel(image, size=(4, 7)):
    segments = np.zeros([image.shape[0], image.shape[1]])
    row_idx, col_idx = np.where(segments == 0)
    for i, j in zip(row_idx, col_idx):
        segments[i, j] = int((image.shape[1]/size[1]) * (i//size[0]) + j//size[1])
    return segments
[10]:
segments = superpixel(x_train[idx])
plt.imshow(segments);
../_images/examples_anchor_image_fashion_mnist_15_0.png

Define prediction function

[11]:
predict_fn = lambda x: cnn.predict(x)

Initialize anchor image explainer

[12]:
image_shape = x_train[idx].shape
explainer = AnchorImage(predict_fn, image_shape, segmentation_fn=superpixel)

Explain a prediction

The explanation returns a mask with the superpixels that constitute the anchor.

Image to be explained:

[13]:
i = 0
image = x_test[i]
plt.imshow(image[:,:,0]);
../_images/examples_anchor_image_fashion_mnist_21_0.png

Model prediction:

[14]:
cnn.predict(image.reshape(1, 28, 28, 1)).argmax()
[14]:
9

The predicted category correctly corresponds to the class ankle boot:

Label

Description

0

T-shirt/top

1

Trouser

2

Pullover

3

Dress

4

Coat

5

Sandal

6

Shirt

7

Sneaker

8

Bag

9

Ankle boot

Generate explanation:

[15]:
explanation = explainer.explain(image, threshold=.95, p_sample=.8)

Show anchor:

[16]:
plt.imshow(explanation['anchor'][:,:,0]);
../_images/examples_anchor_image_fashion_mnist_28_0.png

From the example, it looks like the arch of the boot alone is sufficient to predict an ankle boot.