{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2231055-6a63-4425-9248-c5aae455396e",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "<figure>\n",
    "<center>\n",
    "<img src=\"../Imagenes/logo_final.png\"  align=\"left\"/> \n",
    "</center>   \n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4ce5373-7375-4e11-872a-d58813d66c24",
   "metadata": {},
   "source": [
    "# <span style=\"color:red\"><center>BERT-Español</center></span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30115409-9d8e-45df-a9cd-072858c02397",
   "metadata": {},
   "source": [
    "<center>Explorando modelos pre-entrenados</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23484740-ab13-447a-bc3f-01686f528f49",
   "metadata": {},
   "source": [
    "##   <span style=\"color:blue\">Autores</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "405456a4-e556-41d2-bf60-593a5055f8b4",
   "metadata": {
    "tags": []
   },
   "source": [
    "1. Alvaro Mauricio Montenegro Díaz, ammontenegrod@unal.edu.co\n",
    "2. Daniel Mauricio Montenegro Reyes, dextronomo@gmail.com "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7aaec87-e77d-4152-bb83-bc9e62d4a94c",
   "metadata": {},
   "source": [
    "##   <span style=\"color:blue\">Diseño gráfico y Marketing digital</span>\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aac3be8e-e519-43bf-9bd9-9903a27b570b",
   "metadata": {},
   "source": [
    "1. Maria del Pilar Montenegro Reyes, pmontenegro88@gmail.com "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80e4663a-f0f2-42d8-a044-6e3ab6d16e3e",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Asistentes</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b038974-8718-4142-a66e-895520f68ca7",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0a15511c-982d-4ff7-96fd-a6f1c95f39a1",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Referencias</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "190d6ac6-19de-4dcf-8904-a9e158c84d34",
   "metadata": {},
   "source": [
    "1. [HuggingFace BERT model](https://huggingface.co/transformers/model_doc/bert.html)\n",
    "1. [Getting Started with Google BERT: Build and train state-of-the-art natural language processing models using BERT](http://library.lol/main/A0CA3A1276D07957FD7B28F843C299BA)\n",
    "1. [Transformers for Natural Language Processing: Build innovative deep neural network architectures for NLP with Python, PyTorch, TensorFlow, BERT, RoBERTa, and more](http://library.lol/main/A8C97E552646B3F194ECA333221CEE88)\n",
    "1. [HuggingFace. Transformers ](https://huggingface.co/transformers/)\n",
    "1. [HuggingFace. Intro pipeline](https://huggingface.co/course/chapter1/3?fw=pt)\n",
    "1. [Tutorial Transformer de Google](https://www.tensorflow.org/text/tutorials/transformer)\n",
    "1. [Transformer-chatbot-tutorial-with-tensorflow-2](https://blog.tensorflow.org/2019/05/transformer-chatbot-tutorial-with-tensorflow-2.html) \n",
    "1. [Transformer Architecture: The positional encoding](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)\n",
    "1. [Illustrated Auto-attención](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a)\n",
    "1. [Illustrated Attention](https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3#0458)\n",
    "1. [Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et. al, 2015)](https://arxiv.org/pdf/1409.0473.pdf)\n",
    "1. [Effective Approaches to Attention-based Neural Machine Translation (Luong et. al, 2015)](https://arxiv.org/pdf/1508.04025.pdf)\n",
    "1. [Attention Is All You Need (Vaswani et. al, 2017)](https://arxiv.org/pdf/1706.03762.pdf)\n",
    "1. [Self-Attention GAN (Zhang et. al, 2018)](https://arxiv.org/pdf/1805.08318.pdf)\n",
    "1. [Sequence to Sequence Learning with Neural Networks (Sutskever et. al, 2014)](https://arxiv.org/pdf/1409.3215.pdf)\n",
    "1. [TensorFlow’s seq2seq Tutorial with Attention (Tutorial on seq2seq+attention)](https://github.com/tensorflow/nmt)\n",
    "1. [Lilian Weng’s Blog on Attention (Great start to attention)](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)\n",
    "1. [Jay Alammar’s Blog on Seq2Seq with Attention (Great illustrations and worked example on seq2seq+attention)](https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/)\n",
    "1. [Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation (Wu et. al, 2016)](https://arxiv.org/pdf/1609.08144.pdf)\n",
    "1. [Adam: A method for stochastic optimization](https://arxiv.org/pdf/1412.6980.pdf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17e9be15-8ae8-4715-b57f-a6e8a8111c0c",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Contenido</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a382773-c33f-4247-b44c-6e3806b9af20",
   "metadata": {},
   "source": [
    "* [Introducción](#Introducción)\n",
    "* [Extracción de incrustamientos de un modelo BERT pre-entrenado](#Extracción-de-incrustamientos-de-un-modelo-BERT-pre-entrenado)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3c46451-72e1-4588-b271-855af1898df1",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Introducción</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "192cd6c6-0240-4bd0-b2e6-427aff064ebb",
   "metadata": {},
   "source": [
    "Usaremos la implementación de HuggingFace in en Pytorch. \n",
    "\n",
    "+ BERT es un modelo con incrustaciones de posición absoluta, por lo que generalmente se recomienda rellenar (padding) las entradas a la derecha en lugar de a la izquierda.\n",
    "\n",
    "+ BERT fue entrenado con el modelado de lenguaje enmascarado (MLM) y los objetivos de predicción de la siguiente oración (NSP). Es eficiente para predecir tokens enmascarados y en NLU en general, pero no es óptimo para la generación de texto."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b86ab415-c6d7-4d89-bbae-098b30f95578",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Extracción de incrustamientos de un modelo BERT pre-entrenado</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cd69e7e-8b83-4a1b-ad59-8d5f01381baa",
   "metadata": {},
   "source": [
    "La tarea de PLN es análisis de sentimiento. El primer experimento lo hacemos en  Español. Para esta tarea esta bién usar el modelo uncase (eliminando mayúsculas). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d93f08d4-ed19-46ef-88cb-f79056f932de",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertModel, BertTokenizer\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "raw",
   "id": "090081a2-b0e4-4345-aff4-b6cd79b6eb3c",
   "metadata": {},
   "source": [
    "## Tensorflow\n",
    "from transformers import TBertModel, TBertTokenizer\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e483a0f6-52b5-439c-96fc-ba7bc99cc4e4",
   "metadata": {},
   "source": [
    "### Cargamos el modelo pre-entrenado 'bert-base-uncase' y su respectivo tokenizador"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b0d6e7aa-c090-4cff-80ce-f85de557c3cc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15602b925db14a84b4dde755aa021139",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/650 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0486ec8faa6c4f4b9adc29394e84dd8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dccuchile/bert-base-spanish-wwm-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-uncased and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ccedf76f949f4f7d89185d8179c79334",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/248k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc026ad7a9244e0aa3cb071f985b329b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/134 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36f7a65ec3494521ac81accce6d341b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/310 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21f0922616554c2992181a22b623321d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/486k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = BertModel.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')\n",
    "                                 \n",
    "tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f85fb650-c2c7-4287-aa32-fa84acb2c658",
   "metadata": {},
   "source": [
    "### Preprocesamiento de la entrada"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "40e00063-d712-494c-9bc5-39cf1bd332bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['yo', 'amo', 'a', 'bogotá']\n"
     ]
    }
   ],
   "source": [
    "# sentencia\n",
    "sentence = 'Yo amo a Bogotá'\n",
    "\n",
    "# tokenización\n",
    "tokens = tokenizer.tokenize(sentence)\n",
    "\n",
    "# print\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77d5985b-8bfd-4fa0-8255-0dce67066768",
   "metadata": {},
   "source": [
    "### Agregamos los tokens [CLS] al comienzo y [SEP] al final de la lista de tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f48fc4af-645a-4dc2-a73d-07c25f3623a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]', 'yo', 'amo', 'a', 'bogotá', '[SEP]']\n"
     ]
    }
   ],
   "source": [
    "tokens = ['[CLS]'] + tokens + ['[SEP]']\n",
    "\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "396b2885-baed-4daf-8504-da040e19e927",
   "metadata": {},
   "source": [
    "###  Relleno y máscara para la sentencia"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df8ec922-bb09-4452-b387-9a44e2a6cdc5",
   "metadata": {},
   "source": [
    "El tamaño de la lista de tokens es 5. Supongamos que hemos decidido que el tamaño máximo se las sentencias será 7. BERT está constuido para aceptar sentencias hasta de tamaño 512. Todas las sentencias deben tener el mismo tamaño."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e144cfa0-e684-4f55-8fc7-d30317821fae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]', 'yo', 'amo', 'a', 'bogotá', '[SEP]', '[PAD]']\n"
     ]
    }
   ],
   "source": [
    "## Relleno\n",
    "max_sentence_size = 7\n",
    "pad_size = max_sentence_size - len(tokens)\n",
    "\n",
    "for i in range(pad_size): \n",
    "    tokens = tokens + ['[PAD]'] \n",
    "\n",
    "print(tokens)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "315fd135-f97a-4453-a658-8972d56d5c75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 1, 1, 1, 1, 1, 0]\n"
     ]
    }
   ],
   "source": [
    "## máscara de atención\n",
    "attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]\n",
    "print(attention_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad09897-d398-4f44-8431-a554a2d83972",
   "metadata": {},
   "source": [
    "### Convertimos la lsita de tokens en la lista de ID de los tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6fbea2a3-320d-4dd7-8521-d1af470d3d45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 1252, 4017, 1012, 14548, 5, 1]\n"
     ]
    }
   ],
   "source": [
    "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "print(token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f15280fd-00da-45ca-a450-a7a3ad1d69f5",
   "metadata": {},
   "source": [
    "### Convertimos token_ids y attention_mask a tensores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6e6e7bd2-4899-4cb9-9d8c-5b9c8d8ed639",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[    4,  1252,  4017,  1012, 14548,     5,     1]])\n"
     ]
    }
   ],
   "source": [
    "token_ids = torch.tensor(token_ids).unsqueeze(0) # unsuezze es para agregar una dimensión al comienzo (varias sentencias)\n",
    "attention_mask = torch.tensor(attention_mask).unsqueeze(0)\n",
    "print(token_ids) # tensor([[ 101, 1045, 2293, 3000,  102,    0,    0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "675844ed-0528-44c2-86eb-b3f6c59742df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[    4,  1252,  4017,  1012, 14548,     5,     1]])\n",
      "tensor([[1, 1, 1, 1, 1, 1, 0]])\n"
     ]
    }
   ],
   "source": [
    "print(token_ids)\n",
    "print(attention_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86126d72-8263-4d14-93b2-37d269bc10fb",
   "metadata": {},
   "source": [
    "### Pregunta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5c9ca04-b524-4655-aa73-899edad4d1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "¿Cómo hace esto con tensorflow?"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a9a18715-4f3b-4759-b455-1c6339eb17e8",
   "metadata": {},
   "source": [
    "# Consigne aquí su respuesta\n",
    "import tensorflow as tf\n",
    "token_ids = tf.expand_dims(tf.constant(token_ids), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a587dc32-086e-4b90-aaa5-e786350cd427",
   "metadata": {},
   "source": [
    "### Extracción del incrustamiento final"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb47b7a9-de74-4348-8792-8e27c6725df3",
   "metadata": {},
   "source": [
    "model regresa una lista con dos objetos: \n",
    "\n",
    "* El primer valor, *last_hidden_state*, contiene la representación de todos los tokens obtenidos solo de la capa del codificador final (codificador 12).\n",
    "* A continuación, *pooler_output* indica la representación del token [CLS] de la capa codificadora final, que se procesa posteriormente mediante una capa lineal y una activación *tanh*. La capa lineal es entrenada cuando se entrena el modelo BERT para la tarea NSP (Next sequence prediction).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0613566b-a5a5-410e-86ec-4a8b021315d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(token_ids, attention_mask = attention_mask)\n",
    "last_hidden_state, pooler_output = out.last_hidden_state, out.pooler_output # out[0], out[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6d52ab29-9b0d-416e-b0f0-823ec5e74168",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 7, 768])\n"
     ]
    }
   ],
   "source": [
    "print(last_hidden_state.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4b228c0-26b1-4371-8ca1-d9a4371f2c03",
   "metadata": {},
   "source": [
    "* Tamaño batch = 1\n",
    "* Tamaño secuencia = 7\n",
    "* tamaño del emebdding = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e6061579-ecac-4724-a27b-c3d74896cc42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['last_hidden_state', 'pooler_output'])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# out es un diccionario. Podemos obtener las claves  así:\n",
    "out.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1287a8b-9a09-4b2f-8fc2-5c6d4ffe5f77",
   "metadata": {},
   "source": [
    "### Extracción de los incrustamientos de todas las capas codificadoras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e43f16-437b-419c-b6e5-61f1ca3f1333",
   "metadata": {},
   "source": [
    "En esta sección revisamos como extraer las incrustaciones (embeddings) que salen de cada una de las capas codificadoras (12 por ejemplo en el modelo base). Algunos veces estop se hace para extraer diferentes features de las sentencias. \n",
    "\n",
    "Por ejemplo en la tarea NER (name entity recognition) los investigadores han usado las incrustaciones de las diferentes capas, para hacr promedios pesados de algunas de ellas y con esto han podido mejorar la exactitud en la precisión.\n",
    "\n",
    "Para hacer esto, es necesario instanciar el modelo preentrenado con la opción *output_hidden_states=True*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8b14fb0-e508-4164-b9a5-cbae7e9db8ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BertModel.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased', output_hidden_states=True)\n",
    "tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9cd791bc-ec6a-43f3-8deb-e6fa50c7df6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(token_ids, attention_mask=attention_mask)\n",
    "\n",
    "last_hidden_state, pooler_output, hidden_states = \\\n",
    "        out.last_hidden_state, out.pooler_output, out.hidden_states\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2916d80b-0f59-48fb-b349-b7c986ab065c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 7, 768])\n",
      "torch.Size([1, 768])\n",
      "13\n"
     ]
    }
   ],
   "source": [
    "print(last_hidden_state.shape)\n",
    "print(pooler_output.shape)\n",
    "print(len(hidden_states)) # esta es una lista conteniendo las\n",
    "                          # incrutaciones de todas las capas codificadoras"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6d576f-40d0-4a9a-a56d-159b0d01c42f",
   "metadata": {},
   "source": [
    "Observe que *hidden_states* tiene 13 elementos. La capa 0 corresponde a la incrustación de la capa de entrada, luego los elementos 1 a 12 corresponden a las incrustaciones de de salida de cada una de las 12 capas codificadoras.\n",
    "\n",
    "\n",
    "La representación de los token de la última capa oculta (codificadora) pueden ser obtenidos así:\n",
    "\n",
    "* *last_hidden_state[0][0]*: entrega la representación del prime token, es decir, *[CLS]*.\n",
    "* *last_hidden_state[0][1]*: entrega la representación del Token *I*.\n",
    "* *last_hidden_state[0][2]*: entrega la representación del Token *love*. \n",
    "\n",
    "Esta es la representación contextual final de los token. \n",
    "\n",
    "Las incrustaciones de cada capa *i*, se obtienen mediante *hidden_states[i]:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dc65d589-ec50-468f-b440-10fdf46a4932",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 7, 768])\n",
      "torch.Size([1, 7, 768])\n"
     ]
    }
   ],
   "source": [
    "# Incrutaciones de la capa de entrada\n",
    "input_embedding = hidden_states[0]\n",
    "print(input_embedding.shape)\n",
    "\n",
    "# incrustaciones de la capa codificadora 11\n",
    "embedding_11 = hidden_states[11]\n",
    "print(embedding_11.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5555e4d2-9e50-4592-9ae7-e7ccc6df3645",
   "metadata": {},
   "outputs": [],
   "source": [
    "help(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "372442e1-0221-4d21-95da-9d0058ea7070",
   "metadata": {},
   "source": [
    "### Recuperando los pesos de atención"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "176cad80-144e-486a-a805-8dc09569c275",
   "metadata": {},
   "source": [
    "Los pesos de  atención después de la atención softmax, se utilizan para calcular el promedio ponderado en las cabezas de  autoatención. \n",
    "Son obtenidos pasando al modelo *output_attentions=True*\n",
    "\n",
    "+ *output attention* es una tupla. Cada elemento coresponde a los pesos de atención de cada capa codificadora."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d18b8954-dae6-4721-a0ec-7a0ca3ce5d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BertModel.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased',\\\n",
    "                                  output_hidden_states=True, output_attentions=True)\n",
    "tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0efd6e94-5101-473f-a9b3-24efa41b12a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12\n"
     ]
    }
   ],
   "source": [
    "out = model(token_ids, attention_mask=attention_mask)\n",
    "\n",
    "last_hidden_state, pooler_output, hidden_states, attentions = \\\n",
    "        out.last_hidden_state, out.pooler_output, out.hidden_states, \\\n",
    "        out.attentions\n",
    "print(len(attentions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "be0b076f-cadd-4c72-96f7-12dc08906dd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 12, 7, 7])\n"
     ]
    }
   ],
   "source": [
    "print(attentions[11].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6f09a46-8acc-46b4-aef3-b92b1fe0c2f9",
   "metadata": {},
   "source": [
    "La salida se explica así:\n",
    "\n",
    "- El tamaño del batch es 1. Una sentencia.\n",
    "- Son 12 cabezas de atención.\n",
    "- La sentencia viene de tamaño 7.\n",
    "\n",
    "Por lo tanto tenemos la salida de las 12 cabezas de atención para la sentencia.\n",
    "\n",
    "Vamos a darle una mirada a los pesos de atención de la última capa codificadora\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "58fabdd1-b720-4734-8acc-f10010d37a62",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention11 = attentions[11].squeeze()#elimina la dimensión de batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c414704d-b8f9-4271-93db-721319905899",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 7, 7])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention11.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf6013b9-45bc-4d5f-b1cc-aa04837e9450",
   "metadata": {},
   "source": [
    "### Función para graficar pesos de atención de una cabeza"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b99ba957-08ff-4553-a4e2-3d1aca5472a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "# versión con decode utf-8\n",
    "def plot_attention_head_cp(in_tokens, translated_tokens, attention):\n",
    "  # The plot is of the attention when a token was generated.\n",
    "  # The model didn't generate `<START>` in the output. Skip it.\n",
    "  translated_tokens = translated_tokens[1:]\n",
    "\n",
    "  ax = plt.gca()\n",
    "  ax.matshow(attention)\n",
    "  ax.set_xticks(range(len(in_tokens)))\n",
    "  ax.set_yticks(range(len(translated_tokens)))\n",
    "\n",
    "  labels = [label.decode('utf-8') for label in in_tokens.numpy()]\n",
    "  ax.set_xticklabels(\n",
    "      labels, rotation=90)\n",
    "\n",
    "  labels = [label.decode('utf-8') for label in translated_tokens.numpy()]\n",
    "  ax.set_yticklabels(labels)\n",
    "\n",
    "\n",
    "\n",
    "def plot_attention_head(in_tokens, translated_tokens, attention):\n",
    "  # The plot is of the attention when a token was generated.\n",
    "  # The model didn't generate `<START>` in the output. Skip it.\n",
    "  #translated_tokens = translated_tokens[1:]\n",
    "\n",
    "  ax = plt.gca()\n",
    "  pcm = ax.matshow(attention)\n",
    "  ax.set_xticks(range(len(in_tokens)))\n",
    "  ax.set_yticks(range(len(translated_tokens)))\n",
    "\n",
    "  labels = [label for label in in_tokens]\n",
    "  ax.set_xticklabels(\n",
    "      labels, rotation=90)\n",
    "\n",
    "  labels = [label for label in translated_tokens]\n",
    "  ax.set_yticklabels(labels)\n",
    "  \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f9c4698e-32b6-42ba-a90e-a68023d49177",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 7])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "head = attention11[0]\n",
    "head.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a25408a7-86c1-4da8-9045-e2235583e2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "head = head.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0e583ceb-8a8a-47df-9633-b8857b7de318",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_attention_head(in_tokens=tokens, translated_tokens=tokens, attention=head)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b603500-d79b-4b20-99e3-0996c8454d41",
   "metadata": {},
   "source": [
    "### Visualizando los pesos de todas las cabezas de atención"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "864f57d2-01d6-41b5-872c-402b44a9e7e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_attention_weights(sentence, translated_tokens, attention_heads):\n",
    "  in_tokens = sentence\n",
    "  #in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()\n",
    "  #in_tokens = tokenizers.pt.lookup(in_tokens)[0]\n",
    "  #in_tokens\n",
    "\n",
    "  fig = plt.figure(figsize=(16, 8))\n",
    "\n",
    "  for h, head in enumerate(attention_heads):\n",
    "    ax = fig.add_subplot(3, 4, h+1)\n",
    "\n",
    "    plot_attention_head(in_tokens, translated_tokens, head)\n",
    "\n",
    "    ax.set_xlabel(f'Head {h+1}')\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "51b75032-bdf9-4b9f-a069-373bc7afc6fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x576 with 12 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "heads = attention11.detach().numpy()\n",
    "\n",
    "plot_attention_weights(sentence=tokens, translated_tokens=tokens, \n",
    "                      attention_heads=heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16bf982c-06c4-4492-8964-a5836c033eeb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}