{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "day34_03_lecture_chap1_exercise_public.ipynb", "provenance": [], "collapsed_sections": [ "B1UDdsNvruEi" ], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/rkti498/e_shikaku/blob/main/day34_03_lecture_chap1_exercise_public.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "Ht1-DGiJAM7f" }, "source": [ "# 演習 Sequence-to-Sequence (Seq2Seq) モデル" ] }, { "cell_type": "markdown", "metadata": { "id": "lImUIx2SAM7h" }, "source": [ "Sequence-to-Sequence (Seq2Seq) モデルは、系列を入力として系列を出力するモデルです。\n", "\n", "入力系列をRNNで固定長のベクトルに変換(= Encode)し、そのベクトルを用いて系列を出力(= Decode)することから、Encoder-Decoder モデルとも呼ばれます。\n", "\n", "RNNの代わりにLSTMやGRUでも可能です。\n", "\n", "機械翻訳のほか、文書要約や対話生成にも使われます。<br>\n", "今回は機械翻訳を例にとって解説していきます。" ] }, { "cell_type": "markdown", "metadata": { "id": "lUXx-h-RUGnt" }, "source": [ "" ] }, { "cell_type": "code", "metadata": { "id": "N-CmUMypoDTn" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 250 }, "id": "MMRW9btFUHnI", "outputId": "f141506f-2b4a-47e8-8516-ca9cc9400262" }, "source": [ "# このページの情報で、特定のバージョンじゃないとだめと書いてあった。\n", "# https://github.com/buildout/buildout.wheel/issues/18\n", "\n", "!pip3 install wheel==0.34.1" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "text": [ "Collecting wheel==0.34.1\n", " Downloading https://files.pythonhosted.org/packages/81/44/db78754a73d9a88c5bd1bb692b40004410970e88aa0c5dff20b57f231505/wheel-0.34.1-py2.py3-none-any.whl\n", "\u001b[31mERROR: tensorflow 2.5.0 has requirement wheel~=0.35, but you'll have wheel 0.34.1 which is incompatible.\u001b[0m\n", "Installing collected packages: wheel\n", " Found existing installation: wheel 0.36.2\n", " Uninstalling wheel-0.36.2:\n", " Successfully uninstalled wheel-0.36.2\n", "Successfully installed wheel-0.34.1\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { "pip_warning": { "packages": [ "wheel" ] } } }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "V4sqRdFNGAIN" }, "source": [ "" ] }, { "cell_type": "code", "metadata": { "id": "AlrUn0gcNevH", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "4062bee3-3238-4cca-d019-316b52121218" }, "source": [ "from os import path\n", "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", "platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", "\n", "accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n", "\n", "!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision\n", "import torch\n", "print(torch.__version__)\n", "print(torch.cuda.is_available())" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "\u001b[31m ERROR: HTTP error 403 while getting http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl\u001b[0m\n", "\u001b[31m ERROR: Could not install requirement torch==0.4.0 from http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl because of error 403 Client Error: Forbidden for url: http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl\u001b[0m\n", "\u001b[31mERROR: Could not install requirement torch==0.4.0 from http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl because of HTTP error 403 Client Error: Forbidden for url: http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl for URL http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-cp37m-linux_x86_64.whl\u001b[0m\n", "1.9.0+cu102\n", "True\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "55_hatEGIB3N", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "02a2d5fb-45e1-4b6a-c6e4-c75916f34de9" }, "source": [ "! wget https://www.dropbox.com/s/9narw5x4uizmehh/utils.py\n", "! mkdir images data\n", "\n", "# data取得\n", "! wget https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en -P data/\n", "! wget https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja -P data/\n", "! wget https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en -P data/\n", "! wget https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja -P data/\n", "! wget https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en -P data/\n", "! wget https://www.dropbox.com/s/ak53qirssci6f1j/train.ja -P data/" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "--2021-07-18 06:34:36-- https://www.dropbox.com/s/9narw5x4uizmehh/utils.py\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/9narw5x4uizmehh/utils.py [following]\n", "--2021-07-18 06:34:36-- https://www.dropbox.com/s/raw/9narw5x4uizmehh/utils.py\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com/cd/0/inline/BSgUmGGCHiSqU-Aw1pqfKOzPawKACdyROQJrqEsxM4-f2H4Gzrz7YDwXXaAliCXeVNoddmtHd5QRGb2HA7V4GS8dHSdUao3EEMxm_3A-SsbUgwbId4NORcApbJpJ8upvnIsWwa2PJRnNoZumkX1ssAF-/file# [following]\n", "--2021-07-18 06:34:36-- https://uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com/cd/0/inline/BSgUmGGCHiSqU-Aw1pqfKOzPawKACdyROQJrqEsxM4-f2H4Gzrz7YDwXXaAliCXeVNoddmtHd5QRGb2HA7V4GS8dHSdUao3EEMxm_3A-SsbUgwbId4NORcApbJpJ8upvnIsWwa2PJRnNoZumkX1ssAF-/file\n", "Resolving uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com (uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", "Connecting to uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com (uc0d07df862f4b979f2bffa57f83.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 949 [text/plain]\n", "Saving to: ‘utils.py’\n", "\n", "utils.py 100%[===================>] 949 --.-KB/s in 0s \n", "\n", "2021-07-18 06:34:36 (107 MB/s) - ‘utils.py’ saved [949/949]\n", "\n", "--2021-07-18 06:34:36-- https://www.dropbox.com/s/o4kyc52a8we25wy/dev.en\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/o4kyc52a8we25wy/dev.en [following]\n", "--2021-07-18 06:34:37-- https://www.dropbox.com/s/raw/o4kyc52a8we25wy/dev.en\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com/cd/0/inline/BSiwAsgw6yNys2vumtK-8EobSB5qkLzkuaxKBAcWUIop_KIUtGQZJVrBojvaz8LTWNj7MiLuBJxzMXg_6tQCs8KOTuBQXIi7ulv0hXbSwh2Vn_hSEBn_XmWBfK6ZBnrRhBkhbY0EyhNAEz3UlKPFwuRL/file# [following]\n", "--2021-07-18 06:34:37-- https://ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com/cd/0/inline/BSiwAsgw6yNys2vumtK-8EobSB5qkLzkuaxKBAcWUIop_KIUtGQZJVrBojvaz8LTWNj7MiLuBJxzMXg_6tQCs8KOTuBQXIi7ulv0hXbSwh2Vn_hSEBn_XmWBfK6ZBnrRhBkhbY0EyhNAEz3UlKPFwuRL/file\n", "Resolving ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com (ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f\n", "Connecting to ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com (ucd48c85be83aba12ac4731c731c.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 17054 (17K) [text/plain]\n", "Saving to: ‘data/dev.en’\n", "\n", "dev.en 100%[===================>] 16.65K --.-KB/s in 0s \n", "\n", "2021-07-18 06:34:37 (52.7 MB/s) - ‘data/dev.en’ saved [17054/17054]\n", "\n", "--2021-07-18 06:34:37-- https://www.dropbox.com/s/kdgskm5hzg6znuc/dev.ja\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/kdgskm5hzg6znuc/dev.ja [following]\n", "--2021-07-18 06:34:37-- https://www.dropbox.com/s/raw/kdgskm5hzg6znuc/dev.ja\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com/cd/0/inline/BSjTOog1jPfDXFmgmhf_rdryp5iNiwzfoV9hSJFKtlrfA3Z81FYzrjoaSSqzvLXjAPGVzSVBiLyW_8OikG6gE37sIrowFV-JBT3ve4tbkc_U8n94BJSGQReAq7kafyd3qw-OUeh5JP1ZEKgt_96CdCQb/file# [following]\n", "--2021-07-18 06:34:38-- https://ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com/cd/0/inline/BSjTOog1jPfDXFmgmhf_rdryp5iNiwzfoV9hSJFKtlrfA3Z81FYzrjoaSSqzvLXjAPGVzSVBiLyW_8OikG6gE37sIrowFV-JBT3ve4tbkc_U8n94BJSGQReAq7kafyd3qw-OUeh5JP1ZEKgt_96CdCQb/file\n", "Resolving ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com (ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", "Connecting to ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com (ucaa0e8f898c49f5a7f5409609c2.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 27781 (27K) [text/plain]\n", "Saving to: ‘data/dev.ja’\n", "\n", "dev.ja 100%[===================>] 27.13K --.-KB/s in 0.02s \n", "\n", "2021-07-18 06:34:38 (1.48 MB/s) - ‘data/dev.ja’ saved [27781/27781]\n", "\n", "--2021-07-18 06:34:38-- https://www.dropbox.com/s/gyyx4gohv9v65uh/test.en\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/gyyx4gohv9v65uh/test.en [following]\n", "--2021-07-18 06:34:38-- https://www.dropbox.com/s/raw/gyyx4gohv9v65uh/test.en\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com/cd/0/inline/BSiuGhflXLPPjGibDJfLwUmgmpl547WyWBIaaJUuma-2AwZQ9T4Ds3IRYWyXxXWO1n01MQTthDyWKwAbtzi_q6n6vug-T8cIN_VaxvaHhvbPUJkLWVSGWdSSJCYwfzvLPm7pj1fRkYGh2mPyIbpQpH-m/file# [following]\n", "--2021-07-18 06:34:38-- https://uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com/cd/0/inline/BSiuGhflXLPPjGibDJfLwUmgmpl547WyWBIaaJUuma-2AwZQ9T4Ds3IRYWyXxXWO1n01MQTthDyWKwAbtzi_q6n6vug-T8cIN_VaxvaHhvbPUJkLWVSGWdSSJCYwfzvLPm7pj1fRkYGh2mPyIbpQpH-m/file\n", "Resolving uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com (uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f\n", "Connecting to uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com (uc0ed46fdcd88c4a41b3ed069d73.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 17301 (17K) [text/plain]\n", "Saving to: ‘data/test.en’\n", "\n", "test.en 100%[===================>] 16.90K --.-KB/s in 0.002s \n", "\n", "2021-07-18 06:34:39 (10.9 MB/s) - ‘data/test.en’ saved [17301/17301]\n", "\n", "--2021-07-18 06:34:39-- https://www.dropbox.com/s/hotxwbgoe2n013k/test.ja\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/hotxwbgoe2n013k/test.ja [following]\n", "--2021-07-18 06:34:40-- https://www.dropbox.com/s/raw/hotxwbgoe2n013k/test.ja\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com/cd/0/inline/BSgRc-0Hj7l_EK3_qOouZd86iiU710RlBc0VQjdSxB6PCfvYWaqv9XTbSd8LaxNRJE5qksADXFKoRUKxoeImMOldFqyCqyoMv4AWrgwsIBEIMZFobTY_YP6lu1fu44bkyxrRWz80kwF2-eEgHbQlz08s/file# [following]\n", "--2021-07-18 06:34:40-- https://uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com/cd/0/inline/BSgRc-0Hj7l_EK3_qOouZd86iiU710RlBc0VQjdSxB6PCfvYWaqv9XTbSd8LaxNRJE5qksADXFKoRUKxoeImMOldFqyCqyoMv4AWrgwsIBEIMZFobTY_YP6lu1fu44bkyxrRWz80kwF2-eEgHbQlz08s/file\n", "Resolving uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com (uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", "Connecting to uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com (uc97815d392b2d5a81bb818c8147.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 27793 (27K) [text/plain]\n", "Saving to: ‘data/test.ja’\n", "\n", "test.ja 100%[===================>] 27.14K --.-KB/s in 0.02s \n", "\n", "2021-07-18 06:34:40 (1.67 MB/s) - ‘data/test.ja’ saved [27793/27793]\n", "\n", "--2021-07-18 06:34:40-- https://www.dropbox.com/s/5lsftkmb20ay9e1/train.en\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/5lsftkmb20ay9e1/train.en [following]\n", "--2021-07-18 06:34:40-- https://www.dropbox.com/s/raw/5lsftkmb20ay9e1/train.en\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com/cd/0/inline/BSjkSPW1zz5m31_fTvKaS5-65LtHWcp93afPXPX2Ez78RpFhVnw8WNq8EhlDRReZ9h8JNdAdgL0snDFXldEwqqe_L05AXuAOLcWuZek9Tlh7ajuCtSzKLOZwmzhpWDyhVX_zS9tcIyq2U_SAo2loDhDy/file# [following]\n", "--2021-07-18 06:34:41-- https://ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com/cd/0/inline/BSjkSPW1zz5m31_fTvKaS5-65LtHWcp93afPXPX2Ez78RpFhVnw8WNq8EhlDRReZ9h8JNdAdgL0snDFXldEwqqe_L05AXuAOLcWuZek9Tlh7ajuCtSzKLOZwmzhpWDyhVX_zS9tcIyq2U_SAo2loDhDy/file\n", "Resolving ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com (ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6031:15::a27d:510f\n", "Connecting to ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com (ucaa3164b979e3cfa196e47aed8e.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1701356 (1.6M) [text/plain]\n", "Saving to: ‘data/train.en’\n", "\n", "train.en 100%[===================>] 1.62M --.-KB/s in 0.1s \n", "\n", "2021-07-18 06:34:41 (13.6 MB/s) - ‘data/train.en’ saved [1701356/1701356]\n", "\n", "--2021-07-18 06:34:41-- https://www.dropbox.com/s/ak53qirssci6f1j/train.ja\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: /s/raw/ak53qirssci6f1j/train.ja [following]\n", "--2021-07-18 06:34:41-- https://www.dropbox.com/s/raw/ak53qirssci6f1j/train.ja\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com/cd/0/inline/BSgHUx4UNd_T9p8lYt2u9EB0Xyf5HHEGfnn3hfLGNhtb3bNM7pZtQwlYVRgHnBsx-lk_k0gV85uYgwlBDU2_06mQz43BzAEDOnD_CYw4XdlsMAfI_7FKWE7MXIRNRK6v-O1GdySSK2J9acJLT53yduod/file# [following]\n", "--2021-07-18 06:34:41-- https://uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com/cd/0/inline/BSgHUx4UNd_T9p8lYt2u9EB0Xyf5HHEGfnn3hfLGNhtb3bNM7pZtQwlYVRgHnBsx-lk_k0gV85uYgwlBDU2_06mQz43BzAEDOnD_CYw4XdlsMAfI_7FKWE7MXIRNRK6v-O1GdySSK2J9acJLT53yduod/file\n", "Resolving uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com (uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", "Connecting to uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com (uc8f66691b57a3c1b4c9c4f73abf.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 2784447 (2.7M) [text/plain]\n", "Saving to: ‘data/train.ja’\n", "\n", "train.ja 100%[===================>] 2.66M 6.84MB/s in 0.4s \n", "\n", "2021-07-18 06:34:42 (6.84 MB/s) - ‘data/train.ja’ saved [2784447/2784447]\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "xp5QEw8CICiO", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0814cac6-874c-4d09-9b13-5a24215cd944" }, "source": [ "! ls data" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "text": [ "dev.en\tdev.ja\ttest.en test.ja train.en train.ja\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "_HauAB3uAM7i", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "841cc0fb-fc94-41d6-a91a-048df9f0d50a" }, "source": [ "import random\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.utils import shuffle\n", "from nltk import bleu_score\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\n", "from utils import Vocab\n", "\n", "# デバイスの設定\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "torch.manual_seed(1)\n", "random_state = 42\n", "\n", "print(torch.__version__)" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "text": [ "1.9.0+cu102\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "pAmQOdx0AM7o" }, "source": [ "# 1.データセットの準備\n", "英語-日本語の対訳コーパスである、Tanaka Corpus ( http://www.edrdg.org/wiki/index.php/Tanaka_Corpus )を使います。<br>\n", "今回はそのうちの一部分を取り出したsmall_parallel_enja: 50k En/Ja Parallel Corpus for Testing SMT Methods ( https://github.com/odashi/small_parallel_enja )を使用します。\n", "\n", "train.enとtrain.jaの中身を見てみましょう。" ] }, { "cell_type": "code", "metadata": { "id": "gVxFp2MmAM7p", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "90ff00e1-c10f-4431-9bc7-a716ca396764" }, "source": [ "! head -10 data/train.en" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "i can 't tell who will arrive first .\n", "many animals have been destroyed by men .\n", "i 'm in the tennis club .\n", "emi looks happy .\n", "please bear this fact in mind .\n", "she takes care of my children .\n", "we want to be international .\n", "you ought not to break your promise .\n", "when you cross the street , watch out for cars .\n", "i have nothing to live for .\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "jSgmTKl7AM7u", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "af9d63e2-cbf2-4b89-f60d-96e8ae832b1a" }, "source": [ "! head -10 ./data/train.ja" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。\n", "多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。\n", "私 は テニス 部員 で す 。\n", "エミ は 幸せ そう に 見え ま す 。\n", "この 事実 を 心 に 留め て お い て 下さ い 。\n", "彼女 は 私 たち の 世話 を し て くれ る 。\n", "私 達 は 国際 人 に な り た い と 思 い ま す 。\n", "約束 を 破 る べ き で は あ り ま せ ん 。\n", "道路 を 横切 る とき は 車 に 注意 し なさ い 。\n", "私 に は 生き 甲斐 が な い 。\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "UbgN52WHAM7y" }, "source": [ "それぞれの文章が英語-日本語で対応しているのがわかります。" ] }, { "cell_type": "markdown", "metadata": { "id": "OQhfLZ5lAM7z" }, "source": [ "## 1.1データの読み込みと単語の分割" ] }, { "cell_type": "code", "metadata": { "id": "coc6DTCUAM71" }, "source": [ "def load_data(file_path):\n", " # テキストファイルからデータを読み込むメソッド\n", " data = []\n", " for line in open(file_path, encoding='utf-8'):\n", " words = line.strip().split() # スペースで単語を分割\n", " data.append(words)\n", " return data" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "z5UMnxsTAM74" }, "source": [ "train_X = load_data('./data/train.en')\n", "train_Y = load_data('./data/train.ja')" ], "execution_count": 8, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ryYkPteoAM76" }, "source": [ "# 訓練データと検証データに分割\n", "train_X, valid_X, train_Y, valid_Y = train_test_split(train_X, train_Y, test_size=0.2, random_state=random_state)" ], "execution_count": 9, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "6yB4jwiCAM79" }, "source": [ "この時点で入力と教師データは以下のようになっています" ] }, { "cell_type": "code", "metadata": { "id": "0HV1SNLAAM7-", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "9f55d45a-7c33-4b9d-ee40-2711fbe0f8e4" }, "source": [ "print('train data', train_X[0])\n", "print('valid data', valid_X[0])" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "train data ['where', 'shall', 'we', 'eat', 'tonight', '?']\n", "valid data ['you', 'may', 'extend', 'your', 'stay', 'in', 'tokyo', '.']\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "NeB7llfIAM8E" }, "source": [ "## 1.2単語辞書の作成\n", "データセットに登場する各単語にIDを割り振る" ] }, { "cell_type": "code", "metadata": { "id": "b-OqjDkXAM8F" }, "source": [ "# まず特殊トークンを定義しておく\n", "PAD_TOKEN = '<PAD>' # バッチ処理の際に、短い系列の末尾を埋めるために使う (Padding)\n", "BOS_TOKEN = '<S>' # 系列の始まりを表す (Beggining of sentence)\n", "EOS_TOKEN = '</S>' # 系列の終わりを表す (End of sentence)\n", "UNK_TOKEN = '<UNK>' # 語彙に存在しない単語を表す (Unknown)\n", "PAD = 0\n", "BOS = 1\n", "EOS = 2\n", "UNK = 3" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "T_G7dYnTAM8I" }, "source": [ "MIN_COUNT = 2 # 語彙に含める単語の最低出現回数 再提出現回数に満たない単語はUNKに置き換えられる\n", "\n", "# 単語をIDに変換する辞書の初期値を設定\n", "word2id = {\n", " PAD_TOKEN: PAD,\n", " BOS_TOKEN: BOS,\n", " EOS_TOKEN: EOS,\n", " UNK_TOKEN: UNK,\n", " }\n", "\n", "# 単語辞書を作成\n", "vocab_X = Vocab(word2id=word2id)\n", "vocab_Y = Vocab(word2id=word2id)\n", "vocab_X.build_vocab(train_X, min_count=MIN_COUNT)\n", "vocab_Y.build_vocab(train_Y, min_count=MIN_COUNT)" ], "execution_count": 12, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "0xDhdQ4FAM8K", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "404bba10-7912-4ec7-fe9b-7587bd14546f" }, "source": [ "vocab_size_X = len(vocab_X.id2word)\n", "vocab_size_Y = len(vocab_Y.id2word)\n", "print('入力言語の語彙数:', vocab_size_X)\n", "print('出力言語の語彙数:', vocab_size_Y)" ], "execution_count": 13, "outputs": [ { "output_type": "stream", "text": [ "入力言語の語彙数: 3725\n", "出力言語の語彙数: 4405\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "P-K_xHBkC5TC" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "NsoRIue9AM8P" }, "source": [ "# 2.テンソルへの変換" ] }, { "cell_type": "markdown", "metadata": { "id": "HsCajwO2AM8Q" }, "source": [ "### 2.1 IDへの変換\n", "まずはモデルが文章を認識できるように、文章を単語IDのリストに変換します" ] }, { "cell_type": "code", "metadata": { "id": "gm6qa0fNAM8R" }, "source": [ "def sentence_to_ids(vocab, sentence):\n", " # 単語(str)のリストをID(int)のリストに変換する関数\n", " ids = [vocab.word2id.get(word, UNK) for word in sentence]\n", " ids += [EOS] # EOSを加える\n", " return ids" ], "execution_count": 14, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lk0B0VR_AM8T" }, "source": [ "train_X = [sentence_to_ids(vocab_X, sentence) for sentence in train_X]\n", "train_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in train_Y]\n", "valid_X = [sentence_to_ids(vocab_X, sentence) for sentence in valid_X]\n", "valid_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in valid_Y]" ], "execution_count": 15, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "I4pEMlCxAM8X" }, "source": [ "この時点で入力と教師データは以下のようになっている" ] }, { "cell_type": "code", "metadata": { "id": "D6OKuYgwAM8Y", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "32329dac-28fb-43d9-d22f-1724af677b88" }, "source": [ "print('train data', train_X[0])\n", "print('valid data', valid_X[0])" ], "execution_count": 16, "outputs": [ { "output_type": "stream", "text": [ "train data [132, 321, 28, 290, 367, 12, 2]\n", "valid data [8, 93, 3532, 36, 236, 13, 284, 4, 2]\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "_DlWrRnQAM8f" }, "source": [ "### 2.2 DataLoaderの定義\n", "データセットからバッチを取得するデータローダーを定義します\n", "- この際、長さの異なる複数の系列をバッチで並列に扱えるように、短い系列の末尾を特定のシンボル(`<PAD>`など)でパディングし、バッチ内の系列の長さを最長のものに合わせる\n", "- (batch_size, max_length)のサイズの行列を得るが、実際にモデルを学習させるときには、バッチをまたいで各時刻ごとに進めていくので、転置して(max_length, batch_size)の形に変える<br>(batch_first=Trueのオプションを使う場合は不要)" ] }, { "cell_type": "code", "metadata": { "id": "YtmFgYLqAM8h" }, "source": [ "def pad_seq(seq, max_length):\n", " # 系列(seq)が指定の文長(max_length)になるように末尾をパディングする\n", " res = seq + [PAD for i in range(max_length - len(seq))]\n", " return res \n", "\n", "\n", "class DataLoader(object):\n", "\n", " def __init__(self, X, Y, batch_size, shuffle=False):\n", " \"\"\"\n", " :param X: list, 入力言語の文章(単語IDのリスト)のリスト\n", " :param Y: list, 出力言語の文章(単語IDのリスト)のリスト\n", " :param batch_size: int, バッチサイズ\n", " :param shuffle: bool, サンプルの順番をシャッフルするか否か\n", " \"\"\"\n", " self.data = list(zip(X, Y))\n", " self.batch_size = batch_size\n", " self.shuffle = shuffle\n", " self.start_index = 0\n", " \n", " self.reset()\n", " \n", " def reset(self):\n", " if self.shuffle: # サンプルの順番をシャッフルする\n", " self.data = shuffle(self.data, random_state=random_state)\n", " self.start_index = 0 # ポインタの位置を初期化する\n", " \n", " def __iter__(self):\n", " return self\n", "\n", " def __next__(self):\n", " # ポインタが最後まで到達したら初期化する\n", " if self.start_index >= len(self.data):\n", " self.reset()\n", " raise StopIteration()\n", "\n", " # バッチを取得\n", " seqs_X, seqs_Y = zip(*self.data[self.start_index:self.start_index+self.batch_size])\n", " # 入力系列seqs_Xの文章の長さ順(降順)に系列ペアをソートする\n", " seq_pairs = sorted(zip(seqs_X, seqs_Y), key=lambda p: len(p[0]), reverse=True)\n", " seqs_X, seqs_Y = zip(*seq_pairs)\n", " # 短い系列の末尾をパディングする\n", " lengths_X = [len(s) for s in seqs_X] # 後述のEncoderのpack_padded_sequenceでも用いる\n", " lengths_Y = [len(s) for s in seqs_Y]\n", " max_length_X = max(lengths_X)\n", " max_length_Y = max(lengths_Y)\n", " padded_X = [pad_seq(s, max_length_X) for s in seqs_X]\n", " padded_Y = [pad_seq(s, max_length_Y) for s in seqs_Y]\n", " # tensorに変換し、転置する\n", " batch_X = torch.tensor(padded_X, dtype=torch.long, device=device).transpose(0, 1)\n", " batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)\n", "\n", " # ポインタを更新する\n", " self.start_index += self.batch_size\n", "\n", " return batch_X, batch_Y, lengths_X" ], "execution_count": 17, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-37frCXrAM8k" }, "source": [ "# 3.モデルの構築\n", "EncoderとDecoderのRNNを定義します。" ] }, { "cell_type": "markdown", "metadata": { "id": "1X3oRjArAM8l" }, "source": [ "### 導入:PackedSequence" ] }, { "cell_type": "markdown", "metadata": { "id": "lRmj-EdbAM8m" }, "source": [ "PyTorchのRNNでは、可変長の系列のバッチを効率よく計算できるように系列を表現する`PackedSequence`というクラスを用いることができます。\n", "\n", "入力バッチのテンソルをこの`PackedSequence`のインスタンスに変換してからRNNに入力することで、パディング部分の計算を省略することができるため、効率的な計算が可能になります。\n", "\n", "`PackedSequence`を作成するには、まず、系列長の異なるバッチに対してパディングを行なってください。\n", "\n", "ここで、パディングを行う前に各サンプルの系列長(`lengths`)を保存しておきます。" ] }, { "cell_type": "code", "metadata": { "id": "yWAV3W89AM8n", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8d84de3f-9345-4517-d5f0-1514aedaa8e4" }, "source": [ "# 系列長がそれぞれ4,3,2の3つのサンプルからなるバッチを作成\n", "batch = [[1,2,3,4], [5,6,7], [8,9]]\n", "lengths = [len(sample) for sample in batch]\n", "print('各サンプルの系列長:', lengths)\n", "print()\n", "\n", "# 最大系列長に合うように各サンプルをpadding\n", "_max_length = max(lengths)\n", "padded = torch.tensor([pad_seq(sample, _max_length) for sample in batch])\n", "print('paddingされたテンソル:\\n', padded)\n", "padded = padded.transpose(0,1) # (max_length, batch_size)に転置\n", "print('padding & 転置されたテンソル:\\n', padded)\n", "print('padding & 転置されたテンソルのサイズ:\\n', padded.size())\n", "print()" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "text": [ "各サンプルの系列長: [4, 3, 2]\n", "\n", "paddingされたテンソル:\n", " tensor([[1, 2, 3, 4],\n", " [5, 6, 7, 0],\n", " [8, 9, 0, 0]])\n", "padding & 転置されたテンソル:\n", " tensor([[1, 5, 8],\n", " [2, 6, 9],\n", " [3, 7, 0],\n", " [4, 0, 0]])\n", "padding & 転置されたテンソルのサイズ:\n", " torch.Size([4, 3])\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "GTAq1q8mAM8q" }, "source": [ "次に、パディングを行ったテンソル(`padded`)と各サンプルの元々の系列長(`lengths`)を`torch.nn.utils.rnn.pack_padded_sequence`という関数に与えると、\n", "`data`と`batch_sizes`という要素を持った`PackedSequence`のインスタンス(`packed`)が作成できます。\n", "- `data`: テンソルの`PAD`以外の値のみを保有するベクトル\n", "- `batch_sizes`: 各時刻で計算が必要な(=`PAD`に到達していない)バッチの数を表すベクトル" ] }, { "cell_type": "code", "metadata": { "id": "FtRm7uqIAM8s", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0db8122c-0d58-4296-c1bf-bba64e36c128" }, "source": [ "# PackedSequenceに変換(テンソルをRNNに入力する前に適用する)\n", "packed = pack_padded_sequence(padded, lengths=lengths) # 各サンプルの系列長も与える\n", "print('PackedSequenceのインスタンス:\\n', packed) # テンソルのPAD以外の値(data)と各時刻で計算が必要な(=PADに到達していない)バッチの数(batch_sizes)を有するインスタンス\n", "print()" ], "execution_count": 19, "outputs": [ { "output_type": "stream", "text": [ "PackedSequenceのインスタンス:\n", " PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=None, unsorted_indices=None)\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "iYKaiMZDAM8w" }, "source": [ "こうして得られた`PackedSequence`のインスタンスをRNNに入力します。(ここでは省略)\n", "\n", "RNNから出力されたテンソルは`PackedSeauence`のインスタンスのままなので、後段の計算につなぐために`torch.nn.utils.rnn.pad_packed_sequence`の関数によって通常のテンソルに戻します。" ] }, { "cell_type": "code", "metadata": { "id": "F7BBaiVzAM8x", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "92f14706-cef9-4add-c8b9-a30fabdb89ef" }, "source": [ "# PackedSequenceのインスタンスをRNNに入力する(ここでは省略)\n", "output = packed\n", "\n", "# テンソルに戻す(RNNの出力に対して適用する)\n", "output, _length = pad_packed_sequence(output) # PADを含む元のテンソルと各サンプルの系列長を返す\n", "print('PADを含む元のテンソル:\\n', output)\n", "print('各サンプルの系列長:', _length)" ], "execution_count": 20, "outputs": [ { "output_type": "stream", "text": [ "PADを含む元のテンソル:\n", " tensor([[1, 5, 8],\n", " [2, 6, 9],\n", " [3, 7, 0],\n", " [4, 0, 0]])\n", "各サンプルの系列長: tensor([4, 3, 2])\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "wzT0I8w9AM81" }, "source": [ "### Encoder\n", "今回はEncoder側でバッチを処理する際に、`pack_padded_sequence`関数によってtensorを`PackedSequence`に変換し、処理を終えた後に`pad_packed_sequence`関数によってtensorに戻すという処理を行います。" ] }, { "cell_type": "code", "metadata": { "id": "NdY2WGwMAM82" }, "source": [ "class Encoder(nn.Module):\n", " def __init__(self, input_size, hidden_size):\n", " \"\"\"\n", " :param input_size: int, 入力言語の語彙数\n", " :param hidden_size: int, 隠れ層のユニット数\n", " \"\"\"\n", " super(Encoder, self).__init__()\n", " self.hidden_size = hidden_size\n", "\n", " self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD)\n", " self.gru = nn.GRU(hidden_size, hidden_size)\n", "\n", " def forward(self, seqs, input_lengths, hidden=None):\n", " \"\"\"\n", " :param seqs: tensor, 入力のバッチ, size=(max_length, batch_size)\n", " :param input_lengths: 入力のバッチの各サンプルの文長\n", " :param hidden: tensor, 隠れ状態の初期値, Noneの場合は0で初期化される\n", " :return output: tensor, Encoderの出力, size=(max_length, batch_size, hidden_size)\n", " :return hidden: tensor, Encoderの隠れ状態, size=(1, batch_size, hidden_size)\n", " \"\"\"\n", " emb = self.embedding(seqs) # seqsはパディング済み\n", " packed = pack_padded_sequence(emb, input_lengths) # PackedSequenceオブジェクトに変換\n", " output, hidden = self.gru(packed, hidden)\n", " output, _ = pad_packed_sequence(output)\n", " return output, hidden" ], "execution_count": 21, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eBw_ZiwDAM85" }, "source": [ "### Decoder\n", "今回はDecoder側ではパディング等行わないので、通常のtensorのままRNNに入力して問題ありません。" ] }, { "cell_type": "code", "metadata": { "id": "UjKk_-9_AM86" }, "source": [ "class Decoder(nn.Module):\n", " def __init__(self, hidden_size, output_size):\n", " \"\"\"\n", " :param hidden_size: int, 隠れ層のユニット数\n", " :param output_size: int, 出力言語の語彙数\n", " :param dropout: float, ドロップアウト率\n", " \"\"\"\n", " super(Decoder, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.output_size = output_size\n", "\n", " self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)\n", " self.gru = nn.GRU(hidden_size, hidden_size)\n", " self.out = nn.Linear(hidden_size, output_size)\n", "\n", " def forward(self, seqs, hidden):\n", " \"\"\"\n", " :param seqs: tensor, 入力のバッチ, size=(1, batch_size)\n", " :param hidden: tensor, 隠れ状態の初期値, Noneの場合は0で初期化される\n", " :return output: tensor, Decoderの出力, size=(1, batch_size, output_size)\n", " :return hidden: tensor, Decoderの隠れ状態, size=(1, batch_size, hidden_size)\n", " \"\"\"\n", " emb = self.embedding(seqs)\n", " output, hidden = self.gru(emb, hidden)\n", " output = self.out(output)\n", " return output, hidden" ], "execution_count": 22, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Tf64KCf2AM88" }, "source": [ "## EncoderDecoder\n", "上で定義したEncoderとDecoderを用いた、一連の処理をまとめるEncoderDecoderのクラスを定義します。\n", "\n", "ここで、Decoder側の処理で注意する点があります。\n", "\n", "RNNでは、時刻$t$の出力を時刻$t+1$の入力とすることができるが、この方法でDecoderを学習させると連鎖的に誤差が大きくなっていき、学習が不安定になったり収束が遅くなったりする問題が発生します。\n", "\n", "\n", "この問題への対策として**Teacher Forcing**というテクニックがあります。\n", "これは、訓練時にはDecoder側の入力に、ターゲット系列(参照訳)をそのまま使うというものです。\n", "これにより学習が安定し、収束が早くなるというメリットがありますが、逆に評価時は前の時刻にDecoderが生成したものが使われるため、学習時と分布が異なってしまうというデメリットもあります。\n", "\n", "\n", "Teacher Forcingの拡張として、ターゲット系列を入力とするか生成された結果を入力とするかを確率的にサンプリングする**Scheduled Sampling**という手法があります。\n", "\n", "ここではScheduled Samplingを採用し、一定の確率に基づいてターゲット系列を入力とするか生成された結果を入力とするかを切り替えられるようにクラスを定義しておきます。" ] }, { "cell_type": "code", "metadata": { "id": "OB9Nlcd9AM89" }, "source": [ "class EncoderDecoder(nn.Module):\n", " \"\"\"EncoderとDecoderの処理をまとめる\"\"\"\n", " def __init__(self, input_size, output_size, hidden_size):\n", " \"\"\"\n", " :param input_size: int, 入力言語の語彙数\n", " :param output_size: int, 出力言語の語彙数\n", " :param hidden_size: int, 隠れ層のユニット数\n", " \"\"\"\n", " super(EncoderDecoder, self).__init__()\n", " self.encoder = Encoder(input_size, hidden_size)\n", " self.decoder = Decoder(hidden_size, output_size)\n", "\n", " def forward(self, batch_X, lengths_X, max_length, batch_Y=None, use_teacher_forcing=False):\n", " \"\"\"\n", " :param batch_X: tensor, 入力系列のバッチ, size=(max_length, batch_size)\n", " :param lengths_X: list, 入力系列のバッチ内の各サンプルの文長\n", " :param max_length: int, Decoderの最大文長\n", " :param batch_Y: tensor, Decoderで用いるターゲット系列\n", " :param use_teacher_forcing: Decoderでターゲット系列を入力とするフラグ\n", " :return decoder_outputs: tensor, Decoderの出力, \n", " size=(max_length, batch_size, self.decoder.output_size)\n", " \"\"\"\n", " # encoderに系列を入力(複数時刻をまとめて処理)\n", " _, encoder_hidden = self.encoder(batch_X, lengths_X)\n", " \n", " _batch_size = batch_X.size(1)\n", "\n", " # decoderの入力と隠れ層の初期状態を定義\n", " decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device) # 最初の入力にはBOSを使用する\n", " decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)\n", " decoder_hidden = encoder_hidden # Encoderの最終隠れ状態を取得\n", "\n", " # decoderの出力のホルダーを定義\n", " decoder_outputs = torch.zeros(max_length, _batch_size, self.decoder.output_size, device=device) # max_length分の固定長\n", "\n", " # 各時刻ごとに処理\n", " for t in range(max_length):\n", " decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)\n", " decoder_outputs[t] = decoder_output\n", " # 次の時刻のdecoderの入力を決定\n", " if use_teacher_forcing and batch_Y is not None: # teacher forceの場合、ターゲット系列を用いる\n", " decoder_input = batch_Y[t].unsqueeze(0)\n", " else: # teacher forceでない場合、自身の出力を用いる\n", " decoder_input = decoder_output.max(-1)[1]\n", " \n", " return decoder_outputs" ], "execution_count": 23, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qBTFmbwLAM9A" }, "source": [ "# 4.訓練\n", "### 4.1 損失関数の定義\n", "基本的にはクロスエントロピーを損失関数として用いますが、パディングを行うと短い系列の末尾には`<PAD>`トークンが入るため、この部分の損失を計算しないように、マスクをかけます。" ] }, { "cell_type": "code", "metadata": { "id": "Kt-r-nxJAM9B", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "22f6e9e9-88b1-4e77-cbea-1e10ae8a7743" }, "source": [ "mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD) # PADを無視する\n", "def masked_cross_entropy(logits, target):\n", " logits_flat = logits.view(-1, logits.size(-1)) # (max_seq_len * batch_size, output_size)\n", " target_flat = target.view(-1) # (max_seq_len * batch_size, 1)\n", " return mce(logits_flat, target_flat)" ], "execution_count": 24, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", " warnings.warn(warning.format(ret))\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "GgQZl0GvAM9E" }, "source": [ "### 4.2学習" ] }, { "cell_type": "code", "metadata": { "id": "qurGD8IsAM9F" }, "source": [ "# ハイパーパラメータの設定\n", "num_epochs = 10\n", "batch_size = 64\n", "lr = 1e-3 # 学習率\n", "teacher_forcing_rate = 0.2 # Teacher Forcingを行う確率\n", "ckpt_path = 'model.pth' # 学習済みのモデルを保存するパス\n", "\n", "model_args = {\n", " 'input_size': vocab_size_X,\n", " 'output_size': vocab_size_Y,\n", " 'hidden_size': 256,\n", "}" ], "execution_count": 25, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "4xNxWJGPAM9I" }, "source": [ "# データローダを定義\n", "train_dataloader = DataLoader(train_X, train_Y, batch_size=batch_size, shuffle=True)\n", "valid_dataloader = DataLoader(valid_X, valid_Y, batch_size=batch_size, shuffle=False)\n", "\n", "# モデルとOptimizerを定義\n", "model = EncoderDecoder(**model_args).to(device)\n", "optimizer = optim.Adam(model.parameters(), lr=lr)" ], "execution_count": 26, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "_67wGmARAM9L" }, "source": [ "実際に損失関数を計算する関数を定義します。" ] }, { "cell_type": "code", "metadata": { "id": "Adggq9xOAM9L" }, "source": [ "def compute_loss(batch_X, batch_Y, lengths_X, model, optimizer=None, is_train=True):\n", " # 損失を計算する関数\n", " model.train(is_train) # train/evalモードの切替え\n", " \n", " # 一定確率でTeacher Forcingを行う\n", " use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)\n", " max_length = batch_Y.size(0)\n", " # 推論\n", " pred_Y = model(batch_X, lengths_X, max_length, batch_Y, use_teacher_forcing)\n", " \n", " # 損失関数を計算\n", " loss = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())\n", " \n", " if is_train: # 訓練時はパラメータを更新\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " batch_Y = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()\n", " pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()\n", "\n", " return loss.item(), batch_Y, pred" ], "execution_count": 28, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "QHJXFnLpAM9O" }, "source": [ "ここで、Loss以外に、学習の進捗を確認するためにモデルの性能を評価する指標として、BLEUを計算します。\n", "\n", "BLEUは機械翻訳の分野において最も一般的な自動評価基準の一つで、予め用意した複数の参照訳と、機械翻訳モデルが出力した訳のn-gramのマッチ率に基づく指標です。\n", "\n", "NLTK (Natural Language Tool Kit) という自然言語処理で用いられるライブラリを用いて簡単に計算することができます。" ] }, { "cell_type": "code", "metadata": { "id": "ImK-xzAWAM9P" }, "source": [ "def calc_bleu(refs, hyps):\n", " \"\"\"\n", " BLEUスコアを計算する関数\n", " :param refs: list, 参照訳。単語のリストのリスト (例: [['I', 'have', 'a', 'pen'], ...])\n", " :param hyps: list, モデルの生成した訳。単語のリストのリスト (例: ['I', 'have', 'a', 'pen'])\n", " :return: float, BLEUスコア(0~100)\n", " \"\"\"\n", " refs = [[ref[:ref.index(EOS)]] for ref in refs] # EOSは評価しないで良いので切り捨てる, refsのほうは複数なのでlistが一個多くかかっている\n", " hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps]\n", " return 100 * bleu_score.corpus_bleu(refs, hyps)" ], "execution_count": 29, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "inYRxu8aAM9T" }, "source": [ "それではモデルの訓練を行います。" ] }, { "cell_type": "code", "metadata": { "id": "bz-Dx5p6AM9T", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "2ef9587b-c01f-4c29-a3e9-75b1562cc12b" }, "source": [ "# 訓練\n", "best_valid_bleu = 0.\n", "\n", "for epoch in range(1, num_epochs+1):\n", " train_loss = 0.\n", " train_refs = []\n", " train_hyps = []\n", " valid_loss = 0.\n", " valid_refs = []\n", " valid_hyps = []\n", " # train\n", " for batch in train_dataloader:\n", " batch_X, batch_Y, lengths_X = batch\n", " loss, gold, pred = compute_loss(\n", " batch_X, batch_Y, lengths_X, model, optimizer, \n", " is_train=True\n", " )\n", " train_loss += loss\n", " train_refs += gold\n", " train_hyps += pred\n", " # valid\n", " for batch in valid_dataloader:\n", " batch_X, batch_Y, lengths_X = batch\n", " loss, gold, pred = compute_loss(\n", " batch_X, batch_Y, lengths_X, model, \n", " is_train=False\n", " )\n", " valid_loss += loss\n", " valid_refs += gold\n", " valid_hyps += pred\n", " # 損失をサンプル数で割って正規化\n", " train_loss = np.sum(train_loss) / len(train_dataloader.data)\n", " valid_loss = np.sum(valid_loss) / len(valid_dataloader.data)\n", " # BLEUを計算\n", " train_bleu = calc_bleu(train_refs, train_hyps)\n", " valid_bleu = calc_bleu(valid_refs, valid_hyps)\n", "\n", " # validationデータでBLEUが改善した場合にはモデルを保存\n", " if valid_bleu > best_valid_bleu:\n", " ckpt = model.state_dict()\n", " torch.save(ckpt, ckpt_path)\n", " best_valid_bleu = valid_bleu\n", "\n", " print('Epoch {}: train_loss: {:5.2f} train_bleu: {:2.2f} valid_loss: {:5.2f} valid_bleu: {:2.2f}'.format(\n", " epoch, train_loss, train_bleu, valid_loss, valid_bleu))\n", " \n", " print('-'*80)" ], "execution_count": 30, "outputs": [ { "output_type": "stream", "text": [ "Epoch 1: train_loss: 52.44 train_bleu: 3.30 valid_loss: 48.78 valid_bleu: 5.10\n", "--------------------------------------------------------------------------------\n", "Epoch 2: train_loss: 44.48 train_bleu: 7.57 valid_loss: 44.77 valid_bleu: 8.37\n", "--------------------------------------------------------------------------------\n", "Epoch 3: train_loss: 40.05 train_bleu: 11.49 valid_loss: 41.95 valid_bleu: 8.68\n", "--------------------------------------------------------------------------------\n", "Epoch 4: train_loss: 37.40 train_bleu: 14.04 valid_loss: 41.00 valid_bleu: 13.31\n", "--------------------------------------------------------------------------------\n", "Epoch 5: train_loss: 34.79 train_bleu: 17.00 valid_loss: 40.30 valid_bleu: 14.62\n", "--------------------------------------------------------------------------------\n", "Epoch 6: train_loss: 32.96 train_bleu: 19.18 valid_loss: 39.93 valid_bleu: 15.41\n", "--------------------------------------------------------------------------------\n", "Epoch 7: train_loss: 31.71 train_bleu: 20.88 valid_loss: 39.90 valid_bleu: 16.35\n", "--------------------------------------------------------------------------------\n", "Epoch 8: train_loss: 30.40 train_bleu: 22.62 valid_loss: 40.41 valid_bleu: 17.56\n", "--------------------------------------------------------------------------------\n", "Epoch 9: train_loss: 29.20 train_bleu: 24.48 valid_loss: 40.64 valid_bleu: 18.55\n", "--------------------------------------------------------------------------------\n", "Epoch 10: train_loss: 27.63 train_bleu: 27.09 valid_loss: 40.98 valid_bleu: 19.21\n", "--------------------------------------------------------------------------------\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Y3tlT8z9SCoF" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BdgBkAlkAM9V" }, "source": [ "# 5.評価" ] }, { "cell_type": "code", "metadata": { "id": "ze8jkchYAM9W", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c7487f67-9d30-451d-90e8-58f1c3d6e5b6" }, "source": [ "# 学習済みモデルの読み込み\n", "ckpt = torch.load(ckpt_path) # cpuで処理する場合はmap_locationで指定する必要があります。\n", "model.load_state_dict(ckpt)\n", "model.eval()" ], "execution_count": 38, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "EncoderDecoder(\n", " (encoder): Encoder(\n", " (embedding): Embedding(3725, 256, padding_idx=0)\n", " (gru): GRU(256, 256)\n", " )\n", " (decoder): Decoder(\n", " (embedding): Embedding(4405, 256, padding_idx=0)\n", " (gru): GRU(256, 256)\n", " (out): Linear(in_features=256, out_features=4405, bias=True)\n", " )\n", ")" ] }, "metadata": { "tags": [] }, "execution_count": 38 } ] }, { "cell_type": "code", "metadata": { "id": "YyKW9WY6AM9Y" }, "source": [ "def ids_to_sentence(vocab, ids):\n", " # IDのリストを単語のリストに変換する\n", " return [vocab.id2word[_id] for _id in ids]\n", "\n", "def trim_eos(ids):\n", " # IDのリストからEOS以降の単語を除外する\n", " if EOS in ids:\n", " return ids[:ids.index(EOS)]\n", " else:\n", " return ids" ], "execution_count": 39, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "r7qCpnSpAM9b" }, "source": [ "# テストデータの読み込み\n", "test_X = load_data('./data/dev.en')\n", "test_Y = load_data('./data/dev.ja')" ], "execution_count": 40, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "41hLJdkNAM9d" }, "source": [ "test_X = [sentence_to_ids(vocab_X, sentence) for sentence in test_X]\n", "test_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in test_Y]" ], "execution_count": 41, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "IFK0JzSYAM9m" }, "source": [ "test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)" ], "execution_count": 42, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "mYfjq3shAM9q", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "bcba5129-580c-459b-bac9-3c9e07f6973d" }, "source": [ "# 生成\n", "batch_X, batch_Y, lengths_X = next(test_dataloader)\n", "sentence_X = ' '.join(ids_to_sentence(vocab_X, batch_X.data.cpu().numpy()[:-1, 0]))\n", "sentence_Y = ' '.join(ids_to_sentence(vocab_Y, batch_Y.data.cpu().numpy()[:-1, 0]))\n", "print('src: {}'.format(sentence_X))\n", "print('tgt: {}'.format(sentence_Y))\n", "\n", "output = model(batch_X, lengths_X, max_length=20)\n", "output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()\n", "output_sentence = ' '.join(ids_to_sentence(vocab_Y, trim_eos(output)))\n", "output_sentence_without_trim = ' '.join(ids_to_sentence(vocab_Y, output))\n", "print('out: {}'.format(output_sentence))\n", "print('without trim: {}'.format(output_sentence_without_trim))" ], "execution_count": 89, "outputs": [ { "output_type": "stream", "text": [ "src: we went to boston , where we stayed a week .\n", "tgt: 私 たち は ボストン に 行 き 、 そこ に 一 週間 滞在 し た 。\n", "out: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た\n", "without trim: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た </S> </S> </S> </S>\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "atSEiLMHAM93", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c44a8743-6b22-40fb-b7b4-27ffe32ebada" }, "source": [ "# BLEUの計算\n", "test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)\n", "refs_list = []\n", "hyp_list = []\n", "\n", "for batch in test_dataloader:\n", " batch_X, batch_Y, lengths_X = batch\n", " pred_Y = model(batch_X, lengths_X, max_length=20)\n", " pred = pred_Y.max(dim=-1)[1].view(-1).data.cpu().tolist()\n", " refs = batch_Y.view(-1).data.cpu().tolist()\n", " refs_list.append(refs)\n", " hyp_list.append(pred)\n", "bleu = calc_bleu(refs_list, hyp_list)\n", "print(bleu)" ], "execution_count": 70, "outputs": [ { "output_type": "stream", "text": [ "19.696735593854097\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "B1UDdsNvruEi" }, "source": [ "### Beam Search\n", "テストデータに対して新たな文を生成する際、これまでは各時刻で最も確率の高い単語を正解として採用し、次のステップでの入力として使っていました。\n", "ただ、本当にやりたいのは、文全体の尤度が最も高くなるような文を生成することです。そのため、ただ近視眼的に確率の高い単語を採用していくより、もう少し大局的に評価していく必要があります。\n", "\n", "Beam Searchでは、各時刻において一定の数$K$のそれまでのスコア(対数尤度など)の高い文を保持しながら選択を行っていきます。 \n", "\n", "\n", "図はSlack上のものを参照してください。" ] }, { "cell_type": "code", "metadata": { "id": "2vFRiqFwtKsZ" }, "source": [ "class BeamEncoderDecoder(EncoderDecoder):\n", " \"\"\"\n", " Beam Searchでdecodeを行うためのクラス\n", " \"\"\"\n", " def __init__(self, input_size, output_size, hidden_size, beam_size=4):\n", " \"\"\"\n", " :param input_size: int, 入力言語の語彙数\n", " :param output_size: int, 出力言語の語彙数\n", " :param hidden_size: int, 隠れ層のユニット数\n", " :param beam_size: int, ビーム数\n", " \"\"\"\n", " super(BeamEncoderDecoder, self).__init__(input_size, output_size, hidden_size)\n", " self.beam_size = beam_size\n", "\n", " def forward(self, batch_X, lengths_X, max_length):\n", " \"\"\"\n", " :param batch_X: tensor, 入力系列のバッチ, size=(max_length, batch_size)\n", " :param lengths_X: list, 入力系列のバッチ内の各サンプルの文長\n", " :param max_length: int, Decoderの最大文長\n", " :return decoder_outputs: list, 各ビームのDecoderの出力\n", " :return finished_scores: list of float, 各ビームのスコア\n", " \"\"\"\n", " _, encoder_hidden = self.encoder(batch_X, lengths_X)\n", "\n", " # decoderの入力と隠れ層の初期状態を定義\n", " decoder_input = torch.tensor([BOS] * self.beam_size, dtype=torch.long, device=device)\n", " decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)\n", " decoder_hidden = encoder_hidden\n", "\n", " # beam_sizeの数だけrepeatする\n", " decoder_input = decoder_input.expand(1, beam_size)\n", " decoder_hidden = decoder_hidden.expand(1, beam_size, -1).contiguous()\n", "\n", " k = beam_size\n", " finished_beams = []\n", " finished_scores = []\n", " prev_probs = torch.zeros(beam_size, 1, dtype=torch.float, device=device) # 前の時刻の各ビームの対数尤度を保持しておく\n", " output_size = self.decoder.output_size\n", "\n", " # 各時刻ごとに処理\n", " for t in range(max_length):\n", " # decoder_input: (1, k)\n", " decoder_output, decoder_hidden = self.decoder(decoder_input[-1:], decoder_hidden)\n", " # decoder_output: (1, k, output_size)\n", " # decoder_hidden: (1, k, hidden_size)\n", " decoder_output_t = decoder_output[-1] # (k, output_size)\n", " log_probs = prev_probs + F.log_softmax(decoder_output_t, dim=-1) # (k, output_size)\n", " scores = log_probs # 対数尤度をスコアとする\n", "\n", " # スコアの高いビームとその単語を取得\n", " flat_scores = scores.view(-1) # (k*output_size,)\n", " if t == 0:\n", " flat_scores = flat_scores[:output_size] # t=0のときは後半の同じ値の繰り返しを除外\n", " top_vs, top_is = flat_scores.data.topk(k)\n", " beam_indices = top_is / output_size # (k,)\n", " word_indices = top_is % output_size # (k,)\n", " \n", " # ビームを更新する\n", " _next_beam_indices = []\n", " _next_word_indices = []\n", " for b, w in zip(beam_indices, word_indices):\n", " if w.item() == EOS: # EOSに到達した場合はそのビームは更新して終了\n", " k -= 1\n", " beam = torch.cat([decoder_input.t()[b], w.view(1,)]) # (t+2,)\n", " score = scores[b, w].item()\n", " finished_beams.append(beam)\n", " finished_scores.append(score)\n", " else: # それ以外の場合はビームを更新\n", " _next_beam_indices.append(b)\n", " _next_word_indices.append(w)\n", " if k == 0:\n", " break\n", "\n", " # tensornに変換\n", " next_beam_indices = torch.tensor(_next_beam_indices, device=device)\n", " next_word_indices = torch.tensor(_next_word_indices, device=device)\n", "\n", " # 次の時刻のDecoderの入力を更新\n", " decoder_input = torch.index_select(\n", " decoder_input, dim=-1, index=next_beam_indices)\n", " decoder_input = torch.cat(\n", " [decoder_input, next_word_indices.unsqueeze(0)], dim=0)\n", " \n", " # 次の時刻のDecoderの隠れ層を更新\n", " decoder_hidden = torch.index_select(\n", " decoder_hidden, dim=1, index=next_beam_indices)\n", "\n", " # 各ビームの対数尤度を更新\n", " flat_probs = log_probs.view(-1) # (k*output_size,)\n", " next_indices = (next_beam_indices + 1) * next_word_indices\n", " prev_probs = torch.index_select(\n", " flat_probs, dim=0, index=next_indices).unsqueeze(1) # (k, 1)\n", "\n", " # すべてのビームが完了したらデータを整形\n", " decoder_outputs = [[idx.item() for idx in beam[1:-1]] for beam in finished_beams]\n", " \n", " return decoder_outputs, finished_scores" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DIZ9NHXHttCg" }, "source": [ "# 学習済みモデルの読み込み\n", "beam_size = 3\n", "beam_model = BeamEncoderDecoder(**model_args, beam_size=beam_size).to(device)\n", "beam_model.load_state_dict(ckpt)\n", "beam_model.eval()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KzsI_6bBtwGc" }, "source": [ "test_dataloader = DataLoader(test_X, test_Y, batch_size=1, shuffle=False)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "iSr_wwtbtxKZ" }, "source": [ "# 生成\n", "batch_X, batch_Y, lengths_X = next(test_dataloader)\n", "sentence_X = ' '.join(ids_to_sentence(vocab_X, batch_X.data.cpu().numpy()[:-1, 0]))\n", "sentence_Y = ' '.join(ids_to_sentence(vocab_Y, batch_Y.data.cpu().numpy()[:-1, 0]))\n", "print('src: {}'.format(sentence_X))\n", "print('tgt: {}'.format(sentence_Y))\n", "\n", "# 普通のdecode\n", "output = model(batch_X, lengths_X, max_length=20)\n", "output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()\n", "output_sentence = ' '.join(ids_to_sentence(vocab_Y, trim_eos(output)))\n", "print('out: {}'.format(output_sentence))\n", "\n", "# beam decode\n", "outputs, scores = beam_model(batch_X, lengths_X, max_length=20)\n", "# scoreの良い順にソート\n", "outputs, scores = zip(*sorted(zip(outputs, scores), key=lambda x: -x[1]))\n", "for o, output in enumerate(outputs):\n", " output_sentence = ' '.join(ids_to_sentence(vocab_Y, output))\n", " print('out{}: {}'.format(o+1, output_sentence)) " ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "roOpX8nAAM95" }, "source": [ "# 参考文献\n", "- [Practical PyTorch: Translation with a Sequence to Sequence Network and Attention](https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb)\n", "- [Translation with a Sequence to Sequence Network and Attention](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#sphx-glr-intermediate-seq2seq-translation-tutorial-py)\n", "- [Encoder\\-decoderモデルとTeacher Forcing,Scheduled Sampling,Professor Forcing](http://satopirka.com/2018/02/encoder-decoder%E3%83%A2%E3%83%87%E3%83%AB%E3%81%A8teacher-forcingscheduled-samplingprofessor-forcing/)\n", "- [Sequence\\-to\\-Sequence Learning as Beam\\-Search Optimization](https://arxiv.org/abs/1606.02960)" ] }, { "cell_type": "markdown", "metadata": { "id": "xyVihX0V0NDt" }, "source": [ "### 考察\n", "TeacherForcingを使うことによって、誤差が極端に拡大してしまうことを防ぐ。\n", "これを使った場合はデコーダの出力を用いずに教師データを使うので誤りがなくなる\n", "しかしこれだけに頼ってしまうと、学習時と実用時の環境がかけ離れてしまい、推論を行うときに汎化性能のないモデルになってしまうため、一定の割合に制限して使用するようにしている。\n", "データローダーの仕組みは重要。\n", "こういう仕組みを使わないと、膨大なデータをすべて前処理した状態のものをメモリに一括で載せないといけなくなってしまい、メモリが大抵足らなくなる。\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "2kKxjl1y2nzv" }, "source": [ "# 実行結果\n", "\n", "'''\n", "src: he lived a hard life .\n", "tgt: 彼 は つら い 人生 を 送 っ た 。\n", "out: 彼 は 人生 を 生活 を 送 っ た 。\n", "\n", "\n", "src: no . i 'm sorry , i 've got to go back early .\n", "tgt: ごめん なさ い 。 早 く 帰 ら な く ちゃ 。\n", "out: いいえ 、 帰 っ たら 、 行 き き き だ 。\n", "\n", "src: she wrote to me to come at once .\n", "tgt: 彼女 は 私 に すぐ 来 い と の 便り を よこ し た 。\n", "out: 彼女 は すぐ に 来る と に し た た 。\n", "\n", "src: i can 't swim at all .\n", "tgt: 私 は 少し も 泳げ な い 。\n", "out: 私 は 泳げ 泳げ な い 。\n", "\n", "src: is there any hope of his success ?\n", "tgt: 彼 の 成功 の 見込み は あ り ま す か 。\n", "out: 彼 の 成功 は どう か あ り ま す か 。\n", "\n", "src: i 'll pick him up at 5 .\n", "tgt: 私 は 5 時 に 彼 を 迎え に 行 く つもり で す 。\n", "out: 私 は 彼 を に に を し ま 。 。\n", "\n", "src: it 's so lovely a day .\n", "tgt: 本当 に い い 天気 だ 。\n", "out: それ 一 日 で す 。\n", "\n", "src: show your own business .\n", "tgt: 自分 の 事 を しろ 。\n", "out: 君 の 商売 を <UNK> し て い 。\n", "\n", "\n", "src: i study english every day .\n", "tgt: 私 は 毎日 英語 の 勉強 を する 。\n", "out: 私 は 毎日 英語 を 勉強 し ま す 。\n", "\n", "src: i like spring the best of the seasons .\n", "tgt: 私 は 季節 の 中 で 春 が 好き だ 。\n", "out: 私 は 一番 が が 一番 好き が 一番 が が 好き 好き で 。\n", "\n", "src: you can 't have lost your coat in the house .\n", "tgt: 家 の 中 で コート が 無 くな る はず は な い 。\n", "out: コート を コート を コート を の は で で な な な い 。 。\n", "\n", "src: there are some oranges on the tree .\n", "tgt: 木 に オレンジ が いく つ か な っ て い る 。\n", "out: その の 木 の 木 木 が が い い い 。\n", "\n", "src: she died of shock .\n", "tgt: 彼女 は ショック 死 し た 。\n", "out: 彼女 は 死 ん で い た 。\n", "\n", "src: i like french food very much .\n", "tgt: 私 は フランス 料理 が 好き で す 。\n", "out: 私 は フランス 語 が 好き で す 。\n", "\n", "src: we went to boston , where we stayed a week .\n", "tgt: 私 たち は ボストン に 行 き 、 そこ に 一 週間 滞在 し た 。\n", "out: 私 たち は 、 、 、 、 、 、 、 、 た た た 。 た\n", "\n", "'''\n", "\n", "# 惜しいものもいくつかあるが、かなりわけのわからないものもある。\n", "# おなじ文字や単語が複数連続して出力されてしまうのが特徴的である。\n", "\n" ], "execution_count": null, "outputs": [] } ] }