{ "cells": [ { "cell_type": "markdown", "id": "b9609829", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Machine Translation and the Dataset\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1af4809f", "metadata": { "attributes": { "classes": [], "id": "", "n": "3" }, "execution": { "iopub.execute_input": "2023-08-18T19:25:35.594366Z", "iopub.status.busy": "2023-08-18T19:25:35.593736Z", "iopub.status.idle": "2023-08-18T19:25:38.332100Z", "shell.execute_reply": "2023-08-18T19:25:38.330551Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import os\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "448434f2", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Downloading and Preprocessing the Dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "0c27fc69", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:25:38.347854Z", "iopub.status.busy": "2023-08-18T19:25:38.347146Z", "iopub.status.idle": "2023-08-18T19:25:38.705705Z", "shell.execute_reply": "2023-08-18T19:25:38.704813Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/fra-eng.zip from http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Go.\tVa !\n", "Hi.\tSalut !\n", "Run!\tCours !\n", "Run!\tCourez !\n", "Who?\tQui ?\n", "Wow!\tÇa alors !\n", "\n" ] } ], "source": [ "class MTFraEng(d2l.DataModule): \n", " \"\"\"The English-French dataset.\"\"\"\n", " def _download(self):\n", " d2l.extract(d2l.download(\n", " d2l.DATA_URL+'fra-eng.zip', self.root,\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5'))\n", " with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:\n", " return f.read()\n", "\n", "data = MTFraEng()\n", "raw_text = data._download()\n", "print(raw_text[:75])" ] }, { "cell_type": "markdown", "id": "7063c83a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Proceed with several preprocessing steps" ] }, { "cell_type": "code", "execution_count": 5, "id": "4a3c36d4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:25:38.717999Z", "iopub.status.busy": "2023-08-18T19:25:38.717445Z", "iopub.status.idle": "2023-08-18T19:25:41.820750Z", "shell.execute_reply": "2023-08-18T19:25:41.819869Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go .\tva !\n", "hi .\tsalut !\n", "run !\tcours !\n", "run !\tcourez !\n", "who ?\tqui ?\n", "wow !\tça alors !\n" ] } ], "source": [ "@d2l.add_to_class(MTFraEng) \n", "def _preprocess(self, text):\n", " text = text.replace('\\u202f', ' ').replace('\\xa0', ' ')\n", " no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '\n", " out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n", " for i, char in enumerate(text.lower())]\n", " return ''.join(out)\n", "\n", "text = data._preprocess(raw_text)\n", "print(text[:80])" ] }, { "cell_type": "markdown", "id": "bb5478b0", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Tokenization" ] }, { "cell_type": "code", "execution_count": 7, "id": "3bedb46e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:25:41.833010Z", "iopub.status.busy": "2023-08-18T19:25:41.832457Z", "iopub.status.idle": "2023-08-18T19:25:43.340916Z", "shell.execute_reply": "2023-08-18T19:25:43.340019Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "([['go', '.', ''],\n", " ['hi', '.', ''],\n", " ['run', '!', ''],\n", " ['run', '!', ''],\n", " ['who', '?', ''],\n", " ['wow', '!', '']],\n", " [['va', '!', ''],\n", " ['salut', '!', ''],\n", " ['cours', '!', ''],\n", " ['courez', '!', ''],\n", " ['qui', '?', ''],\n", " ['ça', 'alors', '!', '']])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@d2l.add_to_class(MTFraEng) \n", "def _tokenize(self, text, max_examples=None):\n", " src, tgt = [], []\n", " for i, line in enumerate(text.split('\\n')):\n", " if max_examples and i > max_examples: break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " src.append([t for t in f'{parts[0]} '.split(' ') if t])\n", " tgt.append([t for t in f'{parts[1]} '.split(' ') if t])\n", " return src, tgt\n", "\n", "src, tgt = data._tokenize(text)\n", "src[:6], tgt[:6]" ] }, { "cell_type": "markdown", "id": "2b962bcf", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Plot the histogram of the number of tokens per text sequence" ] }, { "cell_type": "code", "execution_count": 9, "id": "b8fa90b4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:25:43.353348Z", "iopub.status.busy": "2023-08-18T19:25:43.352783Z", "iopub.status.idle": "2023-08-18T19:25:43.722959Z", "shell.execute_reply": "2023-08-18T19:25:43.722055Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:25:43.667026\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):\n", " \"\"\"Plot the histogram for list length pairs.\"\"\"\n", " d2l.set_figsize()\n", " _, _, patches = d2l.plt.hist(\n", " [[len(l) for l in xlist], [len(l) for l in ylist]])\n", " d2l.plt.xlabel(xlabel)\n", " d2l.plt.ylabel(ylabel)\n", " for patch in patches[1].patches:\n", " patch.set_hatch('/')\n", " d2l.plt.legend(legend)\n", "\n", "show_list_len_pair_hist(['source', 'target'], '\n", " 'count', src, tgt);" ] }, { "cell_type": "markdown", "id": "2f827d6c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Each example sequence\n", "had a fixed length" ] }, { "cell_type": "code", "execution_count": 11, "id": "642915b2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:25:43.734774Z", "iopub.status.busy": "2023-08-18T19:25:43.734205Z", "iopub.status.idle": "2023-08-18T19:25:43.741936Z", "shell.execute_reply": "2023-08-18T19:25:43.741036Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(MTFraEng) \n", "def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):\n", " super(MTFraEng, self).__init__()\n", " self.save_hyperparameters()\n", " self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(\n", " self._download())\n", "\n", "@d2l.add_to_class(MTFraEng) \n", "def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):\n", " def _build_array(sentences, vocab, is_tgt=False):\n", " pad_or_trim = lambda seq, t: (\n", " seq[:t] if len(seq) > t else seq + [''] * (t - len(seq)))\n", " sentences = [pad_or_trim(s, self.num_steps) for s in sentences]\n", " if is_tgt:\n", " sentences = [[''] + s for s in sentences]\n", " if vocab is None:\n", " vocab = d2l.Vocab(sentences, min_freq=2)\n", " array = torch.tensor([vocab[s] for s in sentences])\n", " valid_len = (array != vocab['']).type(torch.int32).sum(1)\n", " return array, vocab, valid_len\n", " src, tgt = self._tokenize(self._preprocess(raw_text),\n", " self.num_train + self.num_val)\n", " src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)\n", " tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)\n", " return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),\n", " src_vocab, tgt_vocab)" ] }, { "cell_type": "markdown", "id": "df7d488a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Reading the Dataset" ] }, { "cell_type": "code", "execution_count": 12, "id": "5eb64a4b", "metadata": { "attributes": { "classes": [], "id": "", "n": "10" }, "execution": { "iopub.execute_input": "2023-08-18T19:25:43.745419Z", "iopub.status.busy": "2023-08-18T19:25:43.744866Z", "iopub.status.idle": "2023-08-18T19:25:43.749246Z", "shell.execute_reply": "2023-08-18T19:25:43.748444Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(MTFraEng) \n", "def get_dataloader(self, train):\n", " idx = slice(0, self.num_train) if train else slice(self.num_train, None)\n", " return self.get_tensorloader(self.arrays, train, idx)" ] }, { "cell_type": "markdown", "id": "973dad82", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Read the first minibatch from the English--French dataset" ] }, { "cell_type": "code", "execution_count": 13, "id": "ef39df99", "metadata": { "attributes": { "classes": [], "id": "", "n": "11" }, "execution": { "iopub.execute_input": "2023-08-18T19:25:43.752740Z", "iopub.status.busy": "2023-08-18T19:25:43.752195Z", "iopub.status.idle": "2023-08-18T19:25:47.133736Z", "shell.execute_reply": "2023-08-18T19:25:47.132842Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: tensor([[117, 182, 0, 3, 4, 4, 4, 4, 4],\n", " [ 62, 72, 2, 3, 4, 4, 4, 4, 4],\n", " [ 57, 124, 0, 3, 4, 4, 4, 4, 4]], dtype=torch.int32)\n", "decoder input: tensor([[ 3, 37, 100, 58, 160, 0, 4, 5, 5],\n", " [ 3, 6, 2, 4, 5, 5, 5, 5, 5],\n", " [ 3, 180, 0, 4, 5, 5, 5, 5, 5]], dtype=torch.int32)\n", "source len excluding pad: tensor([4, 4, 4], dtype=torch.int32)\n", "label: tensor([[ 37, 100, 58, 160, 0, 4, 5, 5, 5],\n", " [ 6, 2, 4, 5, 5, 5, 5, 5, 5],\n", " [180, 0, 4, 5, 5, 5, 5, 5, 5]], dtype=torch.int32)\n" ] } ], "source": [ "data = MTFraEng(batch_size=3)\n", "src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))\n", "print('source:', src.type(torch.int32))\n", "print('decoder input:', tgt.type(torch.int32))\n", "print('source len excluding pad:', src_valid_len.type(torch.int32))\n", "print('label:', label.type(torch.int32))" ] }, { "cell_type": "code", "execution_count": 15, "id": "102aced4", "metadata": { "attributes": { "classes": [], "id": "", "n": "13" }, "execution": { "iopub.execute_input": "2023-08-18T19:25:47.145332Z", "iopub.status.busy": "2023-08-18T19:25:47.144692Z", "iopub.status.idle": "2023-08-18T19:25:47.150415Z", "shell.execute_reply": "2023-08-18T19:25:47.149609Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "source: ['hi', '.', '', '', '', '', '', '', '']\n", "target: ['', 'salut', '.', '', '', '', '', '', '']\n" ] } ], "source": [ "@d2l.add_to_class(MTFraEng) \n", "def build(self, src_sentences, tgt_sentences):\n", " raw_text = '\\n'.join([src + '\\t' + tgt for src, tgt in zip(\n", " src_sentences, tgt_sentences)])\n", " arrays, _, _ = self._build_arrays(\n", " raw_text, self.src_vocab, self.tgt_vocab)\n", " return arrays\n", "\n", "src, tgt, _, _ = data.build(['hi .'], ['salut .'])\n", "print('source:', data.src_vocab.to_tokens(src[0].type(torch.int32)))\n", "print('target:', data.tgt_vocab.to_tokens(tgt[0].type(torch.int32)))" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }