{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This project will be broken up into several parts as follows:\n", "\n", "__Part 1:__ Preparing the words\n", "\n", "+ Inspecting the Dataset\n", "+ Using Word Embeddings\n", "+ Organizing the Data\n", "\n", "__Part 2:__ Building the Model\n", "\n", "+ Bi-Directional Encoder\n", "+ Building Attention\n", "+ Decoder with Attention\n", "\n", "__Part 3:__ Training the Model\n", "\n", "+ Training Function\n", "+ Training Loop\n", "\n", "__Part 4:__ Evaluation\n", "\n", "\n", "This project closely follows the [PyTorch Sequence to Sequence tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html), while attempting to go more in depth with both the model implementation and the explanation. Thanks to [Sean Robertson](https://github.com/spro/practical-pytorch) and [PyTorch](https://pytorch.org/tutorials/) for providing such great tutorials.\n", "\n", "If you are working through this notebook, it is strongly recommended that [Jupyter Notebook Extensions](https://github.com/ipython-contrib/jupyter_contrib_nbextensions) is installed so you can turn on collapsable headings. It makes the notebook much easier to navigate." ] }, { "cell_type": "code", "execution_count": 250, "metadata": {}, "outputs": [], "source": [ "# Before we get started we will load all the packages we will need\n", "\n", "# Pytorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import numpy as np\n", "import os.path\n", "import time\n", "import math\n", "import random\n", "import matplotlib.pyplot as plt\n", "import string\n", "\n", "# Use gpu if available\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 251, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 251, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Part 1: Preparing the Words" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "### Inspecting the Dataset" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The dataset that will be used is a text file of english sentences and the corresponding french sentences.\n", "\n", "Each sentence is on a new line. The sentences will be split into a list." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "#### Load the data\n", "The data will be stored in two lists where each item is a sentence. The lists are:\n", "+ english_sentences\n", "+ french_sentences\n", "\n", "Download the first dataset from the projects' github repo. Place it in the same folder as the notebook or create a data folder in the notebook's folder." ] }, { "cell_type": "code", "execution_count": 252, "metadata": { "hidden": true }, "outputs": [], "source": [ "with open('data/small_vocab_en', \"r\") as f:\n", " data1 = f.read()\n", "with open('data/small_vocab_fr', \"r\") as f:\n", " data2 = f.read()\n", " \n", "# The data is just in a text file with each sentence on its own line\n", "english_sentences = data1.split('\\n')\n", "french_sentences = data2.split('\\n')" ] }, { "cell_type": "code", "execution_count": 253, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of English sentences: 137861 \n", "Number of French sentences: 137861 \n", "\n", "Example/Target pair:\n", "\n", " california is usually quiet during march , and it is usually hot in june .\n", " california est généralement calme en mars , et il est généralement chaud en juin .\n" ] } ], "source": [ "print('Number of English sentences:', len(english_sentences), \n", " '\\nNumber of French sentences:', len(french_sentences),'\\n')\n", "print('Example/Target pair:\\n')\n", "print(' '+english_sentences[2])\n", "print(' '+french_sentences[2])" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "#### Vocabulary\n", "Let's take a closer look at the dataset.\n" ] }, { "cell_type": "code", "execution_count": 254, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "['california',\n", " 'is',\n", " 'usually',\n", " 'quiet',\n", " 'during',\n", " 'march',\n", " ',',\n", " 'and',\n", " 'it',\n", " 'is',\n", " 'usually',\n", " 'hot',\n", " 'in',\n", " 'june',\n", " '.']" ] }, "execution_count": 254, "metadata": {}, "output_type": "execute_result" } ], "source": [ "english_sentences[2].split()" ] }, { "cell_type": "code", "execution_count": 255, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The longest english sentence in our dataset is: 17\n" ] } ], "source": [ "max_en_length = 0\n", "for sentence in english_sentences:\n", " length = len(sentence.split())\n", " max_en_length = max(max_en_length, length)\n", "print(\"The longest english sentence in our dataset is:\", max_en_length) " ] }, { "cell_type": "code", "execution_count": 256, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The longest french sentence in our dataset is: 23\n" ] } ], "source": [ "max_fr_length = 0\n", "for sentence in french_sentences:\n", " length = len(sentence.split())\n", " max_fr_length = max(max_fr_length, length)\n", "print(\"The longest french sentence in our dataset is:\", max_fr_length)" ] }, { "cell_type": "code", "execution_count": 257, "metadata": { "hidden": true }, "outputs": [], "source": [ "max_seq_length = max(max_fr_length, max_en_length) + 1\n", "seq_length = max_seq_length" ] }, { "cell_type": "code", "execution_count": 258, "metadata": { "hidden": true }, "outputs": [], "source": [ "en_word_count = {}\n", "fr_word_count = {}\n", "\n", "for sentence in english_sentences:\n", " for word in sentence.split():\n", " if word in en_word_count:\n", " en_word_count[word] +=1\n", " else:\n", " en_word_count[word] = 1\n", " \n", "for sentence in french_sentences:\n", " for word in sentence.split():\n", " if word in fr_word_count:\n", " fr_word_count[word] +=1\n", " else:\n", " fr_word_count[word] = 1\n" ] }, { "cell_type": "code", "execution_count": 259, "metadata": { "hidden": true }, "outputs": [], "source": [ "# Add end of sentence token to word count dict\n", "en_word_count[''] = len(english_sentences)\n", "fr_word_count[''] = len(english_sentences)" ] }, { "cell_type": "code", "execution_count": 260, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of unique English words: 228\n", "Number of unique French words: 356\n" ] } ], "source": [ "print('Number of unique English words:', len(en_word_count))\n", "print('Number of unique French words:', len(fr_word_count))" ] }, { "cell_type": "code", "execution_count": 261, "metadata": { "hidden": true }, "outputs": [], "source": [ "def get_value(items_tuple):\n", " return items_tuple[1]\n", "\n", "# Sort the word counts to see what words or most/least common\n", "sorted_en_words= sorted(en_word_count.items(), key=get_value, reverse=True)" ] }, { "cell_type": "code", "execution_count": 262, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[('is', 205858),\n", " (',', 140897),\n", " ('', 137861),\n", " ('.', 129039),\n", " ('in', 75525),\n", " ('it', 75137),\n", " ('during', 74933),\n", " ('the', 67628),\n", " ('but', 63987),\n", " ('and', 59850)]" ] }, "execution_count": 262, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted_en_words[:10]" ] }, { "cell_type": "code", "execution_count": 263, "metadata": { "hidden": true }, "outputs": [], "source": [ "sorted_fr_words = sorted(fr_word_count.items(), key=get_value, reverse=True)" ] }, { "cell_type": "code", "execution_count": 264, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[('est', 196809),\n", " ('', 137861),\n", " ('.', 135619),\n", " (',', 123135),\n", " ('en', 105768),\n", " ('il', 84079),\n", " ('les', 65255),\n", " ('mais', 63987),\n", " ('et', 59851),\n", " ('la', 49861)]" ] }, "execution_count": 264, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted_fr_words[:10]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "So the dataset is pretty small, we may want to get a bigger data set, but we'll see how this one does." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "#### Alternate Dataset\n", "Skip this section for now. You can come back and try training on this second dataset later. It is more diverse so it takes longer to train.\n", "\n", "Download the French-English dataset from [here](http://www.manythings.org/anki/), Although you could train the model on any of the other language pairs. However, you would need different word embeddings or they would need to be trained from scratch." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hidden": true }, "outputs": [], "source": [ "with open('data/fra.txt', \"r\") as f:\n", " data1 = f.read()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hidden": true }, "outputs": [], "source": [ "pairs = data1.split('\\n')\n", "english_sentences = []\n", "french_sentences = []\n", "for i, pair in enumerate(pairs):\n", " pair_split = pair.split('\\t')\n", " if len(pair_split)!= 2:\n", " continue\n", " english = pair_split[0].lower()\n", " french = pair_split[1].lower()\n", " \n", " # Remove punctuation and limit sentence length\n", " max_sent_length = 10\n", " punctuation_table = english.maketrans({i:None for i in string.punctuation})\n", " english = english.translate(punctuation_table)\n", " french = french.translate(punctuation_table)\n", " if len(english.split()) > max_sent_length or len(french.split()) > max_sent_length:\n", " continue\n", " \n", " english_sentences.append(english)\n", " french_sentences.append(french)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "139692 139692\n" ] }, { "data": { "text/plain": [ "['i', 'have', 'to', 'fight']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(len(english_sentences), len(french_sentences))\n", "english_sentences[10000].split()\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "['il', 'me', 'faut', 'me', 'battre']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "french_sentences[10000].split()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['would', 'you', 'consider', 'taking', 'care', 'of', 'my', 'children', 'next', 'saturday']\n" ] }, { "data": { "text/plain": [ "['pourriezvous',\n", " 'réfléchir',\n", " 'à',\n", " 'vous',\n", " 'occuper',\n", " 'de',\n", " 'mes',\n", " 'enfants',\n", " 'samedi',\n", " 'prochain']" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(english_sentences[-100].split())\n", "french_sentences[-100].split()\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The longest english sentence in our dataset is: 10\n" ] } ], "source": [ "max_en_length = 0\n", "for sentence in english_sentences:\n", " length = len(sentence.split())\n", " max_en_length = max(max_en_length, length)\n", "print(\"The longest english sentence in our dataset is:\", max_en_length) " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The longest french sentence in our dataset is: 10\n" ] } ], "source": [ "max_fr_length = 0\n", "for sentence in french_sentences:\n", " length = len(sentence.split())\n", " max_fr_length = max(max_fr_length, length)\n", "print(\"The longest french sentence in our dataset is:\", max_fr_length) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hidden": true }, "outputs": [], "source": [ "max_seq_length = max(max_fr_length, max_en_length) + 1\n", "seq_length = max_seq_length" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hidden": true }, "outputs": [], "source": [ "en_word_count = {}\n", "fr_word_count = {}\n", "\n", "for sentence in english_sentences:\n", " for word in sentence.split():\n", " if word in en_word_count:\n", " en_word_count[word] +=1\n", " else:\n", " en_word_count[word] = 1\n", " \n", "for sentence in french_sentences:\n", " for word in sentence.split():\n", " if word in fr_word_count:\n", " fr_word_count[word] +=1\n", " else:\n", " fr_word_count[word] = 1\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hidden": true }, "outputs": [], "source": [ "en_word_count[''] = len(english_sentences)\n", "fr_word_count[''] = len(english_sentences)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of unique English words: 12603\n", "Number of unique French words: 25809\n" ] } ], "source": [ "print('Number of unique English words:', len(en_word_count))\n", "print('Number of unique French words:', len(fr_word_count))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hidden": true }, "outputs": [], "source": [ "fr_word2idx = {k:v+3 for v, k in enumerate(fr_word_count.keys())}\n", "en_word2idx = {k:v+3 for v, k in enumerate(en_word_count.keys())}" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hidden": true }, "outputs": [], "source": [ "fr_word2idx['