{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"from dataset import NewsDataset\n",
"from model import DistilBertForSequenceClassification\n",
"\n",
"from smooth_gradient import SmoothGradient\n",
"from integrated_gradient import IntegratedGradient\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"from transformers import DistilBertConfig, DistilBertTokenizer\n",
"\n",
"from IPython.display import display, HTML"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"config = DistilBertConfig()\n",
"tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
"distilbert = DistilBertForSequenceClassification(config, num_labels=93)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"batch_size = 1\n",
"\n",
"path = '/media/vitaliy/9C690A1791D68B8B/after/learningfolder/distilbert_medium_titles/distilbert.pth'\n",
"\n",
"if torch.cuda.is_available():\n",
" distilbert.load_state_dict(\n",
" torch.load(path)\n",
" )\n",
"else:\n",
" distilbert.load_state_dict(\n",
" torch.load(path, map_location=torch.device('cpu'))\n",
" )\n",
" \n",
"with open('../label_encoder.sklrn', 'rb') as f:\n",
" le = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"test_example = [\n",
" [\"Interpretation of HuggingFase's model decision\"], \n",
" [\"Transformer-based models have taken a leading role \"\n",
" \"in NLP today.\"]\n",
"]\n",
"\n",
"test_dataset = NewsDataset(\n",
" data_list=test_example,\n",
" tokenizer=tokenizer,\n",
" max_length=config.max_position_embeddings, \n",
")\n",
"\n",
"test_dataloader = DataLoader(\n",
" test_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"integrated_grad = IntegratedGradient(\n",
" distilbert, \n",
" criterion, \n",
" tokenizer, \n",
" show_progress=False,\n",
" encoder=\"bert\"\n",
")\n",
"instances = integrated_grad.saliency_interpret(test_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using bos_token, but it is not set yet.\n"
]
},
{
"data": {
"text/html": [
" [CLS] interpretation of huggingfase ' s model decision [SEP] transformer - based models have taken a leading role in nlp today . [SEP] Label: 44 |47.84%|"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"coloder_string = integrated_grad.colorize(instances[0])\n",
"display(HTML(coloder_string))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted label #44: machine-learning\n"
]
}
],
"source": [
"label = instances[0]['label']\n",
"print(f\"Converted label #{label}: {le.inverse_transform([label])[0]}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sqrt",
"language": "python",
"name": "sqrt"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}