{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": true }, "source": [ "

Table of Contents

\n", "
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-05-17T04:59:51.687711Z", "start_time": "2021-05-17T04:59:46.035704Z" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# code for loading the format for the notebook\n", "import os\n", "\n", "# path : store the current path to convert back to it later\n", "path = os.getcwd()\n", "os.chdir(os.path.join('..', '..', 'notebook_format'))\n", "\n", "from formats import load_style\n", "load_style(css_style='custom2.css', plot_style=False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Author: Ethen\n", "\n", "Last updated: 2023-09-11 22:48:58\n", "\n", "Python implementation: CPython\n", "Python version : 3.10.6\n", "IPython version : 8.13.2\n", "\n", "datasets : 2.14.4\n", "numpy : 1.23.2\n", "torch : 2.0.1\n", "tokenizers: 0.13.3\n", "\n" ] } ], "source": [ "os.chdir(path)\n", "\n", "# 1. magic for inline plot\n", "# 2. magic to print version\n", "# 3. magic so that the notebook will reload external python modules\n", "# 4. magic to enable retina (high resolution) plots\n", "# https://gist.github.com/minrk/3301035\n", "%matplotlib inline\n", "%load_ext watermark\n", "%load_ext autoreload\n", "%autoreload 2\n", "%config InlineBackend.figure_format='retina'\n", "\n", "import os\n", "import math\n", "import time\n", "import torch\n", "import random\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from tokenizers import ByteLevelBPETokenizer\n", "from datasets import load_dataset, disable_progress_bar\n", "\n", "# prevents progress bar and logging from flooding our document\n", "disable_progress_bar()\n", "\n", "%watermark -a 'Ethen' -d -t -v -u -p datasets,numpy,torch,tokenizers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Seq2Seq based machine translation system usually comprises of two main components, an encoder that encodes in source sentence into context vectors and a decoder that decodes the context vectors into target sentence, transformer model is no different in this regards. Reasons to their growing popularity at the time of writing this document are primarily due to **self attention layers** and **parallel computation**.\n", "\n", "Previous RNN based encoder and decoder has a constraint of sequential computation. A hidden state at time $t$ in a recurrent layer, has only seen tokens $x_t$ and all the tokens before it, even though this gives us the benefit of modeling long dependencies, it hinders training speed as we can't process the next time step until we finish processing the current one. Transformer model aims to mitigate this issue by solely relying on attention mechanism, where each context vector produced by a transformer model has seen all tokens at all positions within the input sequence. In other words, instead of compressing the entire source sentence, $X = (x_1, ... , x_n)$ into a single context vector, $z$, it produces a sequence of context vectors, $Z = (z_1, ... , z_n)$ in one parallel computation. We'll get to the details of attention mechanism, self attention, that's used throughout the Transformer model in later sections. One important thing to note here is that breakthrough of this model is not due to invention of the attention mechansim, as this concept existed well before. The highlight here is we can build a highly performant model with attention mechanism in isolation, i.e. without the use of recurrent (RNN) or convolutional (CNN) neural networks in the mix.\n", "\n", "In this article, we will be implementing Transformer module from the famous Attention is all you need paper [[9]](https://arxiv.org/abs/1706.03762). This implementation's structure is largely based on [[1]](https://nbviewer.jupyter.org/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb). With the primary difference that we'll be using huggingface's dataset instead of torchtext for data loading, as well as show casing how to implement Transformer module leveraging PyTorch built in Transformer Encoder and Decoder block. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll be using the [Multi30k dataset](http://www.statmt.org/wmt16/multimodal-task.html) to demonstrate using the transfomer model in a machine translation task. This German to English training dataset's size is around 29K. We'll start off by downloading the raw dataset and extracting them. Feel free to swap this step with any other machine translation dataset. If the original link for these datasets fails to load, use this alternative [google drive link](https://drive.google.com/drive/folders/10zrAb3BHz4xCuyqUvqphiBTivahX0V8I?usp=drive_link)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import tarfile\n", "import zipfile\n", "import requests\n", "import subprocess\n", "from tqdm import tqdm\n", "from urllib.parse import urlparse\n", "\n", "\n", "def download_file(url: str, directory: str):\n", " \"\"\"\n", " Download the file at ``url`` to ``directory``.\n", " Extract to the file content ``directory`` if the original file\n", " is a tar, tar.gz or zip file.\n", "\n", " Parameters\n", " ----------\n", " url : str\n", " url of the file.\n", "\n", " directory : str\n", " Directory to download the file.\n", " \"\"\"\n", " response = requests.get(url, stream=True)\n", " response.raise_for_status()\n", "\n", " content_len = response.headers.get('Content-Length')\n", " total = int(content_len) if content_len is not None else 0\n", "\n", " os.makedirs(directory, exist_ok=True)\n", " file_name = get_file_name_from_url(url)\n", " file_path = os.path.join(directory, file_name)\n", "\n", " with tqdm(unit='B', total=total) as pbar, open(file_path, 'wb') as f:\n", " for chunk in response.iter_content(chunk_size=1024):\n", " if chunk:\n", " pbar.update(len(chunk))\n", " f.write(chunk)\n", "\n", " extract_compressed_file(file_path, directory)\n", "\n", "\n", "def extract_compressed_file(compressed_file_path: str, directory: str):\n", " \"\"\"\n", " Extract a compressed file to ``directory``. Supports zip, tar.gz, tgz,\n", " tar extensions.\n", "\n", " Parameters\n", " ----------\n", " compressed_file_path : str\n", "\n", " directory : str\n", " File will to extracted to this directory.\n", " \"\"\"\n", " basename = os.path.basename(compressed_file_path)\n", " if 'zip' in basename:\n", " with zipfile.ZipFile(compressed_file_path, \"r\") as zip_f:\n", " zip_f.extractall(directory)\n", " elif 'tar.gz' in basename or 'tgz' in basename:\n", " with tarfile.open(compressed_file_path) as f:\n", " f.extractall(directory)\n", "\n", "\n", "def get_file_name_from_url(url: str) -> str:\n", " \"\"\"\n", " Return the file_name from a URL\n", "\n", " Parameters\n", " ----------\n", " url : str\n", " URL to extract file_name from\n", "\n", " Returns\n", " -------\n", " file_name : str\n", " \"\"\"\n", " parse = urlparse(url)\n", " return os.path.basename(parse.path)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "urls = [\n", " 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',\n", " 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',\n", " 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'\n", "]\n", "directory = './translation/wmt16'\n", "for url in urls:\n", " download_file(url, directory)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We print out the content in the data directory and some sample data." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mmt16_task1_test.tar.gz test.en train.en\t val.de validation.tar.gz\n", "test.de\t\t\t train.de training.tar.gz val.en\n" ] } ], "source": [ "!ls $directory" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n", "Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.\n", "Ein kleines Mädchen klettert in ein Spielhaus aus Holz.\n", "Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.\n", "Zwei Männer stehen am Herd und bereiten Essen zu.\n", "Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.\n", "Ein Mann lächelt einen ausgestopften Löwen an.\n", "Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.\n", "Eine Frau mit einer großen Geldbörse geht an einem Tor vorbei.\n", "Jungen tanzen mitten in der Nacht auf Pfosten.\n" ] } ], "source": [ "!head $directory/train.de" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Two young, White males are outside near many bushes.\n", "Several men in hard hats are operating a giant pulley system.\n", "A little girl climbing into a wooden playhouse.\n", "A man in a blue shirt is standing on a ladder cleaning a window.\n", "Two men are at the stove preparing food.\n", "A man in green holds a guitar while the other man observes his shirt.\n", "A man is smiling at a stuffed lion\n", "A trendy girl talking on her cellphone while gliding slowly down the street.\n", "A woman with a large purse is walking by a gate.\n", "Boys dancing on poles in the middle of the night.\n" ] } ], "source": [ "!head $directory/train.en" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The original dataset is splits the source and the target language into two separate files (e.g. train.de, train.en are the training dataset for German and English). This type of format is useful when we wish to train a tokenizer on top of the source or target language as we'll soon see.\n", "\n", "On the other hand, having the source and target pair together in one single file makes it easier to load them in batches for training or evaluating our machine translation model. We'll create the paired dataset, and [load the dataset](https://huggingface.co/docs/datasets/loading_datasets.html#csv-files). For loading the dataset, it will be helpful to have some basic understanding of Huggingface's [dataset](https://huggingface.co/docs/datasets/quicktour.html)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def create_translation_data(\n", " source_input_path: str,\n", " target_input_path: str,\n", " output_path: str,\n", " delimiter: str = '\\t',\n", " encoding: str = 'utf-8'\n", "):\n", " \"\"\"\n", " Creates the paired source and target dataset from the separated ones.\n", " e.g. creates `train.tsv` from `train.de` and `train.en`\n", " \"\"\"\n", " with open(source_input_path, encoding=encoding) as f_source_in, \\\n", " open(target_input_path, encoding=encoding) as f_target_in, \\\n", " open(output_path, 'w', encoding=encoding) as f_out:\n", "\n", " for source_raw in f_source_in:\n", " source_raw = source_raw.strip()\n", " target_raw = f_target_in.readline().strip()\n", " if source_raw and target_raw:\n", " output_line = source_raw + delimiter + target_raw + '\\n'\n", " f_out.write(output_line)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': ['train.tsv'], 'val': ['val.tsv'], 'test': ['test.tsv']}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "source_lang = 'de'\n", "target_lang = 'en'\n", "\n", "data_files = {}\n", "for split in ['train', 'val', 'test']:\n", " source_input_path = os.path.join(directory, f'{split}.{source_lang}')\n", " target_input_path = os.path.join(directory, f'{split}.{target_lang}')\n", " output_path = f'{split}.tsv'\n", " create_translation_data(source_input_path, target_input_path, output_path)\n", " data_files[split] = [output_path]\n", "\n", "data_files" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['de', 'en'],\n", " num_rows: 29000\n", " })\n", " val: Dataset({\n", " features: ['de', 'en'],\n", " num_rows: 1014\n", " })\n", " test: Dataset({\n", " features: ['de', 'en'],\n", " num_rows: 1000\n", " })\n", "})" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_dict = load_dataset(\n", " 'csv',\n", " delimiter='\\t',\n", " column_names=[source_lang, target_lang],\n", " data_files=data_files\n", ")\n", "dataset_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can access each split, and record/pair with the following syntax." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',\n", " 'en': 'Two young, White males are outside near many bushes.'}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_dict['train'][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From our raw pair, we need to use or train a tokenizer to convert them into numerical indices. Here we'll be training our tokenizer from scratch using Huggingface's [tokenizer](https://github.com/huggingface/tokenizers). Feel free to swap this step out with other tokenization procedures, what's important is to leave rooms for special tokens such as the init token that represents the start of a sentence, the end of sentence token that represents the end of a sentence, and padding token that pads sentence batches into equivalent length." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./translation/wmt16/train.de ./translation/wmt16/train.en\n" ] } ], "source": [ "# use only the training set to train our tokenizer\n", "split = 'train'\n", "source_input_path = os.path.join(directory, f'{split}.{source_lang}')\n", "target_input_path = os.path.join(directory, f'{split}.{target_lang}')\n", "print(source_input_path, target_input_path)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "elapsed: 1.4205942153930664\n", "source vocab size: 5000\n", "target vocab size: 5000\n" ] } ], "source": [ "init_token = ''\n", "eos_token = ''\n", "pad_token = ''\n", "\n", "tokenizer_params = {\n", " 'min_frequency': 2,\n", " 'vocab_size': 5000,\n", " 'show_progress': False,\n", " 'special_tokens': [init_token, eos_token, pad_token]\n", "}\n", "\n", "start_time = time.time()\n", "source_tokenizer = ByteLevelBPETokenizer(lowercase=True)\n", "source_tokenizer.train(source_input_path, **tokenizer_params)\n", "\n", "target_tokenizer = ByteLevelBPETokenizer(lowercase=True)\n", "target_tokenizer.train(target_input_path, **tokenizer_params)\n", "end_time = time.time()\n", "\n", "print('elapsed: ', end_time - start_time)\n", "print('source vocab size: ', source_tokenizer.get_vocab_size())\n", "print('target vocab size: ', target_tokenizer.get_vocab_size())" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "source_eos_idx = source_tokenizer.token_to_id(eos_token)\n", "target_eos_idx = target_tokenizer.token_to_id(eos_token)\n", "\n", "source_init_idx = source_tokenizer.token_to_id(init_token)\n", "target_init_idx = target_tokenizer.token_to_id(init_token)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll perform this tokenization step for all our dataset up front, so we can do as little preprocessing as possible while feeding our dataset to model. Note that we do not perform the padding step at this stage." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "elapsed: 2.2638769149780273\n" ] }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['de', 'en', 'source_ids', 'target_ids'],\n", " num_rows: 29000\n", " })\n", " val: Dataset({\n", " features: ['de', 'en', 'source_ids', 'target_ids'],\n", " num_rows: 1014\n", " })\n", " test: Dataset({\n", " features: ['de', 'en', 'source_ids', 'target_ids'],\n", " num_rows: 1000\n", " })\n", "})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def encode(example):\n", " \"\"\"\n", " Encode the raw text into numerical token ids. Creating two new fields\n", " ``source_ids`` and ``target_ids``.\n", " Also append the init token and prepend eos token to the sentence.\n", " \"\"\"\n", " source_raw = example[source_lang]\n", " target_raw = example[target_lang]\n", " source_encoded = source_tokenizer.encode(source_raw).ids\n", " source_encoded = [source_init_idx] + source_encoded + [source_eos_idx]\n", " target_encoded = target_tokenizer.encode(target_raw).ids\n", " target_encoded = [target_init_idx] + target_encoded + [target_eos_idx]\n", " example['source_ids'] = source_encoded\n", " example['target_ids'] = target_encoded\n", " return example\n", "\n", "\n", "start_time = time.time()\n", "dataset_dict_encoded = dataset_dict.map(encode, num_proc=8)\n", "end_time = time.time()\n", "print('elapsed: ', end_time - start_time)\n", "\n", "dataset_dict_encoded" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',\n", " 'en': 'Two young, White males are outside near many bushes.',\n", " 'source_ids': [0,\n", " 343,\n", " 377,\n", " 1190,\n", " 412,\n", " 648,\n", " 348,\n", " 659,\n", " 280,\n", " 326,\n", " 725,\n", " 1283,\n", " 262,\n", " 727,\n", " 706,\n", " 16,\n", " 1],\n", " 'target_ids': [0, 335, 372, 14, 369, 2181, 320, 493, 556, 1202, 3157, 16, 1]}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_train = dataset_dict_encoded['train']\n", "dataset_train[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The final step for our data preprocessing step is to prepare the [DataLoader](https://pytorch.org/docs/stable/data.html#dataloader-collate-fn), which prepares batches of tokenized ids for our model. The customized collate function performs the batching as well as padding." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class TranslationPairCollate:\n", "\n", " def __init__(self, max_len, pad_idx, device, percentage=100):\n", " self.device = device\n", " self.max_len = max_len\n", " self.pad_idx = pad_idx\n", " self.percentage = percentage\n", "\n", " def __call__(self, batch):\n", " source_batch = []\n", " source_len = []\n", " target_batch = []\n", " target_len = []\n", " for example in batch:\n", " source = example['source_ids']\n", " source_len.append(len(source))\n", " source_batch.append(source)\n", "\n", " target = example['target_ids']\n", " target_len.append(len(target))\n", " target_batch.append(target)\n", "\n", " source_padded = self.process_encoded_text(source_batch, source_len, self.max_len, self.pad_idx)\n", " target_padded = self.process_encoded_text(target_batch, target_len, self.max_len, self.pad_idx)\n", " return source_padded, target_padded\n", "\n", " def process_encoded_text(self, sequences, sequences_len, max_len, pad_idx):\n", " sequences_len_percentile = int(np.percentile(sequences_len, self.percentage))\n", " max_len = min(sequences_len_percentile, max_len)\n", " padded_sequences = pad_sequences(sequences, max_len, pad_idx)\n", " return torch.LongTensor(padded_sequences)\n", "\n", "\n", "def pad_sequences(sequences, max_len, pad_idx):\n", " \"\"\"\n", " Pad the list of sequences (numerical token ids) to the same length.\n", " Sequence that are shorter than the specified ``max_len`` will be appended\n", " with the specified ``pad_idx``. Those that are longer will be truncated.\n", "\n", " Parameters\n", " ----------\n", " sequences : list[int]\n", " List of numerical token ids.\n", "\n", " max_len : int\n", " Maximum length of all sequences.\n", "\n", " pad_idx : int\n", " Padding index.\n", "\n", " Returns\n", " -------\n", " padded_sequences : 1d ndarray\n", " \"\"\"\n", " num_samples = len(sequences)\n", " padded_sequences = np.full((num_samples, max_len), pad_idx)\n", " for i, sequence in enumerate(sequences):\n", " sequence = np.array(sequence)[:max_len]\n", " padded_sequences[i, :len(sequence)] = sequence\n", "\n", " return padded_sequences" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 0, 343, 377, ..., 2, 2, 2],\n", " [ 0, 640, 412, ..., 2, 2, 2],\n", " [ 0, 261, 542, ..., 2, 2, 2],\n", " ...,\n", " [ 0, 343, 500, ..., 2, 2, 2],\n", " [ 0, 296, 442, ..., 2, 2, 2],\n", " [ 0, 296, 317, ..., 2, 2, 2]]),\n", " tensor([[ 0, 335, 372, ..., 2, 2, 2],\n", " [ 0, 808, 400, ..., 2, 2, 2],\n", " [ 0, 67, 504, ..., 2, 2, 2],\n", " ...,\n", " [ 0, 335, 479, ..., 2, 2, 2],\n", " [ 0, 67, 413, ..., 2, 2, 2],\n", " [ 0, 67, 325, ..., 2, 2, 2]]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "max_len = 100\n", "batch_size = 128\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "pad_idx = source_tokenizer.token_to_id(pad_token)\n", "translation_pair_collate_fn = TranslationPairCollate(max_len, pad_idx, device)\n", "\n", "data_loader_params = {\n", " 'batch_size': batch_size,\n", " 'collate_fn': translation_pair_collate_fn,\n", " 'pin_memory': True\n", "}\n", "\n", "dataloader_train = DataLoader(dataset_train, **data_loader_params)\n", "\n", "# we can print out 1 batch of source and target\n", "source, target = next(iter(dataloader_train))\n", "source, target" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# create the data loader for both validation and test set\n", "dataset_val = dataset_dict_encoded['val']\n", "dataloader_val = DataLoader(dataset_val, **data_loader_params)\n", "\n", "dataset_test = dataset_dict_encoded['test']\n", "dataloader_test = DataLoader(dataset_test, **data_loader_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Architecture From Scratch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Having prepared the data, we can now start implementing Transformer model's architecture, which looks like the following:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Position Wise Embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, input tokens are passed through a standard embedding layer. Next, as the entire sentence is fed into the model in one go, by default it has no idea about the tokens' order within the sequence. We cope with this by using a second embedding layer, positional embedding. This is an embedding layer where our input is not the token id but the token's position within the sequence. If we configure our position embedding to have a \"vocabulary\" size of 100, this means our model can accept sentences up to 100 tokens long.\n", "\n", "The original Transformer implementation from the Attention is All You Need paper does not learn positional embeddings. Instead it uses a fixed static positional encoding. Modern Transformer architectures, like BERT, use positional embeddings, hence, we have decided to use them in these tutorials. Feel free to check out other tutorials [[7]](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) [[8]](http://nlp.seas.harvard.edu/annotated-transformer/) to read more about positional encoding used in the original Transformer model.\n", "\n", "Next, token and positional embeddings are combined together using an elementwise sum operation, giving us a single vector that contains information on both the token and its position with in the sequence. Before they are summed, token embeddings are multiplied by a scaling factor $\\sqrt{d_{model}}$, where $d_{model}$ is the hidden dimension size, `hid_dim`. This supposedly reduces variance in the embeddings and without this scaling factor, it becomes difficult to train the model reliably. Dropout is then applied to the combined embeddings." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class PositionWiseEmbedding(nn.Module):\n", "\n", " def __init__(self, input_dim, hid_dim, max_len, dropout_p):\n", " super().__init__()\n", " self.input_dim = input_dim\n", " self.hid_dim = hid_dim\n", " self.max_len = max_len\n", " self.dropout_p = dropout_p\n", "\n", " self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n", " self.pos_embedding = nn.Embedding(max_len, hid_dim)\n", " self.dropout = nn.Dropout(dropout_p)\n", " self.scale = torch.sqrt(torch.FloatTensor([hid_dim]))\n", "\n", " def forward(self, inputs):\n", "\n", " # inputs = [batch size, inputs len]\n", " batch_size = inputs.shape[0]\n", " inputs_len = inputs.shape[1]\n", "\n", " pos = torch.arange(0, inputs_len).unsqueeze(0).repeat(batch_size, 1).to(inputs.device)\n", " scale = self.scale.to(inputs.device)\n", " embedded = (self.tok_embedding(inputs) * scale) + self.pos_embedding(pos)\n", "\n", " # output = [batch size, inputs len, hid dim]\n", " output = self.dropout(embedded)\n", " return output" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PositionWiseEmbedding(\n", " (tok_embedding): Embedding(5000, 64)\n", " (pos_embedding): Embedding(100, 64)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", ")" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_dim = source_tokenizer.get_vocab_size()\n", "hid_dim = 64\n", "max_len = 100\n", "dropout_p = 0.5\n", "embedding = PositionWiseEmbedding(input_dim, hid_dim, max_len, dropout_p).to(device)\n", "embedding" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 40, 64])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "src_embedded = embedding(source.to(device))\n", "src_embedded.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The combined embeddings are then passed through $N$ encoder layers to get our context vectors $Z$. Before jumping straight into the encoder layers, we'll introduce some of the core building blocks behind them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi Head Attention Layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the key concepts introduced by Transformer model is **multi-head attention layer**.\n", "\n", "\n", "\n", "The purpose behind an attention mechanism is to relate inputs from different parts of the sequence. Attention operation is comprised of *queries*, *keys* and *values*. It might be helpful to look at these terms from an informational retrieval perspective, where every time we issue a query to a search engine, the search engine will match it with some key (title, description), and retrieve the associated value (content).\n", "\n", "To be specific, Transformer model uses scaled dot-product attention, where query is used with key to get an attention vector, which is then used to get a weighted sum of the values.\n", "\n", "\\begin{align}\n", "\\text{Attention}(Q, K, V) = \\text{Softmax} \\big( \\frac{QK^T}{\\sqrt{d_k}} \\big)V\n", "\\end{align}\n", "\n", "Where $Q = XW^Q, K = XW^K, V = XW^V$, $X$ is our input matrix, $W^Q$, $W^K$, $W^V$ are linear layers for the query, key and value. $d_k$ is the head dimension, `head_dim`, which we will further explain shortly. In essence, we are multiplying our input matrix with 3 different weight matrices. We first peform a dot product between query and key followed by a softmax to calculate attention weight, which measures correlation between the two words, finally scaling it by $d_k$ before doing a dot product with the value to get the weighted value. Scaling is done to prevent the results of the dot product from growing too large, and causing the gradients to become too small.\n", "\n", "Multi-head attention extends the single attention mechansim so we can potentially pay attention to different concepts that exists at different sequence positions. If end users are familiar with convolutional neural networks, this trick is very similar to introducing multiple filters so each can learn different aspects of the input. Instead of doing a single attention operation, the queries, keys and values have their `hid_dim` split into $h$ heads each of size $d_k$, and the scaled dot-product attention is calculated over all heads in parallel. After this computation, we re-combine the heads back to `hid_dim` shape. By reducing the dimensionality of each head/concept, the total computational cost is similar to a full dimension single-head attention.\n", "\n", "\\begin{align}\n", "\\text{MultiHead}(Q, K, V) &= \\text{Concat}(\\text{head}_1,...,\\text{head}_h)W^O \\\\\n", "\\text{head}_i &= \\text{Attention}(Q_i, K_i, V_i)\n", "\\end{align}\n", "\n", "Where $W^O$ is the linear layer applied at the end of the multi-head attention layer.\n", "\n", "In the implementation below, we carry out the multi head attention in parallel using batch matrix multiplication as opposed to a for loop. And while calculating the attention weights, we introduce the capability of applying a mask so the model does not pay attention to irrelevant tokens. We'll elaborate more on this in future sections." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n", "\n", " def __init__(self, hid_dim, n_heads):\n", " super().__init__()\n", " self.hid_dim = hid_dim\n", " self.n_heads = n_heads\n", " self.head_dim = hid_dim // n_heads\n", " assert hid_dim % n_heads == 0\n", "\n", " self.key_weight = nn.Linear(hid_dim, hid_dim)\n", " self.query_weight = nn.Linear(hid_dim, hid_dim)\n", " self.value_weight = nn.Linear(hid_dim, hid_dim)\n", " self.linear_weight = nn.Linear(hid_dim, hid_dim)\n", "\n", " def forward(self, query, key, value, mask = None):\n", " batch_size = query.shape[0]\n", " query_len = query.shape[1]\n", " key_len = key.shape[1]\n", "\n", " # key/query/value (proj) = [batch size, input len, hid dim]\n", " key_proj = self.key_weight(key)\n", " query_proj = self.query_weight(query)\n", " value_proj = self.value_weight(value)\n", "\n", " # compute the weights between query and key\n", " query_proj = query_proj.view(batch_size, query_len, self.n_heads, self.head_dim)\n", " query_proj = query_proj.permute(0, 2, 1, 3)\n", " key_proj = key_proj.view(batch_size, key_len, self.n_heads, self.head_dim)\n", " key_proj = key_proj.permute(0, 2, 3, 1)\n", "\n", " # energy, attention = [batch size, num heads, query len, key len]\n", " energy = torch.matmul(query_proj, key_proj) / math.sqrt(self.head_dim)\n", "\n", " if mask is not None:\n", " energy = energy.masked_fill(mask == 0, -1e10)\n", "\n", " attention = torch.softmax(energy, dim=-1)\n", "\n", " # output = [batch size, num heads, query len, head dim]\n", " value_proj = value_proj.view(batch_size, key_len, self.n_heads, self.head_dim)\n", " value_proj = value_proj.permute(0, 2, 1, 3)\n", " output = torch.matmul(attention, value_proj)\n", "\n", " # linaer = [batch size, query len, hid dim]\n", " output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, self.hid_dim)\n", " linear_proj = self.linear_weight(output)\n", " return linear_proj, attention" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([128, 40, 64])\n", "torch.Size([128, 8, 40, 40])\n" ] } ], "source": [ "n_heads = 8\n", "self_attention = MultiHeadAttention(hid_dim, n_heads).to(device)\n", "self_attention_output, attention = self_attention(src_embedded, src_embedded, src_embedded)\n", "print(self_attention_output.shape)\n", "print(attention.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Position Wise Feed Forward Layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another building block is the position wise feed forward layer, which consists of two linear transformations. These transformations are identical across different positions. i.e. feed forward layers are typically used on a tensor of shape (`batch_size`, `hidden_dim`), here it is directly operating on a tensor of shape (`batch size`, `seq_len`, `hidden_dim`).\n", "\n", "The input is transformed from `hid_dim` to `pf_dim`, where `pf_dim` is usually a lot larger than `hid_dim`. Then an activation function is applied before it is transformed back into a `hid_dim` representation." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "class PositionWiseFeedForward(nn.Module):\n", "\n", " def __init__(self, hid_dim, pf_dim):\n", " super().__init__()\n", " self.hid_dim = hid_dim\n", " self.pf_dim = pf_dim\n", "\n", " self.fc1 = nn.Linear(hid_dim, pf_dim)\n", " self.fc2 = nn.Linear(pf_dim, hid_dim)\n", "\n", " def forward(self, inputs):\n", " # inputs = [batch size, src len, hid dim]\n", " fc1_output = torch.relu(self.fc1(inputs))\n", " fc2_output = self.fc2(fc1_output)\n", " return fc2_output" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 40, 64])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hid_dim = 64\n", "pf_dim = 256\n", "position_ff = PositionWiseFeedForward(hid_dim, pf_dim).to(device)\n", "position_ff_output = position_ff(self_attention_output)\n", "position_ff_output.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll now put our building blocks together to form the encoder.\n", "\n", "\n", "\n", "We first pass the source sentence through a position wise embedding layer, this is then followed by *N* (configurable) encoder layers, the \"meat\" of modern transformer based architecture. The main role of our encoder is to update our embeddings/weights so that it can learn some contextual information about our text sequence, e.g. the word \"bank\" will be updated to be more \"financial establishment\" like and less \"land along river\" if words such as money and investment are close to it.\n", "\n", "Inside the encoder layer, we start from the multi-head attention layer, perform dropout on it, apply a residual connection, pass it through a layer normalization layer. followed by a position-wise feedforward layer and then, again, apply dropout, a residual connection and then layer normalization to get the output, this is then fed into the next layer. This sounds like a mouthful, but potentially the code will clarify things a bit. Things worth noting:\n", "\n", "- Parameters are not shared between layers.\n", "- Multi head attention layer is used by the encoder layer to attend to the source sentence, i.e. it is calculating and applying attention over itself instead of another sequence, hence we call it self attention. This layer is the only layer that propagates information along the sequence, other layers operate on each individual token in isolation.\n", "- The gist behind layer normalization is that it normalizes the features' values across the hidden dimension so each feature has a mean of 0 and a standard deviation of 1. This trick along with residual connection, makes it easier to train neural networks with a larger number of layers, like the Transformer." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class EncoderLayer(nn.Module):\n", "\n", " def __init__(self, hid_dim, n_heads, pf_dim, dropout_p):\n", " super().__init__()\n", " self.hid_dim = hid_dim\n", " self.n_heads = n_heads\n", " self.pf_dim = pf_dim\n", " self.dropout_p = dropout_p\n", "\n", " self.self_attention_layer_norm = nn.LayerNorm(hid_dim)\n", " self.position_ff_layer_norm = nn.LayerNorm(hid_dim)\n", " self.self_attention = MultiHeadAttention(hid_dim, n_heads)\n", " self.position_ff = PositionWiseFeedForward(hid_dim, pf_dim)\n", "\n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, src, src_mask):\n", " # src = [batch size, src len, hid dim]\n", " # src_mask = [batch size, 1, 1, src len] \n", " self_attention_output, _ = self.self_attention(src, src, src, src_mask)\n", "\n", " # residual connection and layer norm\n", " self_attention_output = self.dropout(self_attention_output)\n", " self_attention_output = self.self_attention_layer_norm(src + self_attention_output)\n", "\n", " position_ff_output = self.position_ff(self_attention_output)\n", "\n", " # residual connection and layer norm\n", " # [batch size, src len, hid dim]\n", " position_ff_output = self.dropout(position_ff_output)\n", " output = self.position_ff_layer_norm(self_attention_output + position_ff_output) \n", " return output" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "class Encoder(nn.Module):\n", "\n", " def __init__(self, input_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers):\n", " super().__init__()\n", " self.input_dim = input_dim\n", " self.hid_dim = hid_dim\n", " self.max_len = max_len\n", " self.dropout_p = dropout_p\n", " self.n_heads = n_heads\n", " self.pf_dim = pf_dim\n", " self.n_layers = n_layers\n", "\n", " self.pos_wise_embedding = PositionWiseEmbedding(input_dim, hid_dim, max_len, dropout_p)\n", " self.layers = nn.ModuleList([\n", " EncoderLayer(hid_dim, n_heads, pf_dim, dropout_p)\n", " for _ in range(n_layers)\n", " ])\n", "\n", " def forward(self, src, src_mask = None):\n", "\n", " # src = [batch size, src len]\n", " # src_mask = [batch size, 1, 1, src len]\n", " src = self.pos_wise_embedding(src)\n", " for layer in self.layers:\n", " src = layer(src, src_mask)\n", "\n", " # [batch size, src len, hid dim]\n", " return src" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Encoder(\n", " (pos_wise_embedding): PositionWiseEmbedding(\n", " (tok_embedding): Embedding(5000, 64)\n", " (pos_embedding): Embedding(100, 64)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " (layers): ModuleList(\n", " (0): EncoderLayer(\n", " (self_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (position_ff_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (self_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (query_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (value_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (linear_weight): Linear(in_features=64, out_features=64, bias=True)\n", " )\n", " (position_ff): PositionWiseFeedForward(\n", " (fc1): Linear(in_features=64, out_features=256, bias=True)\n", " (fc2): Linear(in_features=256, out_features=64, bias=True)\n", " )\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " )\n", ")" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_dim = source_tokenizer.get_vocab_size()\n", "hid_dim = 64\n", "max_len = 100\n", "dropout_p = 0.5\n", "n_heads = 8\n", "pf_dim = 256\n", "n_layers = 1\n", "encoder = Encoder(input_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers).to(device)\n", "encoder" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 40, 64])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder_output = encoder(source.to(device))\n", "encoder_output.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now comes the decoder part:\n", "\n", "\n", "\n", "Decoder's main goal is to take our source sentence's encoded representation, $Z$, convert it into predicted tokens in the target sentence, $\\hat{Y}$. We then compare it with the actual tokens in the target sentence, $Y$, to calculate our loss and update our parameters to improve our predictions.\n", "\n", "Decoder layer contains similar building blocks as the encoder layer, except it now has two multi-head attention layers, `self_attention` and `encoder_attention`.\n", "\n", "The former attention layer performs self attention on our target sentence's embedding representation to generate a decoder representation. Whereas for the encoder/decoder attention layer, decoder's intermediate presentation will represent queries, whereas keys and values come from encoder representation's output." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "class DecoderLayer(nn.Module):\n", "\n", " def __init__(self, hid_dim, n_heads, pf_dim, dropout_p):\n", " super().__init__()\n", " self.hid_dim = hid_dim\n", " self.n_heads = n_heads\n", " self.pf_dim = pf_dim\n", " self.dropout_p = dropout_p\n", "\n", " self.self_attention_layer_norm = nn.LayerNorm(hid_dim)\n", " self.encoder_attention_layer_norm = nn.LayerNorm(hid_dim)\n", " self.position_ff_layer_norm = nn.LayerNorm(hid_dim)\n", " self.self_attention = MultiHeadAttention(hid_dim, n_heads)\n", " self.encoder_attention = MultiHeadAttention(hid_dim, n_heads)\n", " self.position_ff = PositionWiseFeedForward(hid_dim, pf_dim)\n", " \n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, trg, encoded_src, trg_mask, src_mask):\n", " # encoded_src = [batch size, src len, hid dim]\n", " # src_mask = [batch size, 1, 1, src len] \n", " self_attention_output, _ = self.self_attention(trg, trg, trg, trg_mask)\n", "\n", " # residual connection and layer norm\n", " self_attention_output = self.dropout(self_attention_output)\n", " self_attention_output = self.self_attention_layer_norm(trg + self_attention_output)\n", "\n", " encoder_attention_output, _ = self.encoder_attention(\n", " self_attention_output,\n", " encoded_src,\n", " encoded_src,\n", " src_mask\n", " )\n", " encoder_attention_output = self.dropout(encoder_attention_output)\n", " encoder_attention_output = self.encoder_attention_layer_norm(trg + encoder_attention_output)\n", "\n", " position_ff_output = self.position_ff(encoder_attention_output)\n", "\n", " # residual connection and layer norm\n", " # [batch size, src len, hid dim]\n", " position_ff_output = self.dropout(position_ff_output)\n", " output = self.position_ff_layer_norm(encoder_attention_output + position_ff_output) \n", " return output" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "class Decoder(nn.Module):\n", "\n", " def __init__(self, output_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers):\n", " super().__init__()\n", " self.output_dim = output_dim\n", " self.hid_dim = hid_dim\n", " self.max_len = max_len\n", " self.dropout_p = dropout_p\n", " self.n_heads = n_heads\n", " self.pf_dim = pf_dim\n", " self.n_layers = n_layers\n", "\n", " self.pos_wise_embedding = PositionWiseEmbedding(output_dim, hid_dim, max_len, dropout_p)\n", " self.layers = nn.ModuleList([\n", " DecoderLayer(hid_dim, n_heads, pf_dim, dropout_p)\n", " for _ in range(n_layers)\n", " ])\n", " self.fc_out = nn.Linear(hid_dim, output_dim)\n", "\n", " def forward(self, trg, encoded_src, trg_mask = None, src_mask = None):\n", "\n", " trg = self.pos_wise_embedding(trg)\n", " for layer in self.layers:\n", " trg = layer(trg, encoded_src, trg_mask, src_mask)\n", " \n", " output = self.fc_out(trg)\n", " return output" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Decoder(\n", " (pos_wise_embedding): PositionWiseEmbedding(\n", " (tok_embedding): Embedding(5000, 64)\n", " (pos_embedding): Embedding(100, 64)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " (layers): ModuleList(\n", " (0): DecoderLayer(\n", " (self_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (encoder_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (position_ff_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (self_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (query_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (value_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (linear_weight): Linear(in_features=64, out_features=64, bias=True)\n", " )\n", " (encoder_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (query_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (value_weight): Linear(in_features=64, out_features=64, bias=True)\n", " (linear_weight): Linear(in_features=64, out_features=64, bias=True)\n", " )\n", " (position_ff): PositionWiseFeedForward(\n", " (fc1): Linear(in_features=64, out_features=32, bias=True)\n", " (fc2): Linear(in_features=32, out_features=64, bias=True)\n", " )\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " )\n", " (fc_out): Linear(in_features=64, out_features=5000, bias=True)\n", ")" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output_dim = target_tokenizer.get_vocab_size()\n", "hid_dim = 64\n", "max_len = 100\n", "dropout_p = 0.5\n", "n_heads = 8\n", "pf_dim = 32\n", "n_layers = 1\n", "decoder = Decoder(output_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers).to(device)\n", "decoder" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 26, 5000])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder_output = encoder(source.to(device))\n", "decoder_output = decoder(target.to(device), encoder_output)\n", "decoder_output.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Seq2Seq" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our encoder and decoder, the final part is to have a Seq2Seq module that encapsulates the two. In this module, we'll also handle masking.\n", "\n", "The source mask is created by checking where our source sequence is not equal to the `` token. It is 1 where the token is not a token and 0 when it is. This is used in our encoder layers' multi-head attention mechanisms, where we want our model to not pay any attention to `` tokens, which contain no useful information.\n", "\n", "The target mask is a bit more involved. First, we create a mask for the tokens, as we did for the source mask. Next, we create a \"subsequent\" mask, `trg_sub_mask`, using `torch.tril`. This creates a diagonal matrix where the elements above the diagonal will be zero and the elements below the diagonal will be set to whatever the input tensor is. In this case, the input tensor will be a tensor filled with ones, meaning our `trg_sub_mask` will look something like this (for a target with 5 tokens):\n", "\n", "\\begin{matrix}\n", "1 & 0 & 0 & 0 & 0 \\\\\n", "1 & 1 & 0 & 0 & 0 \\\\\n", "1 & 1 & 1 & 0 & 0 \\\\\n", "1 & 1 & 1 & 1 & 0 \\\\\n", "1 & 1 & 1 & 1 & 1 \\\\\n", "\\end{matrix}\n", " \n", "This shows what each target token (row) is allowed to look at (column). Our first target token has a mask of [1, 0, 0, 0, 0] which means it can only look at the first target token, whereas the second target token has a mask of [1, 1, 0, 0, 0] which it means it can look at both the first and second target tokens and so on.\n", "\n", "The \"subsequent\" mask is then logically anded with the padding mask, this combines the two masks ensuring both the subsequent tokens and the padding tokens cannot be attended to. For example if the last two tokens were `` tokens the final target mask would look like:\n", "\n", "\\begin{matrix}\n", "1 & 0 & 0 & 0 & 0 \\\\\n", "1 & 1 & 0 & 0 & 0 \\\\\n", "1 & 1 & 1 & 0 & 0 \\\\\n", "1 & 1 & 1 & 0 & 0 \\\\\n", "1 & 1 & 1 & 0 & 0 \\\\\n", "\\end{matrix}\n", " \n", "These masks are fed in into model along with source and target sentence to get out predicted target output.\n", " \n", "Site Note: Introducing some other terminology that we might come across. The need to create a subsequent mask is very common in autoregressive model, where the task is to predict the next token in the sequence (e.g. language model). By introducing this masking, we are making the self attention block casual. Different implementation or library might have different ways of specifying this masking, but the core idea is to prevent the model from \"cheating\" by copying the tokens that are after the ones it's currently processing." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "class Seq2Seq(nn.Module):\n", "\n", " def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.src_pad_idx = src_pad_idx\n", " self.trg_pad_idx = trg_pad_idx\n", "\n", " def make_src_mask(self, src):\n", " \"\"\"\n", " the padding mask is unsqueezed so it can be correctly broadcasted\n", " when applying the mask to the attention weights, which is of shape\n", " [batch size, n heads, seq len, seq len].\n", " \"\"\"\n", " src_pad_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n", " return src_pad_mask\n", "\n", " def make_trg_mask(self, trg):\n", " trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)\n", "\n", " trg_len = trg.shape[1]\n", " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool().to(trg.device)\n", " trg_mask = trg_pad_mask & trg_sub_mask\n", " return trg_mask\n", "\n", " def forward(self, src, trg):\n", " src_mask = self.make_src_mask(src)\n", " trg_mask = self.make_trg_mask(trg)\n", "\n", " encoded_src = self.encoder(src, src_mask)\n", " decoder_output = self.decoder(trg, encoded_src, trg_mask, src_mask)\n", " return decoder_output" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Seq2Seq(\n", " (encoder): Encoder(\n", " (pos_wise_embedding): PositionWiseEmbedding(\n", " (tok_embedding): Embedding(5000, 512)\n", " (pos_embedding): Embedding(100, 512)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (layers): ModuleList(\n", " (0-5): 6 x EncoderLayer(\n", " (self_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (position_ff_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (self_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (query_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (value_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (linear_weight): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (position_ff): PositionWiseFeedForward(\n", " (fc1): Linear(in_features=512, out_features=512, bias=True)\n", " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (decoder): Decoder(\n", " (pos_wise_embedding): PositionWiseEmbedding(\n", " (tok_embedding): Embedding(5000, 512)\n", " (pos_embedding): Embedding(100, 512)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (layers): ModuleList(\n", " (0-2): 3 x DecoderLayer(\n", " (self_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (encoder_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (position_ff_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (self_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (query_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (value_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (linear_weight): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (encoder_attention): MultiHeadAttention(\n", " (key_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (query_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (value_weight): Linear(in_features=512, out_features=512, bias=True)\n", " (linear_weight): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (position_ff): PositionWiseFeedForward(\n", " (fc1): Linear(in_features=512, out_features=512, bias=True)\n", " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (fc_out): Linear(in_features=512, out_features=5000, bias=True)\n", " )\n", ")" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "source_pad_idx = source_tokenizer.token_to_id(pad_token)\n", "target_pad_idx = target_tokenizer.token_to_id(pad_token)\n", "\n", "INPUT_DIM = source_tokenizer.get_vocab_size()\n", "OUTPUT_DIM = target_tokenizer.get_vocab_size()\n", "MAX_LEN = 100\n", "HID_DIM = 512\n", "ENC_LAYERS = 6\n", "DEC_LAYERS = 3\n", "ENC_HEADS = 8\n", "DEC_HEADS = 8\n", "ENC_PF_DIM = 512\n", "DEC_PF_DIM = 512\n", "ENC_DROPOUT = 0.1\n", "DEC_DROPOUT = 0.1\n", "\n", "encoder = Encoder(\n", " INPUT_DIM, \n", " HID_DIM,\n", " MAX_LEN,\n", " ENC_DROPOUT, \n", " ENC_HEADS, \n", " ENC_PF_DIM, \n", " ENC_LAYERS\n", ")\n", "\n", "decoder = Decoder(\n", " OUTPUT_DIM, \n", " HID_DIM,\n", " MAX_LEN,\n", " DEC_DROPOUT,\n", " DEC_HEADS,\n", " DEC_PF_DIM,\n", " DEC_LAYERS\n", ")\n", "\n", "model = Seq2Seq(encoder, decoder, source_pad_idx, target_pad_idx).to(device)\n", "model" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 26, 5000])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = model(source.to(device), target.to(device))\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The model has 25,144,200 trainable parameters\n" ] } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "print(f'The model has {count_parameters(model):,} trainable parameters')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training loop also requires a bit of explanation.\n", "\n", "We want our model to predict the `` token but not have it be an input into our model, hence we slice the `` token off the end of our target sequence.\n", "\n", "\\begin{align}\n", "\\text{trg} &= [sos, x_1, x_2, x_3, eos] \\\\\n", "\\text{trg[:-1]} &= [sos, x_1, x_2, x_3]\n", "\\end{align}\n", "\n", "\n", "We then calculate our loss using the original target tensor with the `` token sliced off the front, retaining the `` token.\n", "\n", "\\begin{align}\n", "\\text{output} &= [y_1, y_2, y_3, eos] \\\\\n", "\\text{trg[1:]} &= [x_1, x_2, x_3, eos]\n", "\\end{align}\n", "\n", "All in all, our model receives the target up to the last character (excluding the last), whereas the ground truth will be from the second character onward." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "def train(model, iterator, optimizer, criterion, clip):\n", " \n", " model.train()\n", " epoch_loss = 0\n", " for i, (src, trg) in enumerate(iterator):\n", " src = src.to(device)\n", " trg = trg.to(device)\n", "\n", " optimizer.zero_grad()\n", " output = model(src, trg[:, :-1])\n", " \n", " # output = [batch size, trg len - 1, output dim]\n", " # trg = [batch size, trg len]\n", " output_dim = output.shape[-1]\n", " output = output.contiguous().view(-1, output_dim)\n", " trg = trg[:, 1:].contiguous().view(-1)\n", "\n", " loss = criterion(output, trg)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n", " optimizer.step()\n", " epoch_loss += loss.item()\n", " \n", " return epoch_loss / len(iterator)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluation loop is similar to the training loop, just without the updating the model's parameters." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "def evaluate(model, iterator, criterion):\n", " \n", " model.eval()\n", " epoch_loss = 0\n", " with torch.no_grad():\n", " for i, (src, trg) in enumerate(iterator):\n", " src = src.to(device)\n", " trg = trg.to(device)\n", "\n", " output = model(src, trg[:, :-1])\n", " \n", " # output = [batch size, trg len - 1, output dim]\n", " # trg = [batch size, trg len]\n", " output_dim = output.shape[-1]\n", " output = output.contiguous().view(-1, output_dim)\n", " trg = trg[:, 1:].contiguous().view(-1)\n", "\n", " loss = criterion(output, trg)\n", " epoch_loss += loss.item()\n", " \n", " return epoch_loss / len(iterator)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "def epoch_time(start_time, end_time):\n", " elapsed_time = end_time - start_time\n", " elapsed_mins = int(elapsed_time / 60)\n", " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", " return elapsed_mins, elapsed_secs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While defining our loss function, we also ensure we ignore loss that are calculated over the `` tokens." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "MODEL_CHECKPOINT = 'transformer.pt'\n", "N_EPOCHS = 10\n", "CLIP = 1\n", "LEARNING_RATE = 0.0001\n", "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", "criterion = nn.CrossEntropyLoss(ignore_index=target_pad_idx)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 01 | Time: 0m 16s\n", "\tTrain Loss: 4.926 | Train PPL: 137.785\n", "\t Val. Loss: 4.057 | Val. PPL: 57.784\n", "Epoch: 02 | Time: 0m 16s\n", "\tTrain Loss: 3.808 | Train PPL: 45.049\n", "\t Val. Loss: 3.566 | Val. PPL: 35.371\n", "Epoch: 03 | Time: 0m 16s\n", "\tTrain Loss: 3.422 | Train PPL: 30.628\n", "\t Val. Loss: 3.290 | Val. PPL: 26.830\n", "Epoch: 04 | Time: 0m 16s\n", "\tTrain Loss: 3.156 | Train PPL: 23.484\n", "\t Val. Loss: 3.078 | Val. PPL: 21.713\n", "Epoch: 05 | Time: 0m 16s\n", "\tTrain Loss: 2.945 | Train PPL: 19.018\n", "\t Val. Loss: 2.918 | Val. PPL: 18.507\n", "Epoch: 06 | Time: 0m 16s\n", "\tTrain Loss: 2.766 | Train PPL: 15.891\n", "\t Val. Loss: 2.789 | Val. PPL: 16.257\n", "Epoch: 07 | Time: 0m 17s\n", "\tTrain Loss: 2.606 | Train PPL: 13.544\n", "\t Val. Loss: 2.681 | Val. PPL: 14.599\n", "Epoch: 08 | Time: 0m 16s\n", "\tTrain Loss: 2.463 | Train PPL: 11.738\n", "\t Val. Loss: 2.585 | Val. PPL: 13.259\n", "Epoch: 09 | Time: 0m 17s\n", "\tTrain Loss: 2.336 | Train PPL: 10.340\n", "\t Val. Loss: 2.515 | Val. PPL: 12.368\n", "Epoch: 10 | Time: 0m 17s\n", "\tTrain Loss: 2.216 | Train PPL: 9.173\n", "\t Val. Loss: 2.448 | Val. PPL: 11.561\n" ] } ], "source": [ "best_valid_loss = float('inf')\n", "for epoch in range(N_EPOCHS):\n", " start_time = time.time()\n", " train_loss = train(model, dataloader_train, optimizer, criterion, CLIP)\n", " valid_loss = evaluate(model, dataloader_val, criterion)\n", " end_time = time.time()\n", " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), MODEL_CHECKPOINT)\n", "\n", " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Evaluation" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "| Test Loss: 2.429 | Test PPL: 11.352 |\n" ] } ], "source": [ "model.load_state_dict(torch.load(MODEL_CHECKPOINT))\n", "test_loss = evaluate(model, dataloader_test, criterion)\n", "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "def predict(source, model, source_tokenizer, target_tokenizer):\n", " \"\"\"\n", " Given the raw token, predict the translation using greedy search.\n", " This is a naive implementation without batching\n", " \"\"\"\n", " src_indices = [source_init_idx] + source_tokenizer.encode(source).ids + [source_eos_idx]\n", " src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)\n", " src_mask = model.make_src_mask(src_tensor)\n", "\n", " # separating out the encoder and decoder allows us to generate the encoded source\n", " # sentence once and share it throughout the target prediction step\n", " with torch.no_grad():\n", " encoded_src = model.encoder(src_tensor, src_mask)\n", "\n", " # greedy search\n", " # sequentially predict the target sequence starting from the init sentence token\n", " trg_indices = [target_init_idx]\n", " for _ in range(max_len):\n", " trg_tensor = torch.LongTensor(trg_indices).unsqueeze(0).to(device)\n", " trg_mask = model.make_trg_mask(trg_tensor)\n", "\n", " with torch.no_grad():\n", " output = model.decoder(trg_tensor, encoded_src, trg_mask, src_mask)\n", "\n", " # add the last predicted token\n", " pred_token = output.argmax(dim=2)[:, -1].item()\n", " trg_indices.append(pred_token)\n", " if pred_token == target_eos_idx:\n", " break\n", "\n", " return target_tokenizer.decode(trg_indices)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n", "target: Two young, White males are outside near many bushes.\n" ] }, { "data": { "text/plain": [ "'two young men are outside near white.'" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translation = dataset_dict['train'][0]\n", "source_raw = translation[source_lang]\n", "target_raw = translation[target_lang]\n", "print('source: ', source_raw)\n", "print('target: ', target_raw)\n", "\n", "predict(source_raw, model, source_tokenizer, target_tokenizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformer Module" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of resorting to our own Transformer encoder and decoder implementation, PyTorch's `nn` module already comes with a pre-built one. The major difference here is it expects a different [shape](https://pytorch.org/docs/master/generated/torch.nn.Transformer.html#torch.nn.Transformer.forward) for the padding and subsequent mask." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "class Transformer(nn.Module):\n", " \"\"\"\n", " \n", " References\n", " ----------\n", " https://pytorch.org/docs/master/generated/torch.nn.Transformer.html\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " encoder_embedding_dim,\n", " decoder_embedding_dim,\n", " source_pad_idx,\n", " target_pad_idx,\n", " encoder_max_len = 100,\n", " decoder_max_len = 100,\n", " model_dim = 512,\n", " num_head = 8,\n", " encoder_num_layers = 3,\n", " decoder_num_layers = 3,\n", " feedforward_dim = 512,\n", " dropout = 0.1\n", " ):\n", " super().__init__()\n", " self.source_pad_idx = source_pad_idx\n", " self.target_pad_idx = target_pad_idx\n", "\n", " self.encoder_embedding = PositionWiseEmbedding(\n", " encoder_embedding_dim,\n", " model_dim,\n", " encoder_max_len,\n", " dropout\n", " )\n", " self.decoder_embedding = PositionWiseEmbedding(\n", " decoder_embedding_dim,\n", " model_dim,\n", " decoder_max_len,\n", " dropout\n", " )\n", "\n", " layer_params = {\n", " 'd_model': model_dim,\n", " 'nhead': num_head,\n", " 'dim_feedforward': feedforward_dim,\n", " 'dropout': dropout\n", " }\n", " self.encoder = nn.TransformerEncoder(\n", " nn.TransformerEncoderLayer(**layer_params),\n", " encoder_num_layers\n", " )\n", " self.decoder = nn.TransformerDecoder(\n", " nn.TransformerDecoderLayer(**layer_params),\n", " decoder_num_layers\n", " )\n", " self.linear = nn.Linear(model_dim, decoder_embedding_dim)\n", "\n", " def forward(self, src_tensor, trg_tensor):\n", " # enc_src = self.encoder(src, src_mask)\n", " # decoder_output = self.decoder(trg, enc_src, trg_mask, src_mask)\n", "\n", " # in PyTorch's Transformer Encoder and Decoder implementation, they\n", " # expect the first dimension to be batch size\n", " src_encoded = self.encode(src_tensor)\n", " output = self.decode(trg_tensor, src_encoded)\n", " return output\n", "\n", " def encode(self, src_tensor):\n", " src_key_padding_mask = generate_key_padding_mask(src_tensor, self.source_pad_idx)\n", " src_embedded = self.encoder_embedding(src_tensor).permute(1, 0, 2)\n", " return self.encoder(src_embedded, src_key_padding_mask=src_key_padding_mask)\n", "\n", " def decode(self, trg_tensor, src_encoded):\n", " trg_key_padding_mask = generate_key_padding_mask(trg_tensor, self.target_pad_idx)\n", " trg_mask = generate_subsequent_mask(trg_tensor)\n", " trg_embedded = self.decoder_embedding(trg_tensor).permute(1, 0, 2)\n", " decoder_output = self.decoder(\n", " trg_embedded,\n", " src_encoded,\n", " trg_mask,\n", " tgt_key_padding_mask=trg_key_padding_mask\n", " ).permute(1, 0, 2)\n", " return self.linear(decoder_output)\n", "\n", " def predict(self, src_tensor, max_len = 100):\n", " # separating out the encoder and decoder allows us to generate the encoded source\n", " # sentence once and share it throughout the target prediction step\n", " with torch.no_grad():\n", " src_encoded = self.encode(src_tensor)\n", " \n", " # greedy search\n", " # sequentially predict the target sequence starting from the init sentence token\n", " trg_indices = [target_init_idx]\n", " for _ in range(max_len):\n", " trg_tensor = torch.LongTensor(trg_indices).unsqueeze(0).to(src_tensor.device)\n", " with torch.no_grad():\n", " output = self.decode(trg_tensor, src_encoded)\n", "\n", " # add the last predicted token\n", " pred_token = output.argmax(dim=2)[:, -1].item()\n", " trg_indices.append(pred_token)\n", " if pred_token == target_eos_idx:\n", " break\n", "\n", " return trg_indices\n", "\n", "\n", "def generate_subsequent_mask(inputs):\n", " \"\"\"\n", " If a BoolTensor is provided, positions with True are not\n", " allowed to attend while False values will be unchanged\n", " \"\"\"\n", " inputs_len = inputs.shape[1]\n", " square = torch.ones((inputs_len, inputs_len)).to(inputs.device)\n", " mask = (torch.tril(square) == 0.0).bool()\n", " return mask\n", "\n", "\n", "def generate_key_padding_mask(inputs, pad_idx):\n", " return (inputs == pad_idx).to(inputs.device)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 26, 5000])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "INPUT_DIM = source_tokenizer.get_vocab_size()\n", "OUTPUT_DIM = target_tokenizer.get_vocab_size()\n", "\n", "transformer = Transformer(INPUT_DIM, OUTPUT_DIM, source_pad_idx, target_pad_idx).to(device)\n", "\n", "with torch.no_grad():\n", " output = transformer(source.to(device), target.to(device))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The model has 20,410,248 trainable parameters\n" ] } ], "source": [ "print(f'The model has {count_parameters(transformer):,} trainable parameters')" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "MODEL_CHECKPOINT = 'transformer.pt'\n", "N_EPOCHS = 10\n", "CLIP = 1\n", "LEARNING_RATE = 0.0005\n", "optimizer = optim.Adam(transformer.parameters(), lr=LEARNING_RATE)\n", "criterion = nn.CrossEntropyLoss(ignore_index=target_pad_idx)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 01 | Time: 0m 13s\n", "\tTrain Loss: 4.061 | Train PPL: 58.036\n", "\t Val. Loss: 3.313 | Val. PPL: 27.469\n", "Epoch: 02 | Time: 0m 13s\n", "\tTrain Loss: 3.010 | Train PPL: 20.289\n", "\t Val. Loss: 2.769 | Val. PPL: 15.946\n", "Epoch: 03 | Time: 0m 13s\n", "\tTrain Loss: 2.530 | Train PPL: 12.555\n", "\t Val. Loss: 2.478 | Val. PPL: 11.921\n", "Epoch: 04 | Time: 0m 13s\n", "\tTrain Loss: 2.180 | Train PPL: 8.848\n", "\t Val. Loss: 2.319 | Val. PPL: 10.167\n", "Epoch: 05 | Time: 0m 13s\n", "\tTrain Loss: 1.922 | Train PPL: 6.834\n", "\t Val. Loss: 2.213 | Val. PPL: 9.140\n", "Epoch: 06 | Time: 0m 13s\n", "\tTrain Loss: 1.703 | Train PPL: 5.489\n", "\t Val. Loss: 2.163 | Val. PPL: 8.697\n", "Epoch: 07 | Time: 0m 13s\n", "\tTrain Loss: 1.527 | Train PPL: 4.602\n", "\t Val. Loss: 2.148 | Val. PPL: 8.570\n", "Epoch: 08 | Time: 0m 13s\n", "\tTrain Loss: 1.361 | Train PPL: 3.898\n", "\t Val. Loss: 2.139 | Val. PPL: 8.488\n", "Epoch: 09 | Time: 0m 13s\n", "\tTrain Loss: 1.235 | Train PPL: 3.438\n", "\t Val. Loss: 2.143 | Val. PPL: 8.526\n", "Epoch: 10 | Time: 0m 13s\n", "\tTrain Loss: 1.128 | Train PPL: 3.089\n", "\t Val. Loss: 2.192 | Val. PPL: 8.957\n" ] } ], "source": [ "best_valid_loss = float('inf')\n", "for epoch in range(N_EPOCHS):\n", " start_time = time.time()\n", " train_loss = train(transformer, dataloader_train, optimizer, criterion, CLIP)\n", " valid_loss = evaluate(transformer, dataloader_val, criterion)\n", " end_time = time.time()\n", " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(transformer.state_dict(), MODEL_CHECKPOINT)\n", "\n", " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "| Test Loss: 2.134 | Test PPL: 8.445 |\n" ] } ], "source": [ "transformer.load_state_dict(torch.load(MODEL_CHECKPOINT))\n", "test_loss = evaluate(transformer, dataloader_test, criterion)\n", "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def transformer_predict(source, model, source_tokenizer, target_tokenizer):\n", " src_indices = [source_init_idx] + source_tokenizer.encode(source).ids + [source_eos_idx]\n", " src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)\n", " trg_indices = model.predict(src_tensor)\n", " return target_tokenizer.decode(trg_indices)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n", "target: Two young, White males are outside near many bushes.\n" ] }, { "data": { "text/plain": [ "'two young men are outside, one in white, are in the doorway.'" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translation = dataset_dict['train'][0]\n", "source_raw = translation[source_lang]\n", "target_raw = translation[target_lang]\n", "print('source: ', source_raw)\n", "print('target: ', target_raw)\n", "\n", "transformer_predict(source_raw, transformer, source_tokenizer, target_tokenizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we delved into the implementation of Transformer models. Although originally proposed for solving NLP tasks like machine translation, this module or building block is also gaining popularity in other fields such as computer vision [[2]](https://nbviewer.jupyter.org/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- [[1]](https://nbviewer.jupyter.org/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb) Jupyter Notebook: Attention is All You Need\n", "- [[2]](https://nbviewer.jupyter.org/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.ipynb) Jupyter Notebook: Tutorial 6: Transformers and Multi-Head Attention\n", "- [[3]](https://colab.research.google.com/drive/1swXWW5sOLW8zSZBaQBYcGQkQ_Bje_bmI) Colab: Simple PyTorch Transformer Example with Greedy Decoding\n", "- [[4]](http://peterbloem.nl/blog/transformers) Blog: Transformers from scratch\n", "- [[5]](https://scale.com/blog/pytorch-improvements) Blog: Making Pytorch Transformer Twice as Fast on Sequence Generation\n", "- [[6]](https://theaisummer.com/transformer/) Blog: How Transformers work in deep learning and NLP: an intuitive introduction\n", "- [[7]](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) PyTorch Documentation: Sequence to sequence modeling with nn.Transformer and Torchtext\n", "- [[8]](http://nlp.seas.harvard.edu/annotated-transformer/) The Annotated Transformer\n", "- [[9]](https://arxiv.org/abs/1706.03762) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin - Attention is All you Need (2017)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "281.797px" }, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }