{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the Pretrained Model and the dataset\n",
"We use bert-base-uncased as the model and SST-2 as the dataset for example. More models can be found in [PaddleNLP Model Zoo](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html#transformer).\n",
"\n",
"Obviously, PaddleNLP is needed to run this notebook, which is easy to install:\n",
"```bash\n",
"pip install setuptools_scm \n",
"pip install --upgrade paddlenlp\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m[2021-11-04 16:50:48,431] [ INFO]\u001b[0m - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased.pdparams\u001b[0m\n",
"W1104 16:50:48.433992 22865 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.2\n",
"W1104 16:50:48.439213 22865 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"\u001b[32m[2021-11-04 16:50:58,691] [ INFO]\u001b[0m - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased-vocab.txt\u001b[0m\n",
"INFO:paddle.utils.download:unique_endpoints {'10.255.126.17:35174'}\n"
]
}
],
"source": [
"import paddle\n",
"import paddlenlp\n",
"from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer\n",
"\n",
"MODEL_NAME = \"bert-base-uncased\"\n",
"model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=2)\n",
"tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)\n",
"\n",
"from paddlenlp.datasets import load_dataset\n",
"train_ds, dev_ds, test_ds = load_dataset(\n",
" \"glue\", name='sst-2', splits=[\"train\", \"dev\", \"test\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prepare the Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# training the model and save to save_dir\n",
"# only needs to run once.\n",
"# total steps ~2100 (1 epoch)\n",
"\n",
"from assets.utils import training_model\n",
"training_model(model, tokenizer, train_ds, dev_ds, save_dir=f'assets/sst-2-{MODEL_NAME}')\n",
"\n",
"# global step 2100, epoch: 1, batch: 2100, loss: 0.22977, acc: 0.91710\n",
"# eval loss: 0.20062, accu: 0.91972"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Or Load the trained model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Load the trained model.\n",
"state_dict = paddle.load(f'assets/sst-2-{MODEL_NAME}/model_state.pdparams')\n",
"model.set_dict(state_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# See the prediction results"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data: {'text': \"it 's a charming and often affecting journey . \"} \t Lable: positive\n",
"Data: {'text': 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '} \t Lable: positive\n",
"Data: {'text': 'this one is definitely one to skip , even for horror movie fanatics . '} \t Lable: positive\n",
"Data: {'text': 'in its best moments , resembles a bad high school production of grease , without benefit of song . '} \t Lable: negative\n"
]
}
],
"source": [
"from assets.utils import predict\n",
"\n",
"reviews = [\n",
" \"it 's a charming and often affecting journey . \",\n",
" 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . ',\n",
" 'this one is definitely one to skip , even for horror movie fanatics . ',\n",
" 'in its best moments , resembles a bad high school production of grease , without benefit of song . '\n",
"]\n",
"\n",
"data = [ {\"text\": r} for r in reviews]\n",
"\n",
"label_map = {0: 'negative', 1: 'positive'}\n",
"batch_size = 32\n",
"\n",
"results = predict(\n",
" model, data, tokenizer, label_map, batch_size=batch_size)\n",
"\n",
"for idx, text in enumerate(data):\n",
" print('Data: {} \\t Lable: {}'.format(text, results[idx]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prepare for Interpretations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import interpretdl as it\n",
"import numpy as np\n",
"from assets.utils import convert_example, aggregate_subwords_and_importances\n",
"from paddlenlp.data import Stack, Tuple, Pad\n",
"from interpretdl.data_processor.visualizer import VisualizationTextRecord, visualize_text\n",
"\n",
"def preprocess_fn(data):\n",
" examples = []\n",
" \n",
" if not isinstance(data, list):\n",
" data = [data]\n",
" \n",
" for text in data:\n",
" input_ids, segment_ids = convert_example(\n",
" text,\n",
" tokenizer,\n",
" max_seq_length=128,\n",
" is_test=True\n",
" )\n",
" examples.append((input_ids, segment_ids))\n",
"\n",
" batchify_fn = lambda samples, fn=Tuple(\n",
" Pad(axis=0, pad_val=tokenizer.pad_token_id), # input id\n",
" Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment id\n",
" ): fn(samples)\n",
" \n",
" input_ids, segment_ids = batchify_fn(examples)\n",
" return paddle.to_tensor(input_ids, stop_gradient=False), paddle.to_tensor(segment_ids, stop_gradient=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## IG Interpreter"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
| True Label | Predicted Label (Prob) | Target Label | Word Importance |
|---|
| 1 | 1 (1.00) | 1 | it ' s a charming and often affecting journey . |
|
| 1 | 1 (0.99) | 1 | the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . |
|
| 0 | 1 (0.76) | 1 | this one is definitely one to skip , even for horror movie fanatics . |
|
| 0 | 0 (1.00) | 0 | in its best moments , resembles a bad high school production of grease , without benefit of song . |
|
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ig = it.IntGradNLPInterpreter(model, device='gpu:0')\n",
"\n",
"pred_labels, pred_probs, avg_gradients = ig.interpret(\n",
" preprocess_fn(data),\n",
" steps=50,\n",
" return_pred=True)\n",
"\n",
"true_labels = [1, 1, 0, 0] * 5\n",
"recs = []\n",
"for i in range(avg_gradients.shape[0]):\n",
" subwords = \" \".join(tokenizer._tokenize(data[i]['text'])).split(' ')\n",
" subword_importances = avg_gradients[i]\n",
" words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n",
" word_importances = np.array(word_importances) / np.linalg.norm(\n",
" word_importances)\n",
" \n",
" pred_label = pred_labels[i]\n",
" pred_prob = pred_probs[i, pred_label]\n",
" true_label = true_labels[i]\n",
" interp_class = pred_label\n",
" \n",
" if interp_class == 0:\n",
" word_importances = -word_importances\n",
" recs.append(\n",
" VisualizationTextRecord(words, word_importances, true_label,\n",
" pred_label, pred_prob, interp_class)\n",
" )\n",
"\n",
"visualize_text(recs)\n",
"# The visualization is not available at github"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LIME Interpreter"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"| True Label | Predicted Label (Prob) | Target Label | Word Importance |
|---|
| 1 | 1 (1.00) | 1 | it ' s a charming and often affecting journey . |
|
| 1 | 1 (0.99) | 1 | the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . |
|
| 0 | 1 (0.82) | 1 | this one is definitely one to skip , even for horror movie fanatics . |
|
| 0 | 0 (1.00) | 0 | in its best moments , resembles a bad high school production of grease , without benefit of song . |
|
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"true_labels = [1, 1, 0, 0] * 5\n",
"recs = []\n",
"\n",
"lime = it.LIMENLPInterpreter(model, device='gpu:0')\n",
"for i, review in enumerate(data):\n",
" pred_class, pred_prob, lime_weights = lime.interpret(\n",
" review,\n",
" preprocess_fn,\n",
" num_samples=1000,\n",
" batch_size=32,\n",
" unk_id=tokenizer.convert_tokens_to_ids('[UNK]'),\n",
" pad_id=tokenizer.convert_tokens_to_ids('[PAD]'),\n",
" return_pred=True)\n",
"\n",
" # subwords\n",
" subwords = \" \".join(tokenizer._tokenize(review['text'])).split(' ')\n",
" interp_class = list(lime_weights.keys())[0]\n",
" weights = lime_weights[interp_class][1 : -1]\n",
" subword_importances = [t[1] for t in lime_weights[interp_class][1 : -1]]\n",
" \n",
" words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n",
" word_importances = np.array(word_importances) / np.linalg.norm(\n",
" word_importances)\n",
" \n",
" true_label = true_labels[i]\n",
" \n",
" if interp_class == 0:\n",
" word_importances = -word_importances\n",
" \n",
" rec = VisualizationTextRecord(\n",
" words, \n",
" word_importances, \n",
" true_label, \n",
" pred_class[0], \n",
" pred_prob[0],\n",
" interp_class\n",
" )\n",
" \n",
" recs.append(rec)\n",
"\n",
"visualize_text(recs)\n",
"# The visualization is not available at github"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GradShapNLPInterpreter"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"| True Label | Predicted Label (Prob) | Target Label | Word Importance |
|---|
| 1 | 1 (1.00) | 1 | it ' s a charming and often affecting journey . |
|
| 1 | 1 (1.00) | 1 | the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . |
|
| 0 | 1 (0.76) | 1 | this one is definitely one to skip , even for horror movie fanatics . |
|
| 0 | 0 (1.00) | 0 | in its best moments , resembles a bad high school production of grease , without benefit of song . |
|
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ig = it.GradShapNLPInterpreter(model, device='gpu:0')\n",
"\n",
"pred_labels, pred_probs, avg_gradients = ig.interpret(\n",
" preprocess_fn(data),\n",
" n_samples=10,\n",
" noise_amount=0.1,\n",
" return_pred=True)\n",
"\n",
"true_labels = [1, 1, 0, 0] * 5\n",
"recs = []\n",
"for i in range(avg_gradients.shape[0]):\n",
" subwords = \" \".join(tokenizer._tokenize(data[i]['text'])).split(' ')\n",
" subword_importances = avg_gradients[i]\n",
" words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n",
" word_importances = np.array(word_importances) / np.linalg.norm(\n",
" word_importances)\n",
" \n",
" pred_label = pred_labels[i]\n",
" pred_prob = pred_probs[i, pred_label]\n",
" true_label = true_labels[i]\n",
" interp_class = pred_label\n",
" \n",
" if interp_class == 0:\n",
" word_importances = -word_importances\n",
" recs.append(\n",
" VisualizationTextRecord(words, word_importances, true_label,\n",
" pred_label, pred_prob, interp_class)\n",
" )\n",
"\n",
"visualize_text(recs)\n",
"# The visualization is not available at github"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "a77f6a7464ce9eaa05196aac170c6d9a8812f0d84e7e92f765cacaf98118350b"
},
"kernelspec": {
"display_name": "Python 3.7.9 64-bit ('paddle2.0': conda)",
"name": "python3"
},
"language_info": {
"name": "python",
"version": ""
}
},
"nbformat": 4,
"nbformat_minor": 5
}