{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the Pretrained Model and the dataset\n", "We use bert-base-uncased as the model and SST-2 as the dataset for example. More models can be found in [PaddleNLP Model Zoo](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html#transformer).\n", "\n", "Obviously, PaddleNLP is needed to run this notebook, which is easy to install:\n", "```bash\n", "pip install setuptools_scm \n", "pip install --upgrade paddlenlp\n", "```" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m[2021-11-04 16:50:48,431] [ INFO]\u001b[0m - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased.pdparams\u001b[0m\n", "W1104 16:50:48.433992 22865 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.2\n", "W1104 16:50:48.439213 22865 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n", "\u001b[32m[2021-11-04 16:50:58,691] [ INFO]\u001b[0m - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased-vocab.txt\u001b[0m\n", "INFO:paddle.utils.download:unique_endpoints {'10.255.126.17:35174'}\n" ] } ], "source": [ "import paddle\n", "import paddlenlp\n", "from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer\n", "\n", "MODEL_NAME = \"bert-base-uncased\"\n", "model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=2)\n", "tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "from paddlenlp.datasets import load_dataset\n", "train_ds, dev_ds, test_ds = load_dataset(\n", " \"glue\", name='sst-2', splits=[\"train\", \"dev\", \"test\"]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare the Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# training the model and save to save_dir\n", "# only needs to run once.\n", "# total steps ~2100 (1 epoch)\n", "\n", "from assets.utils import training_model\n", "training_model(model, tokenizer, train_ds, dev_ds, save_dir=f'assets/sst-2-{MODEL_NAME}')\n", "\n", "# global step 2100, epoch: 1, batch: 2100, loss: 0.22977, acc: 0.91710\n", "# eval loss: 0.20062, accu: 0.91972" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Or Load the trained model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Load the trained model.\n", "state_dict = paddle.load(f'assets/sst-2-{MODEL_NAME}/model_state.pdparams')\n", "model.set_dict(state_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# See the prediction results" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data: {'text': \"it 's a charming and often affecting journey . \"} \t Lable: positive\n", "Data: {'text': 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '} \t Lable: positive\n", "Data: {'text': 'this one is definitely one to skip , even for horror movie fanatics . '} \t Lable: positive\n", "Data: {'text': 'in its best moments , resembles a bad high school production of grease , without benefit of song . '} \t Lable: negative\n" ] } ], "source": [ "from assets.utils import predict\n", "\n", "reviews = [\n", " \"it 's a charming and often affecting journey . \",\n", " 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . ',\n", " 'this one is definitely one to skip , even for horror movie fanatics . ',\n", " 'in its best moments , resembles a bad high school production of grease , without benefit of song . '\n", "]\n", "\n", "data = [ {\"text\": r} for r in reviews]\n", "\n", "label_map = {0: 'negative', 1: 'positive'}\n", "batch_size = 32\n", "\n", "results = predict(\n", " model, data, tokenizer, label_map, batch_size=batch_size)\n", "\n", "for idx, text in enumerate(data):\n", " print('Data: {} \\t Lable: {}'.format(text, results[idx]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare for Interpretations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import interpretdl as it\n", "import numpy as np\n", "from assets.utils import convert_example, aggregate_subwords_and_importances\n", "from paddlenlp.data import Stack, Tuple, Pad\n", "from interpretdl.data_processor.visualizer import VisualizationTextRecord, visualize_text\n", "\n", "def preprocess_fn(data):\n", " examples = []\n", " \n", " if not isinstance(data, list):\n", " data = [data]\n", " \n", " for text in data:\n", " input_ids, segment_ids = convert_example(\n", " text,\n", " tokenizer,\n", " max_seq_length=128,\n", " is_test=True\n", " )\n", " examples.append((input_ids, segment_ids))\n", "\n", " batchify_fn = lambda samples, fn=Tuple(\n", " Pad(axis=0, pad_val=tokenizer.pad_token_id), # input id\n", " Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment id\n", " ): fn(samples)\n", " \n", " input_ids, segment_ids = batchify_fn(examples)\n", " return paddle.to_tensor(input_ids, stop_gradient=False), paddle.to_tensor(segment_ids, stop_gradient=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## IG Interpreter" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
True LabelPredicted Label (Prob)Target LabelWord Importance
11 (1.00)1 it ' s a charming and often affecting journey .
11 (0.99)1 the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
01 (0.76)1 this one is definitely one to skip , even for horror movie fanatics .
00 (1.00)0 in its best moments , resembles a bad high school production of grease , without benefit of song .
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ig = it.IntGradNLPInterpreter(model, device='gpu:0')\n", "\n", "pred_labels, pred_probs, avg_gradients = ig.interpret(\n", " preprocess_fn(data),\n", " steps=50,\n", " return_pred=True)\n", "\n", "true_labels = [1, 1, 0, 0] * 5\n", "recs = []\n", "for i in range(avg_gradients.shape[0]):\n", " subwords = \" \".join(tokenizer._tokenize(data[i]['text'])).split(' ')\n", " subword_importances = avg_gradients[i]\n", " words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n", " word_importances = np.array(word_importances) / np.linalg.norm(\n", " word_importances)\n", " \n", " pred_label = pred_labels[i]\n", " pred_prob = pred_probs[i, pred_label]\n", " true_label = true_labels[i]\n", " interp_class = pred_label\n", " \n", " if interp_class == 0:\n", " word_importances = -word_importances\n", " recs.append(\n", " VisualizationTextRecord(words, word_importances, true_label,\n", " pred_label, pred_prob, interp_class)\n", " )\n", "\n", "visualize_text(recs)\n", "# The visualization is not available at github" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LIME Interpreter" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
True LabelPredicted Label (Prob)Target LabelWord Importance
11 (1.00)1 it ' s a charming and often affecting journey .
11 (0.99)1 the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
01 (0.82)1 this one is definitely one to skip , even for horror movie fanatics .
00 (1.00)0 in its best moments , resembles a bad high school production of grease , without benefit of song .
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "true_labels = [1, 1, 0, 0] * 5\n", "recs = []\n", "\n", "lime = it.LIMENLPInterpreter(model, device='gpu:0')\n", "for i, review in enumerate(data):\n", " pred_class, pred_prob, lime_weights = lime.interpret(\n", " review,\n", " preprocess_fn,\n", " num_samples=1000,\n", " batch_size=32,\n", " unk_id=tokenizer.convert_tokens_to_ids('[UNK]'),\n", " pad_id=tokenizer.convert_tokens_to_ids('[PAD]'),\n", " return_pred=True)\n", "\n", " # subwords\n", " subwords = \" \".join(tokenizer._tokenize(review['text'])).split(' ')\n", " interp_class = list(lime_weights.keys())[0]\n", " weights = lime_weights[interp_class][1 : -1]\n", " subword_importances = [t[1] for t in lime_weights[interp_class][1 : -1]]\n", " \n", " words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n", " word_importances = np.array(word_importances) / np.linalg.norm(\n", " word_importances)\n", " \n", " true_label = true_labels[i]\n", " \n", " if interp_class == 0:\n", " word_importances = -word_importances\n", " \n", " rec = VisualizationTextRecord(\n", " words, \n", " word_importances, \n", " true_label, \n", " pred_class[0], \n", " pred_prob[0],\n", " interp_class\n", " )\n", " \n", " recs.append(rec)\n", "\n", "visualize_text(recs)\n", "# The visualization is not available at github" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GradShapNLPInterpreter" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
True LabelPredicted Label (Prob)Target LabelWord Importance
11 (1.00)1 it ' s a charming and often affecting journey .
11 (1.00)1 the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
01 (0.76)1 this one is definitely one to skip , even for horror movie fanatics .
00 (1.00)0 in its best moments , resembles a bad high school production of grease , without benefit of song .
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ig = it.GradShapNLPInterpreter(model, device='gpu:0')\n", "\n", "pred_labels, pred_probs, avg_gradients = ig.interpret(\n", " preprocess_fn(data),\n", " n_samples=10,\n", " noise_amount=0.1,\n", " return_pred=True)\n", "\n", "true_labels = [1, 1, 0, 0] * 5\n", "recs = []\n", "for i in range(avg_gradients.shape[0]):\n", " subwords = \" \".join(tokenizer._tokenize(data[i]['text'])).split(' ')\n", " subword_importances = avg_gradients[i]\n", " words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)\n", " word_importances = np.array(word_importances) / np.linalg.norm(\n", " word_importances)\n", " \n", " pred_label = pred_labels[i]\n", " pred_prob = pred_probs[i, pred_label]\n", " true_label = true_labels[i]\n", " interp_class = pred_label\n", " \n", " if interp_class == 0:\n", " word_importances = -word_importances\n", " recs.append(\n", " VisualizationTextRecord(words, word_importances, true_label,\n", " pred_label, pred_prob, interp_class)\n", " )\n", "\n", "visualize_text(recs)\n", "# The visualization is not available at github" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "a77f6a7464ce9eaa05196aac170c6d9a8812f0d84e7e92f765cacaf98118350b" }, "kernelspec": { "display_name": "Python 3.7.9 64-bit ('paddle2.0': conda)", "name": "python3" }, "language_info": { "name": "python", "version": "" } }, "nbformat": 4, "nbformat_minor": 5 }