{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "nlp_lime.ipynb",
"provenance": [],
"toc_visible": true
},
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
},
"kernel_info": {
"name": "nlp-py38"
},
"kernelspec": {
"display_name": "Python 3.8 - PyTorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"nteract": {
"version": "nteract-front-end@1.0.0"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "iKENeWEEwKIb"
},
"source": [
"Explicaciones para NLP utilizando LIME\n",
"======================================"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JctlpI10jJ15"
},
"source": [
"Introducción\n",
"------------"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sT_t9OxYwKIc"
},
"source": [
"LIME es una novedosa técnica de explicación que explica las predicciones de cualquier clasificador de manera interpretable al aprendiendo un modelo interpretable localmente alrededor de la predicción que el modelo verdadero realiza. Es decir, LIME construye un modelo (interpretable) alrededor de esta predicción con el objetivo de poder comprender como se comporta el espacio (manifold) en esa región confinada. La idea es que si bien un modelo interpretable puede tener muy mala correlación sobre todo el conjunto de datos, puede tener una muy buena correlación en una zona acotada del espacio. \n",
"\n",
"
\n",
"\n",
"Este método resulta flexible explicando diferentes modelos de texto y clasificación de imágenes ya que se los puede extender a estos dominios (haciendo algunas salvedades).\n",
"\n",
"Para una introducción más detallada puede ver la entrada del blog: [Model interpretability — Making your model confesses: LIME](https://santiagof.medium.com/model-interpretability-making-your-model-confess-lime-89db7f70a72b)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t67TU7HLY1Hg"
},
"source": [
"¿Como se computa el modelo local que entrena lime?\n",
"--------------------------------------------------\n",
"\n",
"Para aprender el comportamiento local del modelo (función *f*), LIME aproxima la función f* en las instancias que están cerca de la instancia que queremos explicar. Estas instancias son generadas (muestreadas) y luego poderadas según que tan lejos están de la instancia que necesitamos explicar.\n",
"\n",
"En el caso del texto, estas muestras son generadas eliminando de forma aleatoria palabras de la observación original. Si el modelo descansaba en esta palabra, entonces el mismo debería ver una caida en la performance notable. Adicionalmente, mencionamos que esta perdida estaba poderada por \"que tan parecida o no\" es la muestra a la observación original, es decir que necesitamos una métrica de similaridad. En el caso de texto, esta metrica es la similaridad del coseno, que mide el ángulo de diferencia entre dos vectores.\n",
"\n",
"> Note aqui que los modelos que aprenden representaciones de forma interna no tienen una injerencia en esta métrica de similaridad."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dcyc_TQ6dis7"
},
"source": [
"### Para ejecutar este notebook"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nJFcNrbmjJ17"
},
"source": [
"Para ejecutar este notebook, instale las siguientes librerias:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GgK8b6e_jJ17"
},
"source": [
"!wget https://raw.githubusercontent.com/santiagxf/M72109/master/NLP/Datasets/mascorpus/tweets_marketing.csv \\\n",
" --quiet --no-clobber --directory-prefix ./Datasets/mascorpus/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sxwUvomeZ9hd"
},
"source": [
"!pip install transformers --quiet\n",
"!pip install lime --quiet\n",
"!pip install eli5 --quiet"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vr7vZC0cwff0"
},
"source": [
"Descargaremos un modelo previamente entrenando el el problema de clasificación de Tweets:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XgXwXTX2wff1"
},
"source": [
"!wget https://santiagxf.blob.core.windows.net/public/models/tweet_classification_bert.zip --no-clobber --quiet\n",
"!unzip -qq tweet_classification_bert.zip"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ntcs1AlpfckX"
},
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "20Sih7EhY1H0"
},
"source": [
"Cargamos el conjunto de datos con el que se entrenó el modelo en caso de necesitarlo"
]
},
{
"cell_type": "code",
"metadata": {
"id": "I8vqJD9JwKIv"
},
"source": [
"import pandas as pd\n",
"\n",
"tweets = pd.read_csv('Datasets/mascorpus/tweets_marketing.csv')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HcO4lC5qY1H3"
},
"source": [
"Cargando un modelo de NLP\n",
"-------------------------"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_gBXNzwYwKIu"
},
"source": [
"Cargaremos el modelo que fue descargado anteriornmente utilizando la librería de `transformers`. Note que cargamos tanto el `tokenizer` como el modelo propiamente dicho."
]
},
{
"cell_type": "code",
"metadata": {
"id": "FuUI_dXFxh6u"
},
"source": [
"from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
"\n",
"model_name = \"tweet_classification_bert\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dArhsrI6Y1H5"
},
"source": [
"Entrenando el modelo de LIME\n",
"----------------------------"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aQk-Q4D6Y1H5"
},
"source": [
"### Función de predicciones"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rqn8GMaiY1H5"
},
"source": [
"Para poder utilizar el método de LIME es necesario implementar una función que reciba como entrada un texto y devuelva la distribución probabilistica de las diferentes etiquetas que nuestro modelo predice. La función que sigue realiza esto:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "KU9ZxTd9yNrf"
},
"source": [
"import torch\n",
"import numpy as np\n",
"from typing import Union, List\n",
"\n",
"def predict_proba(text: Union[str, List[str]]) -> np.ndarray:\n",
" \"\"\"\n",
" Ejecuta el modelo sobre una secuencia de texto arbitraria y devuelve la distribución probabilistica para cada una de las clases.\n",
"\n",
" Parameters\n",
" ----------\n",
" text: Union[str, List[str]]\n",
" Texto sobre el que se desea ejecutar el modelo\n",
"\n",
" Returns\n",
" -------\n",
" np.ndarray\n",
" Distribución probabilistica de las clases del modelo.\n",
" \"\"\"\n",
" inputs = tokenizer(text, padding=True, truncation=True, max_length=20, return_tensors='pt')\n",
" predictions = model(**inputs)\n",
" smx = torch.nn.Softmax(dim = 1)(predictions.logits)\n",
"\n",
" return smx.detach().numpy()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "50fz5eanY1H6"
},
"source": [
"Verifiquemos que la función funciona"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bswgkwan7Kn7",
"outputId": "e57e4dd8-50d4-47b4-f0a4-f0f080adad07"
},
"source": [
"predict_proba([\"la casa estaba si vacia claro\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0.4471247 , 0.02632283, 0.03702476, 0.08742985, 0.04531398,\n",
" 0.318541 , 0.03824287]], dtype=float32)"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "juB0uVfFY1H_"
},
"source": [
"Recordemos que nuestro modelo predice los sectores a los que pertenecería el tweet, siendo ellos:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RMBEop__zFAI"
},
"source": [
"target_names = ['ALIMENTACION', 'AUTOMOCION', 'BANCA', 'BEBIDAS', 'DEPORTES', 'RETAIL', 'TELCO']"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uo8ljJ8cY1IB"
},
"source": [
"### Entrenando el modelo localmente en la instancia a explicar"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8GpTjik0zrV_"
},
"source": [
"from eli5.lime import TextExplainer\n",
"\n",
"te = TextExplainer(random_state=42, n_samples=500).fit(\"Nos estafaron en carrefour. No vuelvo a comprar alli jamas\", predict_proba)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TLOV-TupY1IC"
},
"source": [
"Veamos las explicaciones:"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "9abSfxwd1M7B",
"outputId": "f1f6fbba-b1c9-4a93-d48c-689199c45ae4"
},
"source": [
"te.show_prediction(target_names=target_names)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
" \n",
"\n",
" \n",
"\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"
\n", " \n", " \n", " y=ALIMENTACION\n", " \n", "\n", "\n", " \n", " (probability 0.016, score -4.122)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.134\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -3.988\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=AUTOMOCION\n", " \n", "\n", "\n", " \n", " (probability 0.023, score -3.706)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.578\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -3.128\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=BANCA\n", " \n", "\n", "\n", " \n", " (probability 0.038, score -3.213)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.564\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -2.649\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=BEBIDAS\n", " \n", "\n", "\n", " \n", " (probability 0.013, score -4.269)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.534\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -3.735\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=DEPORTES\n", " \n", "\n", "\n", " \n", " (probability 0.015, score -4.166)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.599\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -3.567\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=RETAIL\n", " \n", "\n", "\n", " \n", " (probability 0.880, score 2.249)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " +2.635\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
| \n", " -0.386\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " y=TELCO\n", " \n", "\n", "\n", " \n", " (probability 0.015, score -4.165)\n", "\n", "top features\n", "
\n", " \n", "| \n", " Contribution?\n", " | \n", " \n", "Feature | \n", " \n", "
|---|---|
| \n", " -0.624\n", " | \n", "\n", " <BIAS>\n", " | \n", " \n", "
| \n", " -3.541\n", " | \n", "\n", " Highlighted in text (sum)\n", " | \n", " \n", "
\n", " nos estafaron en carrefour. no vuelvo a comprar alli jamas\n", "
\n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "