{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "17XKXlnyVBQF" }, "source": [ "# Embeddings\n", "\n", "An embedding maps discrete, categorical values to a continous space. Major advances in NLP applications have come from these continuous representations of words.\n", "\n", "If we have some sentence," ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "303nHJfnVCoR" }, "outputs": [], "source": [ "!pip install pymagnitude pytorch_pretrained_bert tensorboardcolab -q" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "_H68bmuYVBQA" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from pymagnitude import Magnitude\n", "from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM\n", "from scipy import spatial\n", "from sklearn.manifold import TSNE\n", "from tensorboardcolab import TensorBoardColab\n", "from torch.utils.tensorboard import SummaryWriter\n", "from tqdm import tqdm_notebook as tqdm\n", "\n", "\n", "%config InlineBackend.figure_format = 'svg'\n", "%matplotlib inline\n", "\n", "RED, BLUE = '#FF4136', '#0074D9'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 4393, "status": "ok", "timestamp": 1573668627926, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "fxTuHmKiVBQH", "outputId": "679b82a6-07fa-4d6b-9cd1-b0d795f39485" }, "outputs": [ { "data": { "text/plain": [ "['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']" ] }, "execution_count": 3, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "sentence = 'the quick brown fox jumps over the lazy dog'\n", "words = sentence.split()\n", "words" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iKfSGpJxVBQP" }, "source": [ "We first turn this sentence into numbers by assigning each unique word an integer." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 155 }, "colab_type": "code", "executionInfo": { "elapsed": 525, "status": "ok", "timestamp": 1573668628272, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "LWXE9OmDVBQR", "outputId": "a5f98afb-3c4a-4806-ebb4-a2564a388cc9" }, "outputs": [ { "data": { "text/plain": [ "{'brown': 0,\n", " 'dog': 1,\n", " 'fox': 2,\n", " 'jumps': 3,\n", " 'lazy': 4,\n", " 'over': 5,\n", " 'quick': 6,\n", " 'the': 7}" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "word2idx = {word: idx for idx, word in enumerate(sorted(set(words)))}\n", "word2idx" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "R8plCT7bVBQX" }, "source": [ "Then, we turn each word in our sentence into its assigned index." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 986, "status": "ok", "timestamp": 1573668638283, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "aB40yyPSVBQY", "outputId": "9c8d972e-c9bc-4e34-dad5-655e6f48d7ee" }, "outputs": [ { "data": { "text/plain": [ "tensor([7, 6, 0, 2, 3, 5, 7, 4, 1])" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "idxs = torch.LongTensor([word2idx[word] for word in sentence.split()])\n", "idxs" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4xE--RIrVBQe" }, "source": [ "Next, we want to create an **embedding layer**. The embedding layer is a 2-D matrix of shape `(n_vocab x embedding_dimension)`. If we apply our input list of indices to the embedding layer, each value in the input list of indices maps to that specific row of the embedding layer matrix. The output shape after applying the input list of indices to the embedding layer is another 2-D matrix of shape `(n_words x embedding_dimension)`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 190 }, "colab_type": "code", "executionInfo": { "elapsed": 1065, "status": "ok", "timestamp": 1573668640850, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "VlF7QIr5VBQg", "outputId": "b3c9cd0a-d17a-4788-b0a8-788145ead85b" }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 1.2617, -0.4338, 1.0826],\n", " [-1.1667, -0.5306, 1.2059],\n", " [ 1.9853, -0.1801, 0.6577],\n", " [ 1.4299, 0.6668, 0.1062],\n", " [-1.1479, -0.7475, 0.1726],\n", " [ 0.2242, 0.4477, -0.1458],\n", " [ 1.2617, -0.4338, 1.0826],\n", " [ 1.0471, -0.8155, -0.6301],\n", " [ 0.4003, 0.2927, -0.7212]], grad_fn=),\n", " torch.Size([9, 3]))" ] }, "execution_count": 7, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "embedding_layer = nn.Embedding(num_embeddings=len(word2idx), embedding_dim=3)\n", "embeddings = embedding_layer(idxs)\n", "embeddings, embeddings.shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5G_N4Cb0VBQl" }, "source": [ "The PyTorch builtin embedding layer comes with randomly initialized weights that are updated with gradient descent as your model learns to map input indices to some kind of output. However, often it is better to use pretrained embeddings that do not update but instead are frozen." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nWFKrgx-VBQm" }, "source": [ "## GloVe Embeddings\n", "\n", "GloVe embeddings are one of the most popular pretrained word embeddings in use. You can download them [here](https://nlp.stanford.edu/projects/glove/). For the best performance for most applications, I recommend using their Common Crawl embeddings with 840B tokens; however, they take the longest to download, so instead let's download the Wikipedia embeddings with 6B tokens" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "FKo_Pg6wVBQn" }, "outputs": [], "source": [ "# Download GloVe vectors (uncomment the below)\n", "\n", "# !wget http://nlp.stanford.edu/data/glove.6B.zip && unzip glove.6B.zip && mkdir glove && mv glove*.txt glove\n", "\n", "# GLOVE_FILENAME = 'glove/glove.6B.50d.txt'\n", "# glove_index = {}\n", "# n_lines = sum(1 for line in open(GLOVE_FILENAME))\n", "# with open(GLOVE_FILENAME) as fp:\n", "# for line in tqdm(fp, total=n_lines):\n", "# split = line.split()\n", "# word = split[0]\n", "# vector = np.array(split[1:]).astype(float)\n", "# glove_index[word] = vector\n", " \n", "# glove_embeddings = np.array([glove_index[word] for word in words])\n", "\n", "# # Because the length of the input sequence is 9 words and the embedding\n", "# # dimension is 50, the output shape is `(9 x 50)`.\n", "# glove_embeddings.shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2StD14zGVBQ3" }, "source": [ "### Magnitude Library for Fast Vector Loading" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rvyAGoEIVBQ4" }, "source": [ "Loading the entire GloVe file can take up a lot of memory. We can use the `magnitude` library for more efficient embedding vector loading. You can download the magnitude version of GloVe embeddings [here](https://github.com/plasticityai/magnitude#pre-converted-magnitude-formats-of-popular-embeddings-models)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 311 }, "colab_type": "code", "executionInfo": { "elapsed": 6822, "status": "ok", "timestamp": 1573668659403, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "vnzGlMubVBQ5", "outputId": "67c51492-9ffa-4c75-f002-9625c2cbef02" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2019-11-13 18:11:06-- http://magnitude.plasticity.ai/glove/light/glove.6B.50d.magnitude\n", "Resolving magnitude.plasticity.ai (magnitude.plasticity.ai)... 52.216.184.90\n", "Connecting to magnitude.plasticity.ai (magnitude.plasticity.ai)|52.216.184.90|:80... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 102670336 (98M) [binary/octet-stream]\n", "Saving to: ‘glove.6B.50d.magnitude.1’\n", "\n", "glove.6B.50d.magnit 100%[===================>] 97.91M 28.2MB/s in 3.5s \n", "\n", "2019-11-13 18:11:10 (28.2 MB/s) - ‘glove.6B.50d.magnitude.1’ saved [102670336/102670336]\n", "\n", "--2019-11-13 18:11:10-- http://glove/\n", "Resolving glove (glove)... failed: Name or service not known.\n", "wget: unable to resolve host address ‘glove’\n", "FINISHED --2019-11-13 18:11:10--\n", "Total wall clock time: 4.0s\n", "Downloaded: 1 files, 98M in 3.5s (28.2 MB/s)\n" ] } ], "source": [ "!wget http://magnitude.plasticity.ai/glove/light/glove.6B.50d.magnitude glove/" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "w-0r7FHLVBQ-" }, "outputs": [], "source": [ "# Load Magnitude GloVe vectors\n", "glove_vectors = Magnitude('glove/glove.6B.50d.magnitude')" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "DP2sOnZ1VBRC" }, "outputs": [], "source": [ "glove_embeddings = glove_vectors.query(words)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ARcZ2PwsVBRG" }, "source": [ "## Similarity operations on embeddings" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "8Ara5883VBRH" }, "outputs": [], "source": [ "def cosine_similarity(word1, word2):\n", " vector1, vector2 = glove_vectors.query(word1), glove_vectors.query(word2)\n", " return 1 - spatial.distance.cosine(vector1, vector2)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 86 }, "colab_type": "code", "executionInfo": { "elapsed": 915, "status": "ok", "timestamp": 1573668660335, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "LQV1Ur9PVBRO", "outputId": "5ac73c0f-385a-4ccd-810f-9596094c1b84" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Similarity between \"dog\" and \"cat\":\t0.92\n", "Similarity between \"tree\" and \"cat\":\t0.57\n", "Similarity between \"tree\" and \"leaf\":\t0.74\n", "Similarity between \"king\" and \"queen\":\t0.78\n" ] } ], "source": [ "word_pairs = [\n", " ('dog', 'cat'),\n", " ('tree', 'cat'),\n", " ('tree', 'leaf'),\n", " ('king', 'queen'),\n", "]\n", "\n", "for word1, word2 in word_pairs:\n", " print(f'Similarity between \"{word1}\" and \"{word2}\":\\t{cosine_similarity(word1, word2):.2f}')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "3mvCSt-2VBRV" }, "source": [ "## Visualizing Embeddings\n", "\n", "We can demonstrate that embeddings carry semantic information by plotting them. However, because our embeddings are more than three dimensions, they are impossible to visualize. Therefore, we can use an algorithm called t-SNE to project the word embeddings to a lower dimension in order to plot them in 2-D." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "MYoO6T2kVBRX" }, "outputs": [], "source": [ "ANIMALS = [\n", " 'whale',\n", " 'fish',\n", " 'horse',\n", " 'rabbit',\n", " 'sheep',\n", " 'lion',\n", " 'dog',\n", " 'cat',\n", " 'tiger',\n", " 'hamster',\n", " 'pig',\n", " 'goat',\n", " 'lizard',\n", " 'elephant',\n", " 'giraffe',\n", " 'hippo',\n", " 'zebra',\n", "]\n", "\n", "HOUSEHOLD_OBJECTS = [\n", " 'stapler',\n", " 'screw',\n", " 'nail',\n", " 'tv',\n", " 'dresser',\n", " 'keyboard',\n", " 'hairdryer',\n", " 'couch',\n", " 'sofa',\n", " 'lamp',\n", " 'chair',\n", " 'desk',\n", " 'pen',\n", " 'pencil',\n", " 'table',\n", " 'sock',\n", " 'floor',\n", " 'wall',\n", "]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 1006, "status": "ok", "timestamp": 1573668664007, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "5R_k2AiCVBRd", "outputId": "fb62b449-432c-463e-ddf9-0f2bc8cb86e7" }, "outputs": [ { "data": { "text/plain": [ "(35, 2)" ] }, "execution_count": 15, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "tsne_words_embedded = TSNE(n_components=2).fit_transform(glove_vectors.query(ANIMALS + HOUSEHOLD_OBJECTS))\n", "tsne_words_embedded.shape" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 621 }, "colab_type": "code", "executionInfo": { "elapsed": 997, "status": "ok", "timestamp": 1573668668518, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "OfM7fFagVBRh", "outputId": "d612bf27-ecdc-43df-9c36-af50f56259a5" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "x, y = zip(*tsne_words_embedded)\n", "\n", "fig, ax = plt.subplots(figsize=(10, 8))\n", "\n", "for i, label in enumerate(ANIMALS + HOUSEHOLD_OBJECTS):\n", " if label in ANIMALS:\n", " color = BLUE\n", " elif label in HOUSEHOLD_OBJECTS:\n", " color = RED\n", " \n", " ax.scatter(x[i], y[i], c=color)\n", " ax.annotate(label, (x[i], y[i]))\n", "\n", "ax.axis('off')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_iUWIHDD6jMF" }, "source": [ "#### 3D visualization in TensorBoard" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 69 }, "colab_type": "code", "executionInfo": { "elapsed": 14667, "status": "ok", "timestamp": 1573668690689, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "ac9o1CsK6huN", "outputId": "b86328f6-d6b2-4103-8a6d-6b7087cf5dc7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wait for 8 seconds...\n", "TensorBoard link:\n", "https://4d5e2df6.ngrok.io\n" ] } ], "source": [ "tbc=TensorBoardColab()\n", "\n", "vectors = glove_vectors.query(ANIMALS + HOUSEHOLD_OBJECTS)\n", "with SummaryWriter(log_dir=\"Graph/\") as writer:\n", " writer.add_embedding(mat=vectors, metadata=ANIMALS+HOUSEHOLD_OBJECTS)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "IrlNQfk767Pa" }, "outputs": [], "source": [ "# %tensorboard --logdir=Graph/" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "IFfVbmhfVBRl" }, "source": [ "## Context embeddings\n", "\n", "GloVe and Fasttext are two examples of global embeddings, where the embeddings don't change even though the \"sense\" of the word might change given the context. This can be a problem for cases such as:\n", "\n", "- A **mouse** stole some cheese.\n", "- I bought a new **mouse** the other day for my computer.\n", "\n", "The word mouse can mean both an animal and a computer accessory depending on the context, yet for GloVe they would receive the same exact distributed representation. We can combat this by taking into account the surroudning words to create a context-sensitive embedding. Context embeddings such as Bert are really popular right now.\n", "\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "colab_type": "code", "executionInfo": { "elapsed": 26218, "status": "ok", "timestamp": 1573669096189, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "v2Kqxd54VBRm", "outputId": "be676015-97d7-47a5-e69c-890ad30940fd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 231508/231508 [00:00<00:00, 869934.62B/s]\n", "100%|██████████| 407873900/407873900 [00:16<00:00, 24702870.58B/s]\n" ] } ], "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "model = BertModel.from_pretrained('bert-base-uncased')\n", "model.eval()\n", "\n", "def to_bert_embeddings(text, return_tokens=False):\n", " if isinstance(text, list):\n", " # Already tokenized\n", " tokens = tokenizer.tokenize(' '.join(text))\n", " else:\n", " # Need to tokenize\n", " tokens = tokenizer.tokenize(text)\n", " \n", " tokens_with_tags = ['[CLS]'] + tokens + ['[SEP]']\n", " indices = tokenizer.convert_tokens_to_ids(tokens_with_tags)\n", "\n", " out = model(torch.LongTensor(indices).unsqueeze(0))\n", " \n", " # Concatenate the last four layers and use that as the embedding\n", " # source: https://jalammar.github.io/illustrated-bert/\n", " embeddings_matrix = torch.stack(out[0]).squeeze(1)[-4:] # use last 4 layers\n", " embeddings = []\n", " for j in range(embeddings_matrix.shape[1]):\n", " embeddings.append(embeddings_matrix[:, j, :].flatten().detach().numpy())\n", " \n", " # Ignore [CLS] and [SEP]\n", " embeddings = embeddings[1:-1]\n", " \n", " if return_tokens:\n", " assert len(embeddings) == len(tokens)\n", " return embeddings, tokens\n", " \n", " return embeddings" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "W6PAaDILVBRq" }, "outputs": [], "source": [ "words_sentences = [\n", " ('mouse', 'I saw a mouse run off with some cheese.'),\n", " ('mouse', 'I bought a new computer mouse yesterday.'),\n", " ('cat', 'My cat jumped on the bed.'),\n", " ('keyboard', 'My computer keyboard broke when I spilled juice on it.'),\n", " ('dessert', 'I had a banana fudge sunday for dessert.'),\n", " ('dinner', 'What did you eat for dinner?'),\n", " ('lunch', 'Yesterday I had a bacon lettuce tomato sandwich for lunch. It was tasty!'),\n", " ('computer', 'My computer broke after the motherdrive was overloaded.'),\n", " ('program', 'I like to program in Java and Python.'),\n", " ('pasta', 'I like to put tomatoes and cheese in my pasta.'),\n", "]\n", "words = [words_sentence[0] for words_sentence in words_sentences]\n", "sentences = [words_sentence[1] for words_sentence in words_sentences]" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "KVSuEP8fVBRt" }, "outputs": [], "source": [ "embeddings_lst, tokens_lst = zip(*[to_bert_embeddings(sentence, return_tokens=True) for sentence in sentences])\n", "words, tokens_lst, embeddings_lst = zip(*[(word, tokens, embeddings) for word, tokens, embeddings in zip(words, tokens_lst, embeddings_lst) if word in tokens])\n", "\n", "# Convert tuples to lists\n", "words, tokens_lst, tokens_lst = map(list, [words, tokens_lst, tokens_lst])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "SBCrt11cVBRw" }, "outputs": [], "source": [ "target_indices = [tokens.index(word) for word, tokens in zip(words, tokens_lst)]" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "IT7nqNYbVBRz" }, "outputs": [], "source": [ "target_embeddings = [embeddings[idx] for idx, embeddings in zip(target_indices, embeddings_lst)]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 765 }, "colab_type": "code", "executionInfo": { "elapsed": 1543, "status": "ok", "timestamp": 1573669128409, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "_x17Kq7mVBR1", "outputId": "1ae24825-d89c-4e32-c22f-89f9df80b765" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "tsne_words_embedded = TSNE(n_components=2).fit_transform(target_embeddings)\n", "x, y = zip(*tsne_words_embedded)\n", "\n", "fig, ax = plt.subplots(figsize=(5, 10))\n", "\n", "for word, tokens, x_i, y_i in zip(words, tokens_lst, x, y):\n", " ax.scatter(x_i, y_i, c=RED)\n", " ax.annotate(' '.join([f'$\\\\bf{x}$' if x == word else x for x in tokens]), (x_i, y_i))\n", "\n", "ax.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x64xA81sVBR6" }, "source": [ "## Try-it-yourself\n", "\n", "- Use the Magnitude library to load other pretrained embeddings such as Fasttext\n", "- Try comparing the GloVe embeddings with the Fasttext embeddings by making t-SNE plots of both, or checking the similarity scores between the same set of words\n", "- Make t-SNE plots using your own words and categories" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "QDP37tWKVBR7" }, "outputs": [], "source": [ "" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "2_embeddings.ipynb", "provenance": [ { "file_id": "https://github.com/scoutbee/pytorch-nlp-notebooks/blob/add_transformer/2_embeddings.ipynb", "timestamp": 1573667665920 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 0 }