This page was generated from examples/integrated_gradients_transformers.ipynb.

Integrated gradients for transformers models

In this example, we apply the integrated gradients method to two different sentiment analysis models. The first one is a pretrained sentiment analysis model from the transformers library. The second model is a combination of a pretrained (distil)BERT model and a simple feed forward network. The entire model, (distil)BERT and feed forward network, is trained on the IMDB reviews dataset.

In text classification models, integrated gradients (IG) define an attribution value for each word in the input sentence. The attributions are calculated considering the integral of the model gradients with respect to the word embedding layer along a straight path from a baseline instance \(x^\prime\) to the input instance \(x.\) A description of the method can be found here. Integrated gradients was originally proposed in Sundararajan et al., “Axiomatic Attribution for Deep Networks”

[1]:
import re
import os
import numpy as np
import matplotlib as mpl
import matplotlib.cm
import tensorflow as tf
import tensorflow.keras as keras

from tqdm import tqdm
from typing import Optional, Union, List, Dict
from IPython.display import HTML
from tensorflow.keras.datasets import imdb
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from transformers import PreTrainedTokenizer
from alibi.explainers import IntegratedGradients

Here we define some functions needed to process the data and visualize. For consistency with other text examples in alibi, we will use the IMDB reviews dataset provided by Keras. Since the dataset consists of reviews that are already tokenized, we need to decode each sentence and re-convert them into tokens using the (distil)BERT tokenizer.

[2]:
def decode_sentence(x: List[int], reverse_index: Dict[int, str], unk_token: str = '[UNK]') -> str:
    """
    Decodes the tokenized sentences from keras IMDB dataset into plain text.

    Parameters
    ----------
    x
        List of integers to be docoded.
    revese_index:
        Reverse index map, from `int` to `str`.
    unk_token:
        Unkown token to be used.

    Returns
    -------
        Decoded sentence.
    """
    # the `-3` offset is due to the special tokens used by keras
    # see https://stackoverflow.com/questions/42821330/restore-original-text-from-keras-s-imdb-dataset
    return " ".join([reverse_index.get(i - 3, unk_token) for i in x])


def process_sentences(sentence: List[str],
                      tokenizer: PreTrainedTokenizer,
                      max_len: int) -> Dict[str, np.ndarray]:
    """
    Tokenize the text sentences.

    Parameters
    ----------
    sentence:
        Sentence to be processed.
    tokenizer:
        Tokenizer to be used.

    Returns
    -------
        Tokenized representation containing:
         - input_ids
         - attention_mask
    """
    # since we are using the model for classification, we need to include special char (i.e, '[CLS]', ''[SEP]')
    # check the example here: https://huggingface.co/transformers/v4.4.2/quicktour.html
    z = tokenizer(sentence,
                  add_special_tokens=True,
                  padding='max_length',
                  max_length=max_len,
                  truncation=True,
                  return_attention_mask = True,
                  return_tensors='np')
    return z
[3]:
def  hlstr(string: str , color: str = 'white') -> str:
    """
    Return HTML markup highlighting text with the desired color.
    """
    return f"<mark style=background-color:{color}>{string} </mark>"


def colorize(attrs: np.ndarray, cmap: str = 'PiYG') -> List:
    """
    Compute hex colors based on the attributions for a single instance.
    Uses a diverging colorscale by default and normalizes and scales
    the colormap so that colors are consistent with the attributions.

    Parameters
    ----------
    attrs:
        Attributions to be visualized.
    cmap:
        Matplotlib cmap type.
    """
    cmap_bound = np.abs(attrs).max()
    norm = mpl.colors.Normalize(vmin=-cmap_bound, vmax=cmap_bound)
    cmap = mpl.cm.get_cmap(cmap)
    return list(map(lambda x: mpl.colors.rgb2hex(cmap(norm(x))), attrs))


def display(X: np.ndarray,
            attrs: np.ndarray,
            tokenizer: PreTrainedTokenizer,
            pred: np.ndarray) -> None:
    """
    Display the attribution of a given instance.

    Parameters
    ----------
    X:
        Instance to display the attributions for.
    attrs:
        Attributions values for the given instance.
    tokenizer:
        Tokenizer to be used for decoding.
    pred:
        Classification label (prediction) for the given instance.
    """
    pred_dict = {1: 'Positive review', 0: 'Negative review'}

    # remove padding
    fst_pad_indices = np.where(X ==tokenizer.pad_token_id)[0]
    if len(fst_pad_indices) > 0:
        X, attrs = X[:fst_pad_indices[0]], attrs[:fst_pad_indices[0]]

    # decode tokens and get colors
    tokens = [tokenizer.decode([X[i]]) for i in range(len(X))]
    colors = colorize(attrs)

    print('Predicted label =  {}: {}'.format(pred, pred_dict[pred]))
    return HTML("".join(list(map(hlstr, tokens, colors))))

Automodel

In this section, we will use the Tensorflow auto model for sequence classification provided by the transformers library.

The model is pretrained on the Stanford Sentiment Treebank (SST) dataset. The Stanford Sentiment Treebank is the first corpus with fully labeled parse trees that allows for a complete analysis of the compositional effects of sentiment in language.

Each phrase is labeled as either negative, somewhat negative, neutral, somewhat positive or positive. The corpus with all 5 labels is referred to as SST-5 or SST fine-grained. Binary classification experiments on full sentences (negative or somewhat negative vs somewhat positive or positive with neutral sentences discarded) refer to the dataset as SST-2 or SST binary. In this example, we will use a text classifier pretrained on the SST-2 dataset.

[4]:
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

# load model and tokenizer
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
auto_model_distilbert = TFAutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.

All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.

The auto_model output is a custom object containing the output logits. We use a wrapper to transform the output into a tensor and apply a softmax function to the logits.

[5]:
class AutoModelWrapper(keras.Model):
    def __init__(self, transformer: keras.Model, **kwargs):
        """
        Constructor.

        Parameters
        ----------
        transformer:
            Transformer to be wrapped.
        """
        super().__init__()
        self.transformer = transformer

    def call(self,
             input_ids: Union[np.ndarray, tf.Tensor],
             attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
             training: bool = False):
        """
        Performs forward pass throguh the model.

        Parameters
        ----------
        input_ids:
            Indices of input sequence tokens in the vocabulary.
        attention_mask:
            Mask to avoid performing attention on padding token indices.

        Returns
        -------
            Classification probabilities.
        """
        out = self.transformer(input_ids=input_ids, attention_mask=attention_mask, training=training)
        return tf.nn.softmax(out.logits, axis=-1)

    def get_config(self):
        return {}

    @classmethod
    def from_config(cls, config):
        return cls(**config)
[6]:
auto_model = AutoModelWrapper(auto_model_distilbert)

Calculate integrated gradients

[7]:
max_features = 10000
max_len = 128

Here we consider some simple sentences such as “I love you, I like you”, “I love you, I like you, but I also kind of dislike you” .

[8]:
text_samples = ['I love you, I like you',
                'I love you, I like you, but I also kind of dislike you',
                'Everything is so nice about you']

# since using the uncased model, we need to use lowercase sentences
text_samples = [text.lower() for text in text_samples]

# tokenize the sentences using the transformer's tokenizer.
tokenized_samples = process_sentences(text_samples, tokenizer, max_len)
X_test = tokenized_samples['input_ids'].astype(np.int32)

# the values of the kwargs have to be `tf.Tensor`.
# see transformers issue #14404: https://github.com/huggingface/transformers/issues/14404
kwargs = {k: tf.constant(v) for k,v in tokenized_samples.items() if k == 'attention_mask'}

The auto model consists of a main distilBERT layer (layer 0) followed by two dense layers.

[9]:
auto_model.layers[0].layers
[9]:
[<transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertMainLayer at 0x7f30403fa5b0>,
 <keras.layers.core.dense.Dense at 0x7f2cdc56a0d0>,
 <keras.layers.core.dense.Dense at 0x7f2cdc56a340>,
 <keras.layers.core.dropout.Dropout at 0x7f2cdc56a670>]

We will proceed with the embedding layer from distilBERT. We calculate attributions to the outputs of the embedding layer since we can easily construct an appropriate baseline for the IG which will result in more interpretable attribution. If we considered a hidden layer instead, we would inevitably capture higher order interaction between the input tokens, which might be harder to interpret as we no longer have a one-to-one mapping between layer outputs and input tokens. Moreover, the embedding layer is a standard choice since we can not compute attributions for the raw input due to its discrete structure (i.e., we cannot differentiate the output of the model with respect to the discrete input representation). That being said, you can use any other layer and compute attributions to the outputs of it instead.

[10]:
#  Extracting the embeddings layer
layer = auto_model.layers[0].layers[0].embeddings

# # Extract the first layer from the transformer
# layer = auto_model.layers[0].layers[0].transformer.layer[0]
[11]:
# define IG
n_steps = 50
internal_batch_size = 5
method = "gausslegendre"

ig  = IntegratedGradients(auto_model,
                          layer=layer,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)
[12]:
# get predictions
predictions = auto_model(X_test, **kwargs).numpy().argmax(axis=1)

# Get the baselines. Note that the baseline contains special characters (e.g, [CLS], [SEP], [UNK] [PAD]) and
# the regular tokens are replaced by the [PAD] token which is a neutral token.
# By including special tokens such as [CLS], [SEP], [UNK], we ensure that the attribution for those tokens
# will be 0 if we use the embedding layer. The 0 attribution is due to integration between [x, x] which is 0.
mask = np.isin(X_test, tokenizer.all_special_ids)
baselines = X_test * mask + tokenizer.pad_token_id * (1 - mask)

# get explanation
explanation = ig.explain(X_test,
                         forward_kwargs=kwargs,
                         baselines=baselines,
                         target=predictions)
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:AutoGraph could not transform <bound method Socket.send of <zmq.sugar.socket.Socket object at 0x7f3040410160>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <bound method Socket.send of <zmq.sugar.socket.Socket object at 0x7f3040410160>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.

Let’s check the attributions’ shapes.

[13]:
# Get attributions values from the explanation object
attrs = explanation.attributions[0]
print('Attributions shape:', attrs.shape)
Attributions shape: (3, 128, 768)

As you can see, the attribution of each token corresponds to a tensor of 768 elements. We compress all this information into a single number buy summing up all 768 components. The nice thing about this is that we still remain consistent with the Completeness Axiom, which states that the attributions add up to the difference between the output of our model for the given instance and the output of our model for the given baseline.

[14]:
attrs = attrs.sum(axis=2)
print('Attributions shape:', attrs.shape)
Attributions shape: (3, 128)
[15]:
index = 1
display(X=X_test[index], attrs=attrs[index], pred=predictions[index], tokenizer=tokenizer)
Predicted label =  0: Negative review
[15]:
[CLS] i love you , i like you , but i also kind of dislike you [SEP]

Note that since the sentence is classified as negative, words like ``dislike`` contribute positively to the score while words like ``love`` contribute negatively.

Sentiment analysis on IMDB with fine-tuned model head.

Load and process data

[16]:
# load imdb reviews datasets.
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

# extract sub-set of training and testing
train_size, test_size = 25000, 1000
x_train, y_train = x_train[:train_size], y_train[:train_size]
x_test, y_test = x_test[:test_size], y_test[:test_size]

# remove the first integer token which is a special character
# that marks the beginning of the sentence
x_train = [x[1:] for x in x_train]
x_test = [x[1:] for x in x_test]

# get mappings. The keys are transformed to lower case since we will use uncased models.
reverse_index = {value: key.lower() for (key, value) in imdb.get_word_index().items()}

Load model and corresponding tokenizer

Now we have to load the model and the corresponding tokenizer. You can chose between the BERT model or the distilBERT model. Note that we will be finetuning those models which will require access to a GPU. In our experiments, we trained distilBERT on a single Quadro RTX 5000 which requires around 5GB of memory. The entire training took around 5-6 min. We recommend using distilBERT as we did not noticed a big difference in performance between the two models after finetuning.

[17]:
# Choose whether to use the BERT model by setting the following flag to `True`
# Otherwise the distilBERT will be used
use_bert = False
[18]:
if use_bert:
    from transformers import BertTokenizerFast
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
else:
    from transformers import DistilBertTokenizerFast
    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
[19]:
if use_bert:
    from transformers import TFBertModel, BertConfig
    config = BertConfig(output_hidden_states=True)
    transformer = TFBertModel.from_pretrained('bert-base-uncased', config=config)
else:
    from transformers import TFDistilBertModel, DistilBertConfig
    config = DistilBertConfig(output_hidden_states=True)
    transformer = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)
Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_transform', 'vocab_layer_norm', 'vocab_projector', 'activation_13']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.

Decoding each sentence in the Keras IMDB tokenized dataset to obtain the corresponding plain text. The dataset is already in a pretty good shape, so we don’t need to do extra preprocessing. The only thing that we do is to replace the unknown tokens with the appropriate tokenizer’s unknown token (i.e., [UNK])

[20]:
X_train, X_test = [], []

# decode training sentences
for i in range(len(x_train)):
    tr_sentence = decode_sentence(x_train[i], reverse_index, unk_token=tokenizer.unk_token)
    X_train.append(tr_sentence)

# decode testing sentences
for i in range(len(x_test)):
    te_sentence = decode_sentence(x_test[i], reverse_index, unk_token=tokenizer.unk_token)
    X_test.append(te_sentence)

Retokenizing the plain text using the (distil)BERT tokenizer.

[21]:
# tokenize datasets
X_train = process_sentences(X_train, tokenizer, max_len)
X_test = process_sentences(X_test, tokenizer, max_len)

Train model

Here we train a classification model by leveraging the pretrained (distil)BERT transformer. Since we are dealing with classification, we only require the output embedding corresponding to the [CLS] token (remember that we introduced some special tokens such as: [CLS], [SEP]). The output embedding for the [CLS] token is a 768 dimensional vector which encoded the entire sentence. The model head consists of one dense layer of 128 hidden units followed by a 2 unit layer with softmax activation, and a dropout layer in-between with a rate of 0.2.

[22]:
class IMDBClassifier(tf.keras.Model):
    def __init__(self,
                 transformer,
                 hidden_dims: int = 128,
                 output_dims: int = 2,
                 dropout_rate: float = 0.2):
        """
        Constructor

        Parameters
        ----------
        transformer:
            Transformer model to be leveraged.
        hidden_dims:
            hidden layer's dimension.
        output_dims:
            Output layer's dimension.
        dropout_rate:
            Dropout layer's dropout rate.
        """
        super().__init__()
        self.hidden_dims = hidden_dims
        self.output_dims = output_dims
        self.dropout_rate = dropout_rate

        self.transformer = transformer
        self.dense_1 = tf.keras.layers.Dense(self.hidden_dims, activation='relu')
        self.dropout_1 = tf.keras.layers.Dropout(self.dropout_rate)
        self.dense_2 = tf.keras.layers.Dense(self.output_dims, activation='softmax')

    def call(self,
             input_ids: Union[np.ndarray, tf.Tensor],
             attention_mask: Optional[Union[np.ndarray, tf.Tensor]]=None,
             training=False):
        """
        Performs forward pass throguh the model.

        Parameters
        ----------
        input_ids:
            Indices of input sequence tokens in the vocabulary.
        attention_mask:
            Mask to avoid performing attention on padding token indices.

        Returns
        -------
            Classification probabilities.
        """
        out = self.transformer(input_ids=input_ids, attention_mask=attention_mask, training=training)
        out = out.last_hidden_state[:, 0, :]  # extract the embedding corresponding to [CLS] token
        out = self.dense_1(out)
        out = self.dropout_1(out, training=training)
        out = self.dense_2(out)
        return out
[23]:
# define the classification model
model = IMDBClassifier(transformer)

We train the entire classification model. That includes the transformer too, which can be hardware demanding. Training just the top layers is possible too, but in our experiments it resulted in a considerably lower accuracy.

[24]:
filepath = './model_transformers/'  # change to desired save directory
load_model = False

# optimization params
learning_rate = 2e-5
batch_size = 32
epochs = 2

# compile the model
model.compile(optimizer=Adam(learning_rate=learning_rate),
              loss=SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

if not load_model:
    checkpoint_path = os.path.join(filepath, "training/cp-{epoch:04d}.ckpt")
    checkpoint_dir = os.path.dirname(checkpoint_path)

    # Create a callback that saves the model's weights every epoch
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     verbose=1,
                                                     save_weights_only=True,
                                                     save_freq='epoch')

    # fit the model
    model.fit(x=[X_train['input_ids'], X_train['attention_mask']],
              y=y_train,
              validation_data=([X_test['input_ids'], X_test['attention_mask']], y_test),
              epochs=epochs,
              batch_size=batch_size,
              callbacks=[cp_callback])
else:
    epoch = 2
    load_path = os.path.join(filepath, f"training/cp-{epoch:04d}.ckpt")
    model.load_weights(load_path)
Epoch 1/2
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
782/782 [==============================] - ETA: 0s - loss: 0.3514 - accuracy: 0.8422WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.

Epoch 00001: saving model to ./model_transformers/training/cp-0001.ckpt
782/782 [==============================] - 167s 209ms/step - loss: 0.3514 - accuracy: 0.8422 - val_loss: 0.2839 - val_accuracy: 0.8670
Epoch 2/2
782/782 [==============================] - ETA: 0s - loss: 0.2279 - accuracy: 0.9074
Epoch 00002: saving model to ./model_transformers/training/cp-0002.ckpt
782/782 [==============================] - 163s 209ms/step - loss: 0.2279 - accuracy: 0.9074 - val_loss: 0.3096 - val_accuracy: 0.8720

Calculate integrated gradients

We pick the first 10 sentences from the test set as examples. You can easily add some of your text here too, as we exemplify it.

[25]:
# include IMDB reviews from the test dataset
text_samples = [decode_sentence(x_test[i], reverse_index, unk_token=tokenizer.unk_token) for i in range(10)]

# inlcude your text here
text_samples.append("best movie i've ever seen nothing bad to say about it")

# tokenize text
tokenized_samples = process_sentences(text_samples, tokenizer, max_len)
X_test = tokenized_samples['input_ids']

# the values of the kwargs have to be `tf.Tensor`.
# see transformers issue #14404: https://github.com/huggingface/transformers/issues/14404
kwargs = {k:tf.constant(v) for k, v in tokenized_samples.items() if k == 'attention_mask'}

We calculate the attributions with respect to the first embedding layer of the (distil)BERT. You can choose any other layer.

[26]:
if use_bert:
    layer = model.layers[0].bert.embeddings
    # layer = model.layers[0].bert.encoder.layer[0]
else:
    layer = model.layers[0].distilbert.embeddings
    # layer = model.layers[0].distilbert.transformer.layer[0]
[27]:
# define IG
n_steps = 50
method = "gausslegendre"
internal_batch_size = 5

ig  = IntegratedGradients(model,
                          layer=layer,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)
[28]:
# compute model's prediction and construct baselines
predictions = model(X_test, **kwargs).numpy().argmax(axis=1)

# construct the baseline as before
mask = np.isin(X_test, tokenizer.all_special_ids)
baselines = X_test * mask + tokenizer.pad_token_id * (1 - mask)

# get explanation
explanation = ig.explain(X_test,
                         forward_kwargs=kwargs,
                         baselines=baselines,
                         target=predictions)
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
[29]:
# Get attributions values from the explanation object
attrs = explanation.attributions[0]
print('Attributions shape:', attrs.shape)
Attributions shape: (11, 128, 768)
[30]:
attrs = attrs.sum(axis=2)
print('Attributions shape:', attrs.shape)
Attributions shape: (11, 128)

Check attributions for our example

[31]:
index = -1
display(X=X_test[index], attrs=attrs[index], pred=predictions[index], tokenizer=tokenizer)
Predicted label =  1: Positive review
[31]:
[CLS] best movie i ' ve ever seen nothing bad to say about it [SEP]

Check attribution for some test examples

[32]:
index = 0
display(X=X_test[index], attrs=attrs[index], pred=predictions[index], tokenizer=tokenizer)
Predicted label =  0: Negative review
[32]:
[CLS] please give this one a miss br br [UNK] [UNK] and the rest of the cast rendered terrible performances the show is flat flat flat br br i don ' t know how michael madison could have allowed this one on his plate he almost seemed to know this wasn ' t going to work out and his performance was quite [UNK] so all you madison fans give this a miss [SEP]
[33]:
index = 1
display(X=X_test[index], attrs=attrs[index], pred=predictions[index], tokenizer=tokenizer)
Predicted label =  1: Positive review
[33]:
[CLS] this film requires a lot of patience because it focuses on mood and character development the plot is very simple and many of the scenes take place on the same set in frances [UNK] the sandy dennis character apartment but the film builds to a disturbing climax br br the characters create an atmosphere [UNK] with sexual tension and psychological [UNK] it ' s very interesting that robert alt ##man directed this considering the style and structure of his other films still the trademark alt ##man audio style is evident here and there i think what really makes this film work is the brilliant performance by sandy dennis it ' s definitely one of her darker characters but she plays it so perfectly and convincing ##ly [SEP]