{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "SvzbCir3NuHY", "toc": true }, "source": [ "

Table of Contents

\n", "
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 279 }, "colab_type": "code", "id": "TctW1rEMvcx_", "outputId": "3a6f7a0d-a408-493e-c967-a64be19f9768" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# code for loading the format for the notebook\n", "import os\n", "\n", "# path : store the current path to convert back to it later\n", "path = os.getcwd()\n", "os.chdir(os.path.join('..', '..', 'notebook_format'))\n", "\n", "from formats import load_style\n", "load_style(css_style='custom2.css', plot_style=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 190 }, "colab_type": "code", "id": "0LSlT1mCu7QI", "outputId": "b33d802c-ec04-4b51-c81f-f2868742b1f2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ethen 2020-01-07 21:44:32 \n", "\n", "CPython 3.6.4\n", "IPython 7.9.0\n", "\n", "numpy 1.16.5\n", "torch 1.3.1\n", "torchtext 0.4.0\n", "spacy 2.1.6\n" ] } ], "source": [ "os.chdir(path)\n", "\n", "# 1. magic for inline plot\n", "# 2. magic to print version\n", "# 3. magic so that the notebook will reload external python modules\n", "# 4. magic to enable retina (high resolution) plots\n", "# https://gist.github.com/minrk/3301035\n", "%matplotlib inline\n", "%load_ext watermark\n", "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format='retina'\n", "\n", "import os\n", "import math\n", "import time\n", "import spacy\n", "import torch\n", "import random\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from typing import List\n", "from torchtext.datasets import Multi30k\n", "from torchtext.data import Field, BucketIterator\n", "\n", "%watermark -a 'Ethen' -d -t -v -p numpy,torch,torchtext,spacy" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "56MoW7GvAhmx" }, "source": [ "# Seq2Seq" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "a4Ksc6nau7QO" }, "source": [ "**Seq2Seq (Sequence to Sequence)** is a many to many network where two neural networks, one encoder and one decoder work together to transform one sequence to another. The core highlight of this method is having no restrictions on the length of the source and target sequence. At a high-level, the way it works is:\n", "\n", "- The encoder network condenses an input sequence into a vector, this vector is a smaller dimensional representation and is often referred to as the context/thought vector. This thought vector is served as an abstract representation for the entire input sequence.\n", "- The decoder network takes in that thought vector and unfolds that vector into the output sequence.\n", "\n", "The main use case includes:\n", "\n", "- chatbots\n", "- text summarization\n", "- speech recognition\n", "- image captioning\n", "- machine translation\n", "\n", "In this notebook, we'll be implementing the seq2seq model ourselves using Pytorch and use it in the context of German to English translations." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "SObZK6KONuHh" }, "source": [ "## Seq2Seq Introduction" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WPRrmCJhNuHh" }, "source": [ "The following sections are heavily \"borrowed\" from the wonderful tutorial on this topic listed below.\n", "\n", "- [Jupyter Notebook: Sequence to Sequence Learning with Neural Networks](https://nbviewer.jupyter.org/github/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb)\n", "\n", "Some personal preference modifications have been made.\n", "\n", "\n", "\n", "\n", "The above image shows an example translation. The input/source sentence, \"guten morgen\", is input into the encoder (green) one word at a time. We also append a *start of sequence* (``) and *end of sequence* (``) token to the start and end of sentence, respectively. At each time-step, the input to the encoder is both the current word, $x_t$, as well as the hidden state from the previous time-step, $h_{t-1}$, and the encoder outputs a new hidden state $h_t$. We can think of the hidden state as a vector representation of the sentence so far. The can be represented as a function of both of $x_t$ and $h_{t-1}$:\n", "\n", "$$h_t = \\text{Encoder}(x_t, h_{t-1})$$\n", "\n", "We're using the term encoder loosely here, in practice, it can be any type of architecture, the most common ones being RNN-type network such as *LSTM* (Long Short-Term Memory) or a *GRU* (Gated Recurrent Unit). \n", "\n", "Here, we have $X = \\{x_1, x_2, ..., x_T\\}$, where $x_1 = \\text{}, x_2 = \\text{guten}$, etc. The initial hidden state, $h_0$, is usually either initialized to zeros or a learned parameter.\n", "\n", "Once the final word, $x_T$, has been passed into the encoder, we use the final hidden state, $h_T$, as the context vector, i.e. $h_T = z$. This is a vector representation of the entire source sentence.\n", "\n", "Now we have our context vector, $z$, we can start decoding it to get the target sentence, \"good morning\". Again, we append the start and end of sequence tokens to the target sentence. At each time-step, the input to the decoder (blue) is the current word, $y_t$, as well as the hidden state from the previous time-step, $s_{t-1}$, where the initial decoder hidden state, $s_0$, is the context vector, $s_0 = z = h_T$, i.e. the initial decoder hidden state is the final encoder hidden state. Thus, similar to the encoder, we can represent the decoder as:\n", "\n", "$$s_t = \\text{Decoder}(y_t, s_{t-1})$$\n", "\n", "In the decoder, we need to go from the hidden state to an actual word, therefore at each time-step we use $s_t$ to predict (by passing it through a `Linear` layer, shown in purple) what we think is the next word in the sequence, $\\hat{y}_t$. \n", "\n", "$$\\hat{y}_t = f(s_t)$$\n", "\n", "The words in the decoder are always generated one after another, with one per time-step. We always use `` for the first input to the decoder, $y_1$, but for subsequent inputs, $y_{t>1}$, we will sometimes use the actual, ground truth next word in the sequence, $y_t$ and sometimes use the word predicted by our decoder, $\\hat{y}_{t-1}$. This is called **teacher forcing**, which we'll later see in action.\n", "\n", "When training/testing our model, we always know how many words are in our target sentence, so we stop generating words once we hit that many. During inference (i.e. real world usage) it is common to keep generating words until the model outputs an `` token or after a certain amount of words have been generated.\n", "\n", "Once we have our predicted target sentence, $\\hat{Y} = \\{ \\hat{y}_1, \\hat{y}_2, ..., \\hat{y}_T \\}$, we compare it against our actual target sentence, $Y = \\{ y_1, y_2, ..., y_T \\}$, to calculate our loss. We then use this loss to update all of the parameters in our model." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yol-A_GPu7QP" }, "source": [ "## Data Preparation" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Rhuxf0XFNuHi" }, "source": [ "We'll be coding up the models in PyTorch and using TorchText to help us do all of the pre-processing required. We'll also be using spaCy to assist in the tokenization of the data. We will introduce the functionalities some these libraries along the way as well." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "v2XqXgDvNuHj", "outputId": "727fd0c1-5cbe-4a3c-ca9e-681e0d0d6926" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SEED = 2222\n", "random.seed(SEED)\n", "torch.manual_seed(SEED)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "LeR8tiHyNuHl" }, "source": [ "The next two code chunks:\n", "\n", "- Downloads the spacy model for the German and English language.\n", "- Create the tokenizer functions, which will take in the sentence as the input and return the sentence as a list of tokens. These functions can then be passed to torchtext." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "BidHjuzuvzOb" }, "outputs": [], "source": [ "# !python -m spacy download de\n", "# !python -m spacy download en" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "PqNyAnYDu7QX" }, "outputs": [], "source": [ "# the link below contains explanation of how spacy's tokenization works\n", "# https://spacy.io/usage/spacy-101#annotations-token\n", "spacy_de = spacy.load('de_core_news_sm')\n", "spacy_en = spacy.load('en_core_web_sm')\n", "\n", "\n", "def tokenize_de(text: str) -> List[str]:\n", " return [tok.text for tok in spacy_de.tokenizer(text)][::-1]\n", "\n", "def tokenize_en(text: str) -> List[str]:\n", " return [tok.text for tok in spacy_en.tokenizer(text)]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "l2jD0Ha-NuHq", "outputId": "6c77f7fc-bfb6-4175-db8f-553ff9b1469e" }, "outputs": [ { "data": { "text/plain": [ "['I', 'do', \"n't\", 'like', 'apple', '.']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"I don't like apple.\"\n", "tokenize_en(text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tokenizer is language specific, e.g. it knows that in the English language don't should be tokenized into do not (n't).\n", "\n", "Another thing to note is that **the order of the source sentence is reversed during the tokenization process**. The rationale behind things comes from the original seq2seq paper where they identified that this trick improved the result of their model.\n", "\n", "> Normally, when we concatenate a source sentence with a target sentence, each word in the source sentence is far from its corresponding word in the target sentence. By reversing the source sentence, the first few words in the source sentence now becomes very close to the first few words in the target sentence, thus the model would have lesser issue establishing communication between the source and target sentence.\n", "> Although, the average distance between words in the source and target language remains the same during this process, however, it was shown that the model learned much better even on later parts of the sentence." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Declaring Fields" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "v3e8UrLVNuHr" }, "source": [ "Moving on, we will begin leveraging torchtext's functionality. The first once is [`Field`](https://pytorch.org/text/data.html#field), which is where we specify how we wish to preprocess our text data for a certain field.\n", "\n", "Here, we set the `tokenize` argument to the correct tokenization function for the source and target field, with German being the source field and English being the target field. The field also appends the \"start of sequence\" and \"end of sequence\" tokens via the `init_token` and `eos_token` arguments, and converts all words to lowercase. The docstring of the `Field` object is pretty well-written, please refer to it to see other arguments that it takes in." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "mrmtdyldu7QZ" }, "outputs": [], "source": [ "source = Field(tokenize=tokenize_de, init_token='', eos_token='', lower=True)\n", "target = Field(tokenize=tokenize_en, init_token='', eos_token='', lower=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Constructing Dataset" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iH7CvaYzNuHu" }, "source": [ "We've defined the logic of processing our raw text data, now we need to tell the fields what data it should work on. This is where `Dataset` comes in. The dataset we'll be using is the [Multi30k dataset](https://pytorch.org/text/datasets.html#multi30k). This is a dataset with ~30,000 parallel English, German and French sentences, each with ~12 words per sentence. Torchtext comes with a capability for us to download and load the training, validation and test data.\n", "\n", "`exts` specifies which languages to use as the source and target (source goes first) and `fields` specifies which field to use for the source and target." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 173 }, "colab_type": "code", "id": "JY202h93u7Qc", "outputId": "9816cfb8-9378-4a46-d326-0734920bb7ed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of training examples: 29000\n", "Number of validation examples: 1014\n", "Number of testing examples: 1000\n" ] } ], "source": [ "train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(source, target))\n", "print(f\"Number of training examples: {len(train_data.examples)}\")\n", "print(f\"Number of validation examples: {len(valid_data.examples)}\")\n", "print(f\"Number of testing examples: {len(test_data.examples)}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0gzRvy1qNuHx" }, "source": [ "Upon loading the dataset, we can indexed and iterate over the `Dataset` like a normal list. Each element in the dataset bundles the attributes of a single record for us. We can index our dataset like a list and then access the `.src` and `.trg` attribute to take a look at the tokenized source and target sentence." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['.',\n", " 'büsche',\n", " 'vieler',\n", " 'nähe',\n", " 'der',\n", " 'in',\n", " 'freien',\n", " 'im',\n", " 'sind',\n", " 'männer',\n", " 'weiße',\n", " 'junge',\n", " 'zwei']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# equivalent, albeit more verbiage train_data.examples[0].src\n", "train_data[0].src" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 207 }, "colab_type": "code", "id": "nI3Re4TSu7Ql", "outputId": "280981e6-908d-4976-face-266e05beda11" }, "outputs": [ { "data": { "text/plain": [ "['two',\n", " 'young',\n", " ',',\n", " 'white',\n", " 'males',\n", " 'are',\n", " 'outside',\n", " 'near',\n", " 'many',\n", " 'bushes',\n", " '.']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data[0].trg" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "sIHr_teXNuH1" }, "source": [ "The next missing piece is to build the vocabulary for the source and target languages. That way we can convert our tokenized tokens into integers so that they can be fed into downstream models. Constructing the vocabulary and word to integer mapping is done by calling the `build_vocab` method of a `Field` on a dataset. This adds the `vocab` attribute to the field.\n", "\n", "The vocabularies of the source and target languages are distinct. Using the `min_freq` argument, we only allow tokens that appear at least 2 times to appear in our vocabulary. Tokens that appear only once are converted into an `` (unknown) token (we can customize this in the Field earlier if we like).\n", "\n", "It is important to note that our vocabulary should only be built from the training set and not the validation/test set. This prevents \"information leakage\" into our model, giving us artificially inflated validation/test scores." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "colab_type": "code", "id": "1FXsfBuVu7Qp", "outputId": "e4ec0bd3-fd95-4580-c770-9874d9fc0dc8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unique tokens in source (de) vocabulary: 7855\n", "Unique tokens in target (en) vocabulary: 5893\n" ] } ], "source": [ "source.build_vocab(train_data, min_freq=2)\n", "target.build_vocab(train_data, min_freq=2)\n", "print(f\"Unique tokens in source (de) vocabulary: {len(source.vocab)}\")\n", "print(f\"Unique tokens in target (en) vocabulary: {len(target.vocab)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Constructing Iterator" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WUERCPo6NuH3" }, "source": [ "The final step of preparing the data is to create the iterators. Very similar to `DataLoader` in the standard pytorch package, `Iterator` in torchtext converts our data into batches, so that they can be fed into the model. These can be iterated on to return a batch of data which will have a `src` and `trg` attribute (PyTorch tensors containing a batch of numericalized source and target sentences). Numericalized is just a fancy way of saying they have been converted from a sequence of tokens to a sequence of corresponding indices, where the mapping between the tokens and indices comes from the learned vocabulary. \n", "\n", "When we get a batch of examples using an iterator we need to make sure that all of the source sentences are padded to the same length, the same with the target sentences. Luckily, torchtext iterators handle this for us! `BucketIterator` is a extremely useful torchtext feature. It automatically shuffles and buckets the input sequences into sequences of similar length, this minimizes the amount of padding that we need to perform." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "id": "EgdSzIrvu7Qw" }, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "\n", "# pytorch boilerplate that determines whether a GPU is present or not,\n", "# this determines whether our dataset or model can to moved to a GPU\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "# create batches out of the dataset and sends them to the appropriate device\n", "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", " (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 86 }, "colab_type": "code", "id": "YAhFBwDeu7Qy", "outputId": "4452fa8b-a53a-4222-ed10-63cc29283a48" }, "outputs": [ { "data": { "text/plain": [ "\n", "[torchtext.data.batch.Batch of size 128 from MULTI30K]\n", "\t[.src]:[torch.LongTensor of size 10x128]\n", "\t[.trg]:[torch.LongTensor of size 14x128]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# pretend that we're iterating over the iterator and print out the print element\n", "test_batch = next(iter(test_iterator))\n", "test_batch" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 2, 2, 2, ..., 2, 2, 2],\n", " [ 4, 4, 4, ..., 4, 714, 4],\n", " [ 123, 91, 3449, ..., 669, 12, 1643],\n", " ...,\n", " [6788, 41, 26, ..., 1, 1, 1],\n", " [ 18, 105, 5, ..., 1, 1, 1],\n", " [ 3, 3, 3, ..., 1, 1, 1]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_batch.src" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DJk8V1JCNuH9" }, "source": [ "We can list out the first batch, we see each element of the iterator is a `Batch` object, similar to element of a `Dataset`, we can access the fields via its attributes. The next important thing to note that it is of size [sentence length, batch size], and the longest sentence in the first batch of the source language has a length of 10." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RSp_0hvMu7Q7" }, "source": [ "## Seq2Seq Implementation" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": {}, "colab_type": "code", "id": "Kf46A1-cNuH-" }, "outputs": [], "source": [ "# adjustable parameters\n", "INPUT_DIM = len(source.vocab)\n", "OUTPUT_DIM = len(target.vocab)\n", "ENC_EMB_DIM = 256\n", "DEC_EMB_DIM = 256\n", "HID_DIM = 512\n", "N_LAYERS = 2\n", "ENC_DROPOUT = 0.5\n", "DEC_DROPOUT = 0.5" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9YHCdUyhNuIA" }, "source": [ "To define our seq2seq model, we first specify the encoder and decoder separately." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "y-TxYAFmNuIB" }, "source": [ "### Encoder Module" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", "id": "oPtj49Giu7Q8" }, "outputs": [], "source": [ "class Encoder(nn.Module):\n", " \"\"\"\n", " Input :\n", " - source batch\n", " Layer : \n", " source batch -> Embedding -> LSTM\n", " Output :\n", " - LSTM hidden state\n", " - LSTM cell state\n", "\n", " Parmeters\n", " ---------\n", " input_dim : int\n", " Input dimension, should equal to the source vocab size.\n", " \n", " emb_dim : int\n", " Embedding layer's dimension.\n", " \n", " hid_dim : int\n", " LSTM Hidden/Cell state's dimension.\n", " \n", " n_layers : int\n", " Number of LSTM layers.\n", " \n", " dropout : float\n", " Dropout for the LSTM layer.\n", " \"\"\"\n", "\n", " def __init__(self, input_dim: int, emb_dim: int, hid_dim: int, n_layers: int, dropout: float):\n", " super().__init__()\n", " self.emb_dim = emb_dim\n", " self.hid_dim = hid_dim\n", " self.input_dim = input_dim\n", " self.n_layers = n_layers\n", " self.dropout = dropout\n", "\n", " self.embedding = nn.Embedding(input_dim, emb_dim)\n", " self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)\n", "\n", " def forward(self, src_batch: torch.LongTensor):\n", " \"\"\"\n", "\n", " Parameters\n", " ----------\n", " src_batch : 2d torch.LongTensor\n", " Batched tokenized source sentence of shape [sent len, batch size].\n", "\n", " Returns\n", " -------\n", " hidden, cell : 3d torch.LongTensor\n", " Hidden and cell state of the LSTM layer. Each state's shape\n", " [n layers * n directions, batch size, hidden dim]\n", " \"\"\"\n", " embedded = self.embedding(src_batch) # [sent len, batch size, emb dim]\n", " outputs, (hidden, cell) = self.rnn(embedded)\n", " # outputs -> [sent len, batch size, hidden dim * n directions]\n", " return hidden, cell" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "EQWgm8N_u7Q_", "outputId": "bd7ccbf9-30e9-4328-aed7-1b9515746a72" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 128, 512]), torch.Size([2, 128, 512]))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)\n", "hidden, cell = encoder(test_batch.src)\n", "hidden.shape, cell.shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4jHc__y-NuIG" }, "source": [ "### Decoder Module" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ampEHPmZNuIG" }, "source": [ "The decoder accept a batch of input tokens, previous hidden states and previous cell states. Note that in the decoder module, we are only decoding one token at a time, the input tokens will always have a sequence length of 1. This is different from the encoder module where we encode the entire source sentence all at once." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": {}, "colab_type": "code", "id": "-37LnQSyu7RC" }, "outputs": [], "source": [ "class Decoder(nn.Module):\n", " \"\"\"\n", " Input :\n", " - first token in the target batch\n", " - LSTM hidden state from the encoder\n", " - LSTM cell state from the encoder\n", " Layer :\n", " target batch -> Embedding -- \n", " |\n", " encoder hidden state ------|--> LSTM -> Linear\n", " |\n", " encoder cell state -------\n", " \n", " Output :\n", " - prediction\n", " - LSTM hidden state\n", " - LSTM cell state\n", "\n", " Parmeters\n", " ---------\n", " output : int\n", " Output dimension, should equal to the target vocab size.\n", " \n", " emb_dim : int\n", " Embedding layer's dimension.\n", " \n", " hid_dim : int\n", " LSTM Hidden/Cell state's dimension.\n", " \n", " n_layers : int\n", " Number of LSTM layers.\n", " \n", " dropout : float\n", " Dropout for the LSTM layer.\n", " \"\"\"\n", "\n", " def __init__(self, output_dim: int, emb_dim: int, hid_dim: int, n_layers: int, dropout: float):\n", " super().__init__()\n", " self.emb_dim = emb_dim\n", " self.hid_dim = hid_dim\n", " self.output_dim = output_dim\n", " self.n_layers = n_layers\n", " self.dropout = dropout\n", "\n", " self.embedding = nn.Embedding(output_dim, emb_dim)\n", " self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)\n", " self.out = nn.Linear(hid_dim, output_dim)\n", "\n", " def forward(self, trg: torch.LongTensor, hidden: torch.FloatTensor, cell: torch.FloatTensor):\n", " \"\"\"\n", "\n", " Parameters\n", " ----------\n", " trg : 1d torch.LongTensor\n", " Batched tokenized source sentence of shape [batch size].\n", " \n", " hidden, cell : 3d torch.FloatTensor\n", " Hidden and cell state of the LSTM layer. Each state's shape\n", " [n layers * n directions, batch size, hidden dim]\n", "\n", " Returns\n", " -------\n", " prediction : 2d torch.LongTensor\n", " For each token in the batch, the predicted target vobulary.\n", " Shape [batch size, output dim]\n", "\n", " hidden, cell : 3d torch.FloatTensor\n", " Hidden and cell state of the LSTM layer. Each state's shape\n", " [n layers * n directions, batch size, hidden dim]\n", " \"\"\"\n", " # [1, batch size, emb dim], the 1 serves as sent len\n", " embedded = self.embedding(trg.unsqueeze(0))\n", " outputs, (hidden, cell) = self.rnn(embedded, (hidden, cell))\n", " prediction = self.out(outputs.squeeze(0))\n", " return prediction, hidden, cell" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "14up-mqQu7RF", "outputId": "0a785245-cb95-46d7-fa40-0f466a8c4b10" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([128, 5893]), torch.Size([2, 128, 512]), torch.Size([2, 128, 512]))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT).to(device)\n", "\n", "# notice that we are not passing the entire the .trg\n", "prediction, hidden, cell = decoder(test_batch.trg[0], hidden, cell)\n", "prediction.shape, hidden.shape, cell.shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "K7eHyywYNuIK" }, "source": [ "### Seq2Seq Module" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MLXDDIsKNuIL" }, "source": [ "For the final part of the implementation, we'll implement the seq2seq model. This will handle: \n", "\n", "- receiving the input/source sentence\n", "- using the encoder to produce the context vectors \n", "- using the decoder to produce the predicted output/target sentence\n", "\n", "The `Seq2Seq` model takes in an `Encoder`, `Decoder`, and a `device` (used to place tensors on the GPU, if it exists).\n", "\n", "For this implementation, we have to ensure that the number of layers and the hidden (and cell) dimensions are equal in the `Encoder` and `Decoder`. This is not always the case, as we do not necessarily need the same number of layers or the same hidden dimension sizes in a sequence-to-sequence model. However, if we do have a different number of layers we will need to make decisions about how this is handled. For example, if our encoder has 2 layers and our decoder only has 1, how is this handled? Do we average the two context vectors output by the decoder? Do we pass both through a linear layer? Do we only use the context vector from the highest layer? etc.\n", "\n", "Our `forward` method takes the source sentence, target sentence and a teacher-forcing ratio. The teacher forcing ratio is used when training our model. When decoding, at each time-step we will predict what the next token in the target sequence will be from the previous tokens decoded. With probability equal to the teaching forcing ratio (`teacher_forcing_ratio`) we will use the actual ground-truth next token in the sequence as the input to the decoder during the next time-step. However, with probability `1 - teacher_forcing_ratio`, we will use the token that the model predicted as the next input to the model, even if it doesn't match the actual next token in the sequence. Note that the teacher forcing ratio is only done during training and should be shut off during evaluation.\n", "\n", "The first thing we do in the `forward` method is to create an `outputs` tensor that will store all of our predictions, $\\hat{Y}$.\n", "\n", "We then feed the input/source sentence, $X$/`src`, into the encoder and receive our final hidden and cell states.\n", "\n", "The first input to the decoder is the start of sequence (``) token. As our `trg` tensor already has the `` token appended (all the way back when we defined the `init_token` in our target field) we get our $y_1$ by slicing into it. We know how long our target sentences should be (`max_len`), so we loop that many times. During each iteration of the loop, we:\n", "- pass the input, previous hidden and previous cell states ($y_t, s_{t-1}, c_{t-1}$) into the decoder\n", "- receive a prediction, next hidden state and next cell state ($\\hat{y}_{t+1}, s_{t}, c_{t}$) from the decoder\n", "- place our prediction, $\\hat{y}_{t+1}$/`output` in our tensor of predictions, $\\hat{Y}$/`outputs`\n", "- decide if we are going to \"teacher force\" or not\n", " - if we do, the next `input` is the ground-truth next token in the sequence, $y_{t+1}$/`trg[t]`\n", " - if we don't, the next `input` is the predicted next token in the sequence, $\\hat{y}_{t+1}$/`top1`, which we get by doing an `argmax` over the output tensor\n", " \n", "Once we've made all of our predictions, we return our tensor full of predictions, $\\hat{Y}$/`outputs`.\n", "\n", "**Note**: our decoder loop starts at 1, not 0. This means the 0th element of our `outputs` tensor remains all zeros. So our `trg` and `outputs` look something like:\n", "\n", "$$\\begin{align*}\n", "\\text{trg} = [, &y_1, y_2, y_3, ]\\\\\n", "\\text{outputs} = [0, &\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", "\\end{align*}$$\n", "\n", "Later on when we calculate the loss, we cut off the first element of each tensor to get:\n", "\n", "$$\\begin{align*}\n", "\\text{trg} = [&y_1, y_2, y_3, ]\\\\\n", "\\text{outputs} = [&\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", "\\end{align*}$$\n", "\n", "All of this should make more sense after we look at the code in the next few section. Feel free to check out the discussion in these two github issues for some more context with this topic. [issue-45](https://github.com/bentrevett/pytorch-seq2seq/issues/45) and [issue-46](https://github.com/bentrevett/pytorch-seq2seq/issues/46)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": {}, "colab_type": "code", "id": "pHotrPb9u7RJ" }, "outputs": [], "source": [ "class Seq2Seq(nn.Module):\n", " def __init__(self, encoder: Encoder, decoder: Decoder, device: torch.device):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.device = device\n", "\n", " assert encoder.hid_dim == decoder.hid_dim, \\\n", " 'Hidden dimensions of encoder and decoder must be equal!'\n", " assert encoder.n_layers == decoder.n_layers, \\\n", " 'Encoder and decoder must have equal number of layers!'\n", "\n", " def forward(self, src_batch: torch.LongTensor, trg_batch: torch.LongTensor,\n", " teacher_forcing_ratio: float=0.5):\n", "\n", " max_len, batch_size = trg_batch.shape\n", " trg_vocab_size = self.decoder.output_dim\n", "\n", " # tensor to store decoder's output\n", " outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)\n", "\n", " # last hidden & cell state of the encoder is used as the decoder's initial hidden state\n", " hidden, cell = self.encoder(src_batch)\n", "\n", " trg = trg_batch[0]\n", " for i in range(1, max_len):\n", " prediction, hidden, cell = self.decoder(trg, hidden, cell)\n", " outputs[i] = prediction\n", "\n", " if random.random() < teacher_forcing_ratio:\n", " trg = trg_batch[i]\n", " else:\n", " trg = prediction.argmax(1)\n", "\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 207 }, "colab_type": "code", "id": "vJCpmflDu7RM", "outputId": "ee8de530-9c16-4326-a2fe-309a9504c148" }, "outputs": [ { "data": { "text/plain": [ "Seq2Seq(\n", " (encoder): Encoder(\n", " (embedding): Embedding(7855, 256)\n", " (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n", " )\n", " (decoder): Decoder(\n", " (embedding): Embedding(5893, 256)\n", " (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n", " (out): Linear(in_features=512, out_features=5893, bias=True)\n", " )\n", ")" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# note that this implementation assumes that the size of the hidden layer,\n", "# and the number of layer are the same between the encoder and decoder\n", "encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)\n", "decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)\n", "seq2seq = Seq2Seq(encoder, decoder, device).to(device)\n", "seq2seq" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "AaYWJV8Qu7RP", "outputId": "f7103afa-564d-4249-95cf-e9dbbc589e7a" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([14, 128, 5893])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outputs = seq2seq(test_batch.src, test_batch.trg)\n", "outputs.shape" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "5nG9tXrDu7RW", "outputId": "e699554e-f40d-489b-af57-44540bec5d8f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The model has 13,899,013 trainable parameters\n" ] } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "print(f'The model has {count_parameters(seq2seq):,} trainable parameters')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bg1sODOKNuIS" }, "source": [ "### Training Seq2Seq" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4DuuKjl8NuIT" }, "source": [ "We've done the hard work of defining our seq2seq module. The final touch is to specify the training/evaluation loop." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": {}, "colab_type": "code", "id": "2bN056hyu7RY" }, "outputs": [], "source": [ "optimizer = optim.Adam(seq2seq.parameters())\n", "\n", "# ignore the padding index when calculating the loss\n", "PAD_IDX = target.vocab.stoi['']\n", "criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": {}, "colab_type": "code", "id": "WZxHYlCJu7Ra" }, "outputs": [], "source": [ "def train(seq2seq, iterator, optimizer, criterion):\n", " seq2seq.train()\n", "\n", " epoch_loss = 0\n", " for batch in iterator:\n", " optimizer.zero_grad()\n", " outputs = seq2seq(batch.src, batch.trg)\n", "\n", " # 1. as mentioned in the seq2seq section, we will\n", " # cut off the first element when performing the evaluation\n", " # 2. the loss function only works on 2d inputs\n", " # with 1d targets we need to flatten each of them\n", " outputs_flatten = outputs[1:].view(-1, outputs.shape[-1])\n", " trg_flatten = batch.trg[1:].view(-1)\n", " loss = criterion(outputs_flatten, trg_flatten)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " epoch_loss += loss.item()\n", "\n", " return epoch_loss / len(iterator)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": {}, "colab_type": "code", "id": "8LiGyozLu7Rc" }, "outputs": [], "source": [ "def evaluate(seq2seq, iterator, criterion):\n", " seq2seq.eval()\n", "\n", " epoch_loss = 0\n", " with torch.no_grad():\n", " for batch in iterator:\n", " # turn off teacher forcing\n", " outputs = seq2seq(batch.src, batch.trg, teacher_forcing_ratio=0) \n", "\n", " # trg = [trg sent len, batch size]\n", " # output = [trg sent len, batch size, output dim]\n", " outputs_flatten = outputs[1:].view(-1, outputs.shape[-1])\n", " trg_flatten = batch.trg[1:].view(-1)\n", " loss = criterion(outputs_flatten, trg_flatten)\n", " epoch_loss += loss.item()\n", "\n", " return epoch_loss / len(iterator)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": {}, "colab_type": "code", "id": "upC7CqpZu7Re" }, "outputs": [], "source": [ "def epoch_time(start_time, end_time):\n", " elapsed_time = end_time - start_time\n", " elapsed_mins = int(elapsed_time / 60)\n", " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", " return elapsed_mins, elapsed_secs" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "colab_type": "code", "id": "vl5m50OHu7Rg", "outputId": "d5189287-7926-47dc-bf6b-c9e6035535c1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 01 | Time: 1m 12s\n", "\tTrain Loss: 5.023 | Train PPL: 151.870\n", "\t Val. Loss: 4.904 | Val. PPL: 134.856\n", "Epoch: 02 | Time: 1m 12s\n", "\tTrain Loss: 4.396 | Train PPL: 81.134\n", "\t Val. Loss: 4.651 | Val. PPL: 104.687\n", "Epoch: 03 | Time: 1m 12s\n", "\tTrain Loss: 4.076 | Train PPL: 58.924\n", "\t Val. Loss: 4.411 | Val. PPL: 82.381\n", "Epoch: 04 | Time: 1m 12s\n", "\tTrain Loss: 3.811 | Train PPL: 45.217\n", "\t Val. Loss: 4.314 | Val. PPL: 74.703\n", "Epoch: 05 | Time: 1m 12s\n", "\tTrain Loss: 3.569 | Train PPL: 35.482\n", "\t Val. Loss: 4.014 | Val. PPL: 55.342\n", "Epoch: 06 | Time: 1m 12s\n", "\tTrain Loss: 3.355 | Train PPL: 28.659\n", "\t Val. Loss: 3.933 | Val. PPL: 51.046\n", "Epoch: 07 | Time: 1m 12s\n", "\tTrain Loss: 3.187 | Train PPL: 24.222\n", "\t Val. Loss: 3.811 | Val. PPL: 45.207\n", "Epoch: 08 | Time: 1m 12s\n", "\tTrain Loss: 3.028 | Train PPL: 20.662\n", "\t Val. Loss: 3.810 | Val. PPL: 45.140\n", "Epoch: 09 | Time: 1m 12s\n", "\tTrain Loss: 2.863 | Train PPL: 17.513\n", "\t Val. Loss: 3.709 | Val. PPL: 40.809\n", "Epoch: 10 | Time: 1m 12s\n", "\tTrain Loss: 2.751 | Train PPL: 15.661\n", "\t Val. Loss: 3.755 | Val. PPL: 42.746\n", "Epoch: 11 | Time: 1m 12s\n", "\tTrain Loss: 2.615 | Train PPL: 13.666\n", "\t Val. Loss: 3.727 | Val. PPL: 41.568\n", "Epoch: 12 | Time: 1m 12s\n", "\tTrain Loss: 2.481 | Train PPL: 11.959\n", "\t Val. Loss: 3.692 | Val. PPL: 40.135\n", "Epoch: 13 | Time: 1m 12s\n", "\tTrain Loss: 2.389 | Train PPL: 10.898\n", "\t Val. Loss: 3.734 | Val. PPL: 41.846\n", "Epoch: 14 | Time: 1m 12s\n", "\tTrain Loss: 2.281 | Train PPL: 9.791\n", "\t Val. Loss: 3.748 | Val. PPL: 42.419\n", "Epoch: 15 | Time: 1m 12s\n", "\tTrain Loss: 2.179 | Train PPL: 8.838\n", "\t Val. Loss: 3.722 | Val. PPL: 41.360\n", "Epoch: 16 | Time: 1m 12s\n", "\tTrain Loss: 2.082 | Train PPL: 8.019\n", "\t Val. Loss: 3.798 | Val. PPL: 44.629\n", "Epoch: 17 | Time: 1m 12s\n", "\tTrain Loss: 2.017 | Train PPL: 7.514\n", "\t Val. Loss: 3.731 | Val. PPL: 41.717\n", "Epoch: 18 | Time: 1m 12s\n", "\tTrain Loss: 1.912 | Train PPL: 6.767\n", "\t Val. Loss: 3.791 | Val. PPL: 44.289\n", "Epoch: 19 | Time: 1m 11s\n", "\tTrain Loss: 1.839 | Train PPL: 6.292\n", "\t Val. Loss: 3.789 | Val. PPL: 44.197\n", "Epoch: 20 | Time: 1m 11s\n", "\tTrain Loss: 1.758 | Train PPL: 5.802\n", "\t Val. Loss: 3.880 | Val. PPL: 48.423\n" ] } ], "source": [ "N_EPOCHS = 20\n", "best_valid_loss = float('inf')\n", "\n", "for epoch in range(N_EPOCHS): \n", " start_time = time.time()\n", " train_loss = train(seq2seq, train_iterator, optimizer, criterion)\n", " valid_loss = evaluate(seq2seq, valid_iterator, criterion)\n", " end_time = time.time()\n", "\n", " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(seq2seq.state_dict(), 'tut1-model.pt')\n", "\n", " # it's easier to see a change in perplexity between epoch as it's an exponential\n", " # of the loss, hence the scale of the measure is much bigger\n", " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "kG6YrsRqNuIb" }, "source": [ "### Evaluating Seq2Seq" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "4XXs0ut3u7Ri", "outputId": "299ccbc8-7b1e-44c3-a6a7-601312d7b978" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "| Test Loss: 3.650 | Test PPL: 38.477 |\n" ] } ], "source": [ "seq2seq.load_state_dict(torch.load('tut1-model.pt'))\n", "\n", "test_loss = evaluate(seq2seq, test_iterator, criterion)\n", "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "cKstx7PxaeOT" }, "source": [ "Here, we pick a random example in our dataset, print out the original source and target sentence. Then takes a look at whether the \"predicted\" target sentence generated by the model." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "colab_type": "code", "id": "S3U31H1iPlSz", "outputId": "72e60e7b-4391-479d-daaf-8b6b250a3b63" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source sentence: . büsche vieler nähe der in freien im sind männer weiße junge zwei\n", "target sentence: two young , white males are outside near many bushes .\n" ] } ], "source": [ "example_idx = 0\n", "example = train_data.examples[example_idx]\n", "print('source sentence: ', ' '.join(example.src))\n", "print('target sentence: ', ' '.join(example.trg))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "colab_type": "code", "id": "WNrAw9WwReEV", "outputId": "53b3577f-756a-48bc-8ce7-d854d5f2f8f8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([13, 1])\n" ] }, { "data": { "text/plain": [ "torch.Size([13, 1, 5893])" ] }, "execution_count": 31, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "src_tensor = source.process([example.src]).to(device)\n", "trg_tensor = target.process([example.trg]).to(device)\n", "print(trg_tensor.shape)\n", "\n", "seq2seq.eval()\n", "with torch.no_grad():\n", " outputs = seq2seq(src_tensor, trg_tensor, teacher_forcing_ratio=0)\n", "\n", "outputs.shape" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "colab_type": "code", "id": "x7qOB8YI2pUT", "outputId": "ee195603-6586-4cb6-8240-5d14320e9b3f" }, "outputs": [ { "data": { "text/plain": [ "'two young men in large large are near some sort . '" ] }, "execution_count": 32, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "output_idx = outputs[1:].squeeze(1).argmax(1)\n", "' '.join([target.vocab.itos[idx] for idx in output_idx])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YgPOZFr5aeOa" }, "source": [ "## Summary" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Yv-fzIU7aeOb" }, "source": [ "In this document:\n", "\n", "- We took a stab at implementing a vanilla version of the seq2seq model, and train it on a German to English translation.\n", "- Implemented the trick introduced by the original seq2seq paper where they reverse the order of the tokens in the source sentence.\n", "\n", "There are a lot of other tricks/ideas that are mentioned in the original paper and worth exploring. e.g.\n", "\n", "- A LSTM with 4 layers was chosen.\n", "- Beam Search was also used to decode the sentence.\n", "- Instead of only relying on log-loss or perplexity, another evaluation metric that they used to evaluate the quality of their translation." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nqqhN5-yNuIl" }, "source": [ "# Reference" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ZbnFj1oHavDp" }, "source": [ "- [Blog: A Comprehensive Introduction to Torchtext (Practical Torchtext part 1)](https://mlexplained.com/2018/02/08/a-comprehensive-tutorial-to-torchtext/)\n", "- [Jupyter Notebook: Using TorchText with Your Own Datasets](https://nbviewer.jupyter.org/github/bentrevett/pytorch-sentiment-analysis/blob/master/A%20-%20Using%20TorchText%20with%20Your%20Own%20Datasets.ipynb)\n", "- [Jupyter Notebook: Sequence to Sequence Learning with Neural Networks](https://nbviewer.jupyter.org/github/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb)\n", "- [Paper: Sutskever, I., Vinyals, O., and Le, Q. (2014). Sequence to sequence learning with neural networks.](https://arxiv.org/abs/1409.3215)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "1_torch_seq2seq_intro.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "272px" }, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }