{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UtKtydW6lwCp" }, "source": [ "# Transformer Sequence-to-Sequence Model\n", "\n", "Here we showcase a vanilla transformer model from the paper [\"Attention is all you need\"](https://arxiv.org/pdf/1706.03762.pdf) (Vaswani et al. 2017) build with both encoder and decoder layers trained on English to French translation dataset. \n", "\n", "- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/): A great introduction to learn the detail mechanisms of multi-head self-attention." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "qQtcdT1ryIUe" }, "outputs": [], "source": [ "!pip install pandas==0.24.0 -q" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "sKYN9scCmFO9" }, "outputs": [], "source": [ "# As python modules are not captured in Colab, we manually copy the code to the GPU instance.\n", "\n", "!wget https://github.com/scoutbee/pytorch-nlp-notebooks/archive/develop.zip -q\n", "!unzip -qq develop.zip\n", "!mkdir transformer\n", "!mv pytorch-nlp-notebooks-develop/transformer/* transformer/\n", "!rm develop.zip \n", "!rm -r pytorch-nlp-notebooks-develop" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "N0DgZgVZlwCr" }, "outputs": [], "source": [ "import numpy as np\n", "import copy\n", "import time\n", "from pathlib import Path\n", "\n", "from google_drive_downloader import GoogleDriveDownloader as gdd\n", "from tqdm import tqdm_notebook, tqdm\n", "\n", "# PyTorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch import optim\n", "from torch.utils.data import Dataset, DataLoader, Subset\n", "from torch.utils.data.dataset import random_split\n", "\n", "# Check out the model architecture in transformer folder\n", "from transformer.model import Transformer\n", "from transformer.batch import *\n", "\n", "tqdm.pandas()\n", "\n", "# Show better CUDA error messages\n", "!export CUDA_LAUNCH_BLOCKING=1" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "XmeD9Q-DlwCv" }, "source": [ "In order to perform deep learning on a GPU (so that everything runs super quick!), CUDA has to be installed and configured. Fortunately, Google Colab already has this set up, but if you want to try this on your own GPU, you can [install CUDA from here](https://developer.nvidia.com/cuda-downloads). Make sure you also [install cuDNN](https://developer.nvidia.com/cudnn) for optimized performance." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 13806, "status": "ok", "timestamp": 1573666791266, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "0KXUq4F4lwCv", "outputId": "2e5edbe7-b24b-43c5-bb33-928c318b4fdb" }, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hGd6mJXSlwCy" }, "source": [ "## Download the data\n", "\n", "We will download a dataset of English-to-French translations from a public Google Drive folder." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 15129, "status": "ok", "timestamp": 1573666792598, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "dwOw3oDXlwCz", "outputId": "914b0cc6-a465-4cc4-82b2-59a8ad60ee1c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading 1Jf7QoW2NK6_ayEXZji6DAXDSIRMvapm3 into data/english_to_french.txt... Done.\n" ] } ], "source": [ "DATA_PATH = 'data/english_to_french.txt'\n", "if not Path(DATA_PATH).is_file():\n", " gdd.download_file_from_google_drive(\n", " file_id='1Jf7QoW2NK6_ayEXZji6DAXDSIRMvapm3',\n", " dest_path=DATA_PATH,\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "executionInfo": { "elapsed": 20840, "status": "ok", "timestamp": 1573666798319, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "LcGJyn_UlwC3", "outputId": "bc981027-f46c-4204-d89e-5787243227b3" }, "outputs": [ { "data": { "text/plain": [ "40288" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "dataset = EnglishFrenchTranslations(DATA_PATH, max_vocab=1000, max_seq_len=100)\n", "len(dataset)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Ybh1AV6AlwC5" }, "outputs": [], "source": [ "# Get the indicies of special tokens \n", "SRC_VOCAB = dataset.token2idx_inputs\n", "TRG_VOCAB = dataset.token2idx_targets\n", "src_pad = torch.tensor(SRC_VOCAB[dataset.padding_token]).to(device)\n", "src_sos = torch.tensor(SRC_VOCAB[dataset.start_of_sequence_token]).to(device)\n", "src_eos = torch.tensor(SRC_VOCAB[dataset.end_of_sequence_token]).to(device)\n", "trg_pad = torch.tensor(TRG_VOCAB[dataset.padding_token]).to(device)\n", "trg_sos = torch.tensor(TRG_VOCAB[dataset.start_of_sequence_token]).to(device)\n", "trg_eos = torch.tensor(TRG_VOCAB[dataset.end_of_sequence_token]).to(device)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5xIhg1oOlwC8" }, "source": [ "## Split into training and test set" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "sYJTyIkDlwC8" }, "outputs": [], "source": [ "train_size = int(0.9999 * len(dataset))\n", "test_size = len(dataset) - train_size\n", "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "URixCO7KlwC-" }, "source": [ "## Batching - Create data generators using `DataLoader`" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "S2Dd3Rb1lwC-" }, "outputs": [], "source": [ "batch_size = 256\n", "train_loader = DataLoader(\n", " train_dataset, \n", " batch_size=batch_size, \n", " collate_fn=lambda batch: collate(batch, src_pad, trg_pad, device),\n", ")\n", "\n", "test_loader = DataLoader(\n", " test_dataset, \n", " batch_size=1, \n", " collate_fn=lambda batch: collate(batch, src_pad, trg_pad, device),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "c-pu0LgNlwDA" }, "source": [ "## Define the `Transformer` model\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 64, "resources": { "http://localhost:8080/images/multi_head_attention.png": { "data": "CjwhRE9DVFlQRSBodG1sPgo8aHRtbCBsYW5nPWVuPgogIDxtZXRhIGNoYXJzZXQ9dXRmLTg+CiAgPG1ldGEgbmFtZT12aWV3cG9ydCBjb250ZW50PSJpbml0aWFsLXNjYWxlPTEsIG1pbmltdW0tc2NhbGU9MSwgd2lkdGg9ZGV2aWNlLXdpZHRoIj4KICA8dGl0bGU+RXJyb3IgNDA0IChOb3QgRm91bmQpISExPC90aXRsZT4KICA8c3R5bGU+CiAgICAqe21hcmdpbjowO3BhZGRpbmc6MH1odG1sLGNvZGV7Zm9udDoxNXB4LzIycHggYXJpYWwsc2Fucy1zZXJpZn1odG1se2JhY2tncm91bmQ6I2ZmZjtjb2xvcjojMjIyO3BhZGRpbmc6MTVweH1ib2R5e21hcmdpbjo3JSBhdXRvIDA7bWF4LXdpZHRoOjM5MHB4O21pbi1oZWlnaHQ6MTgwcHg7cGFkZGluZzozMHB4IDAgMTVweH0qID4gYm9keXtiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9lcnJvcnMvcm9ib3QucG5nKSAxMDAlIDVweCBuby1yZXBlYXQ7cGFkZGluZy1yaWdodDoyMDVweH1we21hcmdpbjoxMXB4IDAgMjJweDtvdmVyZmxvdzpoaWRkZW59aW5ze2NvbG9yOiM3Nzc7dGV4dC1kZWNvcmF0aW9uOm5vbmV9YSBpbWd7Ym9yZGVyOjB9QG1lZGlhIHNjcmVlbiBhbmQgKG1heC13aWR0aDo3NzJweCl7Ym9keXtiYWNrZ3JvdW5kOm5vbmU7bWFyZ2luLXRvcDowO21heC13aWR0aDpub25lO3BhZGRpbmctcmlnaHQ6MH19I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LnBuZykgbm8tcmVwZWF0O21hcmdpbi1sZWZ0Oi01cHh9QG1lZGlhIG9ubHkgc2NyZWVuIGFuZCAobWluLXJlc29sdXRpb246MTkyZHBpKXsjbG9nb3tiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSBuby1yZXBlYXQgMCUgMCUvMTAwJSAxMDAlOy1tb3otYm9yZGVyLWltYWdlOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSAwfX1AbWVkaWEgb25seSBzY3JlZW4gYW5kICgtd2Via2l0LW1pbi1kZXZpY2UtcGl4ZWwtcmF0aW86Mil7I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LTJ4LnBuZykgbm8tcmVwZWF0Oy13ZWJraXQtYmFja2dyb3VuZC1zaXplOjEwMCUgMTAwJX19I2xvZ297ZGlzcGxheTppbmxpbmUtYmxvY2s7aGVpZ2h0OjU0cHg7d2lkdGg6MTUwcHh9CiAgPC9zdHlsZT4KICA8YSBocmVmPS8vd3d3Lmdvb2dsZS5jb20vPjxzcGFuIGlkPWxvZ28gYXJpYS1sYWJlbD1Hb29nbGU+PC9zcGFuPjwvYT4KICA8cD48Yj40MDQuPC9iPiA8aW5zPlRoYXTigJlzIGFuIGVycm9yLjwvaW5zPgogIDxwPiAgPGlucz5UaGF04oCZcyBhbGwgd2Uga25vdy48L2lucz4K", "headers": [ [ "content-length", "1449" ], [ "content-type", "text/html; charset=utf-8" ] ], "ok": false, "status": 404, "status_text": "" }, "http://localhost:8080/images/scaled_dot_product_attention.png": { "data": "CjwhRE9DVFlQRSBodG1sPgo8aHRtbCBsYW5nPWVuPgogIDxtZXRhIGNoYXJzZXQ9dXRmLTg+CiAgPG1ldGEgbmFtZT12aWV3cG9ydCBjb250ZW50PSJpbml0aWFsLXNjYWxlPTEsIG1pbmltdW0tc2NhbGU9MSwgd2lkdGg9ZGV2aWNlLXdpZHRoIj4KICA8dGl0bGU+RXJyb3IgNDA0IChOb3QgRm91bmQpISExPC90aXRsZT4KICA8c3R5bGU+CiAgICAqe21hcmdpbjowO3BhZGRpbmc6MH1odG1sLGNvZGV7Zm9udDoxNXB4LzIycHggYXJpYWwsc2Fucy1zZXJpZn1odG1se2JhY2tncm91bmQ6I2ZmZjtjb2xvcjojMjIyO3BhZGRpbmc6MTVweH1ib2R5e21hcmdpbjo3JSBhdXRvIDA7bWF4LXdpZHRoOjM5MHB4O21pbi1oZWlnaHQ6MTgwcHg7cGFkZGluZzozMHB4IDAgMTVweH0qID4gYm9keXtiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9lcnJvcnMvcm9ib3QucG5nKSAxMDAlIDVweCBuby1yZXBlYXQ7cGFkZGluZy1yaWdodDoyMDVweH1we21hcmdpbjoxMXB4IDAgMjJweDtvdmVyZmxvdzpoaWRkZW59aW5ze2NvbG9yOiM3Nzc7dGV4dC1kZWNvcmF0aW9uOm5vbmV9YSBpbWd7Ym9yZGVyOjB9QG1lZGlhIHNjcmVlbiBhbmQgKG1heC13aWR0aDo3NzJweCl7Ym9keXtiYWNrZ3JvdW5kOm5vbmU7bWFyZ2luLXRvcDowO21heC13aWR0aDpub25lO3BhZGRpbmctcmlnaHQ6MH19I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LnBuZykgbm8tcmVwZWF0O21hcmdpbi1sZWZ0Oi01cHh9QG1lZGlhIG9ubHkgc2NyZWVuIGFuZCAobWluLXJlc29sdXRpb246MTkyZHBpKXsjbG9nb3tiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSBuby1yZXBlYXQgMCUgMCUvMTAwJSAxMDAlOy1tb3otYm9yZGVyLWltYWdlOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSAwfX1AbWVkaWEgb25seSBzY3JlZW4gYW5kICgtd2Via2l0LW1pbi1kZXZpY2UtcGl4ZWwtcmF0aW86Mil7I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LTJ4LnBuZykgbm8tcmVwZWF0Oy13ZWJraXQtYmFja2dyb3VuZC1zaXplOjEwMCUgMTAwJX19I2xvZ297ZGlzcGxheTppbmxpbmUtYmxvY2s7aGVpZ2h0OjU0cHg7d2lkdGg6MTUwcHh9CiAgPC9zdHlsZT4KICA8YSBocmVmPS8vd3d3Lmdvb2dsZS5jb20vPjxzcGFuIGlkPWxvZ28gYXJpYS1sYWJlbD1Hb29nbGU+PC9zcGFuPjwvYT4KICA8cD48Yj40MDQuPC9iPiA8aW5zPlRoYXTigJlzIGFuIGVycm9yLjwvaW5zPgogIDxwPiAgPGlucz5UaGF04oCZcyBhbGwgd2Uga25vdy48L2lucz4K", "headers": [ [ "content-length", "1449" ], [ "content-type", "text/html; charset=utf-8" ] ], "ok": false, "status": 404, "status_text": "" }, "http://localhost:8080/images/transformer.png": { "data": "CjwhRE9DVFlQRSBodG1sPgo8aHRtbCBsYW5nPWVuPgogIDxtZXRhIGNoYXJzZXQ9dXRmLTg+CiAgPG1ldGEgbmFtZT12aWV3cG9ydCBjb250ZW50PSJpbml0aWFsLXNjYWxlPTEsIG1pbmltdW0tc2NhbGU9MSwgd2lkdGg9ZGV2aWNlLXdpZHRoIj4KICA8dGl0bGU+RXJyb3IgNDA0IChOb3QgRm91bmQpISExPC90aXRsZT4KICA8c3R5bGU+CiAgICAqe21hcmdpbjowO3BhZGRpbmc6MH1odG1sLGNvZGV7Zm9udDoxNXB4LzIycHggYXJpYWwsc2Fucy1zZXJpZn1odG1se2JhY2tncm91bmQ6I2ZmZjtjb2xvcjojMjIyO3BhZGRpbmc6MTVweH1ib2R5e21hcmdpbjo3JSBhdXRvIDA7bWF4LXdpZHRoOjM5MHB4O21pbi1oZWlnaHQ6MTgwcHg7cGFkZGluZzozMHB4IDAgMTVweH0qID4gYm9keXtiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9lcnJvcnMvcm9ib3QucG5nKSAxMDAlIDVweCBuby1yZXBlYXQ7cGFkZGluZy1yaWdodDoyMDVweH1we21hcmdpbjoxMXB4IDAgMjJweDtvdmVyZmxvdzpoaWRkZW59aW5ze2NvbG9yOiM3Nzc7dGV4dC1kZWNvcmF0aW9uOm5vbmV9YSBpbWd7Ym9yZGVyOjB9QG1lZGlhIHNjcmVlbiBhbmQgKG1heC13aWR0aDo3NzJweCl7Ym9keXtiYWNrZ3JvdW5kOm5vbmU7bWFyZ2luLXRvcDowO21heC13aWR0aDpub25lO3BhZGRpbmctcmlnaHQ6MH19I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LnBuZykgbm8tcmVwZWF0O21hcmdpbi1sZWZ0Oi01cHh9QG1lZGlhIG9ubHkgc2NyZWVuIGFuZCAobWluLXJlc29sdXRpb246MTkyZHBpKXsjbG9nb3tiYWNrZ3JvdW5kOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSBuby1yZXBlYXQgMCUgMCUvMTAwJSAxMDAlOy1tb3otYm9yZGVyLWltYWdlOnVybCgvL3d3dy5nb29nbGUuY29tL2ltYWdlcy9sb2dvcy9lcnJvcnBhZ2UvZXJyb3JfbG9nby0xNTB4NTQtMngucG5nKSAwfX1AbWVkaWEgb25seSBzY3JlZW4gYW5kICgtd2Via2l0LW1pbi1kZXZpY2UtcGl4ZWwtcmF0aW86Mil7I2xvZ297YmFja2dyb3VuZDp1cmwoLy93d3cuZ29vZ2xlLmNvbS9pbWFnZXMvbG9nb3MvZXJyb3JwYWdlL2Vycm9yX2xvZ28tMTUweDU0LTJ4LnBuZykgbm8tcmVwZWF0Oy13ZWJraXQtYmFja2dyb3VuZC1zaXplOjEwMCUgMTAwJX19I2xvZ297ZGlzcGxheTppbmxpbmUtYmxvY2s7aGVpZ2h0OjU0cHg7d2lkdGg6MTUwcHh9CiAgPC9zdHlsZT4KICA8YSBocmVmPS8vd3d3Lmdvb2dsZS5jb20vPjxzcGFuIGlkPWxvZ28gYXJpYS1sYWJlbD1Hb29nbGU+PC9zcGFuPjwvYT4KICA8cD48Yj40MDQuPC9iPiA8aW5zPlRoYXTigJlzIGFuIGVycm9yLjwvaW5zPgogIDxwPiAgPGlucz5UaGF04oCZcyBhbGwgd2Uga25vdy48L2lucz4K", "headers": [ [ "content-length", "1449" ], [ "content-type", "text/html; charset=utf-8" ] ], "ok": false, "status": 404, "status_text": "" } } }, "colab_type": "code", "executionInfo": { "elapsed": 24417, "status": "ok", "timestamp": 1573666801925, "user": { "displayName": "Jeffrey Hsu", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mCITqjB_-x31R-SfFCiChG69Qj2xNbcXl_P3vxw=s64", "userId": "09103891542297935234" }, "user_tz": -60 }, "id": "6uNID0COlwDB", "outputId": "c882a5f4-2162-4801-8d4a-7ce2c8aac844" }, "outputs": [ { "data": { "text/html": [ "