{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 7. Neural Machine Translation and Models with Attention" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I recommend you take a look at these material first." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n", "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf\n", "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture11.pdf\n", "* https://arxiv.org/pdf/1409.0473.pdf\n", "* https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb\n", "* https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983\n", "* http://www.manythings.org/anki/" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import nltk\n", "import random\n", "import numpy as np\n", "from collections import Counter, OrderedDict\n", "import nltk\n", "from copy import deepcopy\n", "import os\n", "import re\n", "import unicodedata\n", "flatten = lambda l: [item for sublist in l for item in sublist]\n", "\n", "from torch.nn.utils.rnn import PackedSequence,pack_padded_sequence\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "random.seed(1024)\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", "torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def getBatch(batch_size, train_data):\n", " random.shuffle(train_data)\n", " sindex=0\n", " eindex=batch_size\n", " while eindex < len(train_data):\n", " batch = train_data[sindex: eindex]\n", " temp = eindex\n", " eindex = eindex + batch_size\n", " sindex = temp\n", " yield batch\n", " \n", " if eindex >= len(train_data):\n", " batch = train_data[sindex:]\n", " yield batch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Padding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowed image from https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983
" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# It is for Sequence 2 Sequence format\n", "def pad_to_batch(batch, x_to_ix, y_to_ix):\n", " \n", " sorted_batch = sorted(batch, key=lambda b:b[0].size(1), reverse=True) # sort by len\n", " x,y = list(zip(*sorted_batch))\n", " max_x = max([s.size(1) for s in x])\n", " max_y = max([s.size(1) for s in y])\n", " x_p, y_p = [], []\n", " for i in range(len(batch)):\n", " if x[i].size(1) < max_x:\n", " x_p.append(torch.cat([x[i], Variable(LongTensor([x_to_ix['']] * (max_x - x[i].size(1)))).view(1, -1)], 1))\n", " else:\n", " x_p.append(x[i])\n", " if y[i].size(1) < max_y:\n", " y_p.append(torch.cat([y[i], Variable(LongTensor([y_to_ix['']] * (max_y - y[i].size(1)))).view(1, -1)], 1))\n", " else:\n", " y_p.append(y[i])\n", " \n", " input_var = torch.cat(x_p)\n", " target_var = torch.cat(y_p)\n", " input_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in input_var]\n", " target_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in target_var]\n", " \n", " return input_var, target_var, input_len, target_len" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def prepare_sequence(seq, to_index):\n", " idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"\"], seq))\n", " return Variable(LongTensor(idxs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data load and Preprocessing " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Borrowed code from https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427\n", "def unicode_to_ascii(s):\n", " return ''.join(\n", " c for c in unicodedata.normalize('NFD', s)\n", " if unicodedata.category(c) != 'Mn'\n", " )\n", "\n", "# Lowercase, trim, and remove non-letter characters\n", "def normalize_string(s):\n", " s = unicode_to_ascii(s.lower().strip())\n", " s = re.sub(r\"([,.!?])\", r\" \\1 \", s)\n", " s = re.sub(r\"[^a-zA-Z,.!?]+\", r\" \", s)\n", " s = re.sub(r\"\\s+\", r\" \", s).strip()\n", " return s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

French -> English

" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "corpus = open('../dataset/eng-fra.txt', 'r', encoding='utf-8').readlines()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "142787" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(corpus)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "corpus = corpus[:30000] # for practice" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "MIN_LENGTH = 3\n", "MAX_LENGTH = 25" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "29830 29830\n", "['i', 'see', '.'] ['je', 'comprends', '.']\n", "CPU times: user 836 ms, sys: 8 ms, total: 844 ms\n", "Wall time: 843 ms\n" ] } ], "source": [ "%%time\n", "X_r, y_r = [], [] # raw\n", "\n", "for parallel in corpus:\n", " so,ta = parallel[:-1].split('\\t')\n", " if so.strip() == \"\" or ta.strip() == \"\": \n", " continue\n", " \n", " normalized_so = normalize_string(so).split()\n", " normalized_ta = normalize_string(ta).split()\n", " \n", " if len(normalized_so) >= MIN_LENGTH and len(normalized_so) <= MAX_LENGTH \\\n", " and len(normalized_ta) >= MIN_LENGTH and len(normalized_ta) <= MAX_LENGTH:\n", " X_r.append(normalized_so)\n", " y_r.append(normalized_ta)\n", " \n", "\n", "print(len(X_r), len(y_r))\n", "print(X_r[0], y_r[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build Vocab" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4433 7704\n" ] } ], "source": [ "source_vocab = list(set(flatten(X_r)))\n", "target_vocab = list(set(flatten(y_r)))\n", "print(len(source_vocab), len(target_vocab))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "source2index = {'': 0, '': 1, '': 2, '': 3}\n", "for vo in source_vocab:\n", " if source2index.get(vo) is None:\n", " source2index[vo] = len(source2index)\n", "index2source = {v:k for k, v in source2index.items()}\n", "\n", "target2index = {'': 0, '': 1, '': 2, '': 3}\n", "for vo in target_vocab:\n", " if target2index.get(vo) is None:\n", " target2index[vo] = len(target2index)\n", "index2target = {v:k for k, v in target2index.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare train data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.16 s, sys: 364 ms, total: 2.52 s\n", "Wall time: 2.89 s\n" ] } ], "source": [ "%%time\n", "X_p, y_p = [], []\n", "\n", "for so, ta in zip(X_r, y_r):\n", " X_p.append(prepare_sequence(so + [''], source2index).view(1, -1))\n", " y_p.append(prepare_sequence(ta + [''], target2index).view(1, -1))\n", " \n", "train_data = list(zip(X_p, y_p))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowd image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're not familier with pack_padded_sequence and pad_packed_sequence, check this post." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Encoder(nn.Module):\n", " def __init__(self, input_size, embedding_size,hidden_size, n_layers=1,bidirec=False):\n", " super(Encoder, self).__init__()\n", " \n", " self.input_size = input_size\n", " self.hidden_size = hidden_size\n", " self.n_layers = n_layers\n", " \n", " self.embedding = nn.Embedding(input_size, embedding_size)\n", " \n", " if bidirec:\n", " self.n_direction = 2 \n", " self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=True)\n", " else:\n", " self.n_direction = 1\n", " self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True)\n", " \n", " def init_hidden(self, inputs):\n", " hidden = Variable(torch.zeros(self.n_layers * self.n_direction, inputs.size(0), self.hidden_size))\n", " return hidden.cuda() if USE_CUDA else hidden\n", " \n", " def init_weight(self):\n", " self.embedding.weight = nn.init.xavier_uniform(self.embedding.weight)\n", " self.gru.weight_hh_l0 = nn.init.xavier_uniform(self.gru.weight_hh_l0)\n", " self.gru.weight_ih_l0 = nn.init.xavier_uniform(self.gru.weight_ih_l0)\n", " \n", " def forward(self, inputs, input_lengths):\n", " \"\"\"\n", " inputs : B, T (LongTensor)\n", " input_lengths : real lengths of input batch (list)\n", " \"\"\"\n", " hidden = self.init_hidden(inputs)\n", " \n", " embedded = self.embedding(inputs)\n", " packed = pack_padded_sequence(embedded, input_lengths, batch_first=True)\n", " outputs, hidden = self.gru(packed, hidden)\n", " outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) # unpack (back to padded)\n", " \n", " if self.n_layers > 1:\n", " if self.n_direction == 2:\n", " hidden = hidden[-2:]\n", " else:\n", " hidden = hidden[-1]\n", " \n", " return outputs, torch.cat([h for h in hidden], 1).unsqueeze(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Attention Mechanism ( https://arxiv.org/pdf/1409.0473.pdf )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I used general-type for score function $h_t^TW_ah_s^-$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf
" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Decoder(nn.Module):\n", " def __init__(self, input_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.1):\n", " super(Decoder, self).__init__()\n", " \n", " self.hidden_size = hidden_size\n", " self.n_layers = n_layers\n", " \n", " # Define the layers\n", " self.embedding = nn.Embedding(input_size, embedding_size)\n", " self.dropout = nn.Dropout(dropout_p)\n", " \n", " self.gru = nn.GRU(embedding_size + hidden_size, hidden_size, n_layers, batch_first=True)\n", " self.linear = nn.Linear(hidden_size * 2, input_size)\n", " self.attn = nn.Linear(self.hidden_size, self.hidden_size) # Attention\n", " \n", " def init_hidden(self,inputs):\n", " hidden = Variable(torch.zeros(self.n_layers, inputs.size(0), self.hidden_size))\n", " return hidden.cuda() if USE_CUDA else hidden\n", " \n", " \n", " def init_weight(self):\n", " self.embedding.weight = nn.init.xavier_uniform(self.embedding.weight)\n", " self.gru.weight_hh_l0 = nn.init.xavier_uniform(self.gru.weight_hh_l0)\n", " self.gru.weight_ih_l0 = nn.init.xavier_uniform(self.gru.weight_ih_l0)\n", " self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n", " self.attn.weight = nn.init.xavier_uniform(self.attn.weight)\n", "# self.attn.bias.data.fill_(0)\n", " \n", " def Attention(self, hidden, encoder_outputs, encoder_maskings):\n", " \"\"\"\n", " hidden : 1,B,D\n", " encoder_outputs : B,T,D\n", " encoder_maskings : B,T # ByteTensor\n", " \"\"\"\n", " hidden = hidden[0].unsqueeze(2) # (1,B,D) -> (B,D,1)\n", " \n", " batch_size = encoder_outputs.size(0) # B\n", " max_len = encoder_outputs.size(1) # T\n", " energies = self.attn(encoder_outputs.contiguous().view(batch_size * max_len, -1)) # B*T,D -> B*T,D\n", " energies = energies.view(batch_size,max_len, -1) # B,T,D\n", " attn_energies = energies.bmm(hidden).squeeze(2) # B,T,D * B,D,1 --> B,T\n", " \n", "# if isinstance(encoder_maskings,torch.autograd.variable.Variable):\n", "# attn_energies = attn_energies.masked_fill(encoder_maskings,float('-inf'))#-1e12) # PAD masking\n", " \n", " alpha = F.softmax(attn_energies,1) # B,T\n", " alpha = alpha.unsqueeze(1) # B,1,T\n", " context = alpha.bmm(encoder_outputs) # B,1,T * B,T,D => B,1,D\n", " \n", " return context, alpha\n", " \n", " \n", " def forward(self, inputs, context, max_length, encoder_outputs, encoder_maskings=None, is_training=False):\n", " \"\"\"\n", " inputs : B,1 (LongTensor, START SYMBOL)\n", " context : B,1,D (FloatTensor, Last encoder hidden state)\n", " max_length : int, max length to decode # for batch\n", " encoder_outputs : B,T,D\n", " encoder_maskings : B,T # ByteTensor\n", " is_training : bool, this is because adapt dropout only training step.\n", " \"\"\"\n", " # Get the embedding of the current input word\n", " embedded = self.embedding(inputs)\n", " hidden = self.init_hidden(inputs)\n", " if is_training:\n", " embedded = self.dropout(embedded)\n", " \n", " decode = []\n", " # Apply GRU to the output so far\n", " for i in range(max_length):\n", "\n", " _, hidden = self.gru(torch.cat((embedded, context), 2), hidden) # h_t = f(h_{t-1},y_{t-1},c)\n", " concated = torch.cat((hidden, context.transpose(0, 1)), 2) # y_t = g(h_t,y_{t-1},c)\n", " score = self.linear(concated.squeeze(0))\n", " softmaxed = F.log_softmax(score,1)\n", " decode.append(softmaxed)\n", " decoded = softmaxed.max(1)[1]\n", " embedded = self.embedding(decoded).unsqueeze(1) # y_{t-1}\n", " if is_training:\n", " embedded = self.dropout(embedded)\n", " \n", " # compute next context vector using attention\n", " context, alpha = self.Attention(hidden, encoder_outputs, encoder_maskings)\n", " \n", " # column-wise concat, reshape!!\n", " scores = torch.cat(decode, 1)\n", " return scores.view(inputs.size(0) * max_length, -1)\n", " \n", " def decode(self, context, encoder_outputs):\n", " start_decode = Variable(LongTensor([[target2index['']] * 1])).transpose(0, 1)\n", " embedded = self.embedding(start_decode)\n", " hidden = self.init_hidden(start_decode)\n", " \n", " decodes = []\n", " attentions = []\n", " decoded = embedded\n", " while decoded.data.tolist()[0] != target2index['']: # until \n", " _, hidden = self.gru(torch.cat((embedded, context), 2), hidden) # h_t = f(h_{t-1},y_{t-1},c)\n", " concated = torch.cat((hidden, context.transpose(0, 1)), 2) # y_t = g(h_t,y_{t-1},c)\n", " score = self.linear(concated.squeeze(0))\n", " softmaxed = F.log_softmax(score,1)\n", " decodes.append(softmaxed)\n", " decoded = softmaxed.max(1)[1]\n", " embedded = self.embedding(decoded).unsqueeze(1) # y_{t-1}\n", " context, alpha = self.Attention(hidden, encoder_outputs,None)\n", " attentions.append(alpha.squeeze(1))\n", " \n", " return torch.cat(decodes).max(1)[1], torch.cat(attentions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It takes for a while if you use just cpu...." ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "collapsed": true }, "outputs": [], "source": [ "EPOCH = 50\n", "BATCH_SIZE = 64\n", "EMBEDDING_SIZE = 300\n", "HIDDEN_SIZE = 512\n", "LR = 0.001\n", "DECODER_LEARNING_RATIO = 5.0\n", "RESCHEDULED = False" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "collapsed": true }, "outputs": [], "source": [ "encoder = Encoder(len(source2index), EMBEDDING_SIZE, HIDDEN_SIZE, 3, True)\n", "decoder = Decoder(len(target2index), EMBEDDING_SIZE, HIDDEN_SIZE * 2)\n", "encoder.init_weight()\n", "decoder.init_weight()\n", "\n", "if USE_CUDA:\n", " encoder = encoder.cuda()\n", " decoder = decoder.cuda()\n", "\n", "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n", "enc_optimizer = optim.Adam(encoder.parameters(), lr=LR)\n", "dec_optimizer = optim.Adam(decoder.parameters(), lr=LR * DECODER_LEARNING_RATIO)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[00/50] [000/466] mean_loss : 8.94\n", "[00/50] [200/466] mean_loss : 4.27\n", "[00/50] [400/466] mean_loss : 3.38\n", "[01/50] [000/466] mean_loss : 2.79\n", "[01/50] [200/466] mean_loss : 2.59\n", "[01/50] [400/466] mean_loss : 2.40\n", "[02/50] [000/466] mean_loss : 1.86\n", "[02/50] [200/466] mean_loss : 1.97\n", "[02/50] [400/466] mean_loss : 1.93\n", "[03/50] [000/466] mean_loss : 1.58\n", "[03/50] [200/466] mean_loss : 1.61\n", "[03/50] [400/466] mean_loss : 1.70\n", "[04/50] [000/466] mean_loss : 1.42\n", "[04/50] [200/466] mean_loss : 1.43\n", "[04/50] [400/466] mean_loss : 1.48\n", "[05/50] [000/466] mean_loss : 1.09\n", "[05/50] [200/466] mean_loss : 1.29\n", "[05/50] [400/466] mean_loss : 1.31\n", "[06/50] [000/466] mean_loss : 1.00\n", "[06/50] [200/466] mean_loss : 1.17\n", "[06/50] [400/466] mean_loss : 1.23\n", "[07/50] [000/466] mean_loss : 0.82\n", "[07/50] [200/466] mean_loss : 1.07\n", "[07/50] [400/466] mean_loss : 1.14\n", "[08/50] [000/466] mean_loss : 1.10\n", "[08/50] [200/466] mean_loss : 1.02\n", "[08/50] [400/466] mean_loss : 1.09\n", "[09/50] [000/466] mean_loss : 0.86\n", "[09/50] [200/466] mean_loss : 0.96\n", "[09/50] [400/466] mean_loss : 1.04\n", "[10/50] [000/466] mean_loss : 0.85\n", "[10/50] [200/466] mean_loss : 0.90\n", "[10/50] [400/466] mean_loss : 0.97\n", "[11/50] [000/466] mean_loss : 0.97\n", "[11/50] [200/466] mean_loss : 0.88\n", "[11/50] [400/466] mean_loss : 0.95\n", "[12/50] [000/466] mean_loss : 0.77\n", "[12/50] [200/466] mean_loss : 0.86\n", "[12/50] [400/466] mean_loss : 0.92\n", "[13/50] [000/466] mean_loss : 0.70\n", "[13/50] [200/466] mean_loss : 0.80\n", "[13/50] [400/466] mean_loss : 0.87\n", "[14/50] [000/466] mean_loss : 0.66\n", "[14/50] [200/466] mean_loss : 0.74\n", "[14/50] [400/466] mean_loss : 0.84\n", "[15/50] [000/466] mean_loss : 0.66\n", "[15/50] [200/466] mean_loss : 0.72\n", "[15/50] [400/466] mean_loss : 0.81\n", "[16/50] [000/466] mean_loss : 0.55\n", "[16/50] [200/466] mean_loss : 0.72\n", "[16/50] [400/466] mean_loss : 0.80\n", "[17/50] [000/466] mean_loss : 0.64\n", "[17/50] [200/466] mean_loss : 0.70\n", "[17/50] [400/466] mean_loss : 0.80\n", "[18/50] [000/466] mean_loss : 0.62\n", "[18/50] [200/466] mean_loss : 0.69\n", "[18/50] [400/466] mean_loss : 0.77\n", "[19/50] [000/466] mean_loss : 0.49\n", "[19/50] [200/466] mean_loss : 0.74\n", "[19/50] [400/466] mean_loss : 0.80\n", "[20/50] [000/466] mean_loss : 0.55\n", "[20/50] [200/466] mean_loss : 0.67\n", "[20/50] [400/466] mean_loss : 0.76\n", "[21/50] [000/466] mean_loss : 0.64\n", "[21/50] [200/466] mean_loss : 0.67\n", "[21/50] [400/466] mean_loss : 0.75\n", "[22/50] [000/466] mean_loss : 0.60\n", "[22/50] [200/466] mean_loss : 0.63\n", "[22/50] [400/466] mean_loss : 0.70\n", "[23/50] [000/466] mean_loss : 0.60\n", "[23/50] [200/466] mean_loss : 0.60\n", "[23/50] [400/466] mean_loss : 0.67\n", "[24/50] [000/466] mean_loss : 0.57\n", "[24/50] [200/466] mean_loss : 0.61\n", "[24/50] [400/466] mean_loss : 0.68\n", "[25/50] [000/466] mean_loss : 0.50\n", "[25/50] [200/466] mean_loss : 0.61\n", "[25/50] [400/466] mean_loss : 0.68\n", "[26/50] [000/466] mean_loss : 0.53\n", "[26/50] [200/466] mean_loss : 0.53\n", "[26/50] [400/466] mean_loss : 0.51\n", "[27/50] [000/466] mean_loss : 0.58\n", "[27/50] [200/466] mean_loss : 0.50\n", "[27/50] [400/466] mean_loss : 0.49\n", "[28/50] [000/466] mean_loss : 0.40\n", "[28/50] [200/466] mean_loss : 0.48\n", "[28/50] [400/466] mean_loss : 0.47\n", "[29/50] [000/466] mean_loss : 0.45\n", "[29/50] [200/466] mean_loss : 0.46\n", "[29/50] [400/466] mean_loss : 0.45\n", "[30/50] [000/466] mean_loss : 0.56\n", "[30/50] [200/466] mean_loss : 0.44\n", "[30/50] [400/466] mean_loss : 0.45\n", "[31/50] [000/466] mean_loss : 0.46\n", "[31/50] [200/466] mean_loss : 0.43\n", "[31/50] [400/466] mean_loss : 0.43\n", "[32/50] [000/466] mean_loss : 0.30\n", "[32/50] [200/466] mean_loss : 0.41\n", "[32/50] [400/466] mean_loss : 0.42\n", "[33/50] [000/466] mean_loss : 0.30\n", "[33/50] [200/466] mean_loss : 0.40\n", "[33/50] [400/466] mean_loss : 0.41\n", "[34/50] [000/466] mean_loss : 0.34\n", "[34/50] [200/466] mean_loss : 0.40\n", "[34/50] [400/466] mean_loss : 0.40\n", "[35/50] [000/466] mean_loss : 0.32\n", "[35/50] [200/466] mean_loss : 0.39\n", "[35/50] [400/466] mean_loss : 0.39\n", "[36/50] [000/466] mean_loss : 0.31\n", "[36/50] [200/466] mean_loss : 0.39\n", "[36/50] [400/466] mean_loss : 0.38\n", "[37/50] [000/466] mean_loss : 0.39\n", "[37/50] [200/466] mean_loss : 0.38\n", "[37/50] [400/466] mean_loss : 0.38\n", "[38/50] [000/466] mean_loss : 0.33\n", "[38/50] [200/466] mean_loss : 0.37\n", "[38/50] [400/466] mean_loss : 0.37\n", "[39/50] [000/466] mean_loss : 0.39\n", "[39/50] [200/466] mean_loss : 0.37\n", "[39/50] [400/466] mean_loss : 0.37\n", "[40/50] [000/466] mean_loss : 0.41\n", "[40/50] [200/466] mean_loss : 0.36\n", "[40/50] [400/466] mean_loss : 0.36\n", "[41/50] [000/466] mean_loss : 0.31\n", "[41/50] [200/466] mean_loss : 0.36\n", "[41/50] [400/466] mean_loss : 0.36\n", "[42/50] [000/466] mean_loss : 0.30\n", "[42/50] [200/466] mean_loss : 0.35\n", "[42/50] [400/466] mean_loss : 0.35\n", "[43/50] [000/466] mean_loss : 0.23\n", "[43/50] [200/466] mean_loss : 0.35\n", "[43/50] [400/466] mean_loss : 0.34\n", "[44/50] [000/466] mean_loss : 0.31\n", "[44/50] [200/466] mean_loss : 0.34\n", "[44/50] [400/466] mean_loss : 0.35\n", "[45/50] [000/466] mean_loss : 0.25\n", "[45/50] [200/466] mean_loss : 0.33\n", "[45/50] [400/466] mean_loss : 0.35\n", "[46/50] [000/466] mean_loss : 0.47\n", "[46/50] [200/466] mean_loss : 0.33\n", "[46/50] [400/466] mean_loss : 0.34\n", "[47/50] [000/466] mean_loss : 0.43\n", "[47/50] [200/466] mean_loss : 0.33\n", "[47/50] [400/466] mean_loss : 0.33\n", "[48/50] [000/466] mean_loss : 0.30\n", "[48/50] [200/466] mean_loss : 0.33\n", "[48/50] [400/466] mean_loss : 0.33\n", "[49/50] [000/466] mean_loss : 0.39\n", "[49/50] [200/466] mean_loss : 0.33\n", "[49/50] [400/466] mean_loss : 0.32\n" ] } ], "source": [ "for epoch in range(EPOCH):\n", " losses=[]\n", " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", " inputs, targets, input_lengths, target_lengths = pad_to_batch(batch, source2index, target2index)\n", " \n", " input_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data)))) for t in inputs]).view(inputs.size(0), -1)\n", " start_decode = Variable(LongTensor([[target2index['']] * targets.size(0)])).transpose(0, 1)\n", " encoder.zero_grad()\n", " decoder.zero_grad()\n", " output, hidden_c = encoder(inputs, input_lengths)\n", " \n", " preds = decoder(start_decode, hidden_c, targets.size(1), output, input_masks, True)\n", " \n", " loss = loss_function(preds, targets.view(-1))\n", " losses.append(loss.data.tolist()[0] )\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm(encoder.parameters(), 50.0) # gradient clipping\n", " torch.nn.utils.clip_grad_norm(decoder.parameters(), 50.0) # gradient clipping\n", " enc_optimizer.step()\n", " dec_optimizer.step()\n", "\n", " if i % 200==0:\n", " print(\"[%02d/%d] [%03d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, i, len(train_data)//BATCH_SIZE, np.mean(losses)))\n", " losses=[]\n", "\n", " # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n", " if RESCHEDULED == False and epoch == EPOCH//2:\n", " LR *= 0.01\n", " enc_optimizer = optim.Adam(encoder.parameters(), lr=LR)\n", " dec_optimizer = optim.Adam(decoder.parameters(), lr=LR * DECODER_LEARNING_RATIO)\n", " RESCHEDULED = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize Attention" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# borrowed code from https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb\n", "\n", "def show_attention(input_words, output_words, attentions):\n", " # Set up figure with colorbar\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(attentions.numpy(), cmap='bone')\n", " fig.colorbar(cax)\n", "\n", " # Set up axes\n", " ax.set_xticklabels([''] + input_words, rotation=90)\n", " ax.set_yticklabels([''] + output_words)\n", "\n", " # Show label at every tick\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", "# show_plot_visdom()\n", " plt.show()\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Source : he is telling a lie .\n", "Truth : il dit un mensonge .\n", "Prediction : il dit un mensonge .\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEUCAYAAADOaUa5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGHtJREFUeJzt3X+4XVV95/H3J0FEIAJjoE4JGNQgzYMWnBic+pMWbKQK\nM4/WgqUFhaYdhQHRPmLrA31oqZVOnYrF1qiIMlpECzXWKBYFaalAEqBgMmAzsUqCBa6liFIC5Hzm\nj72vHm5z7znXfe7d+677eeXZzz17n3XX+d6T5HvXWWvttWSbiIgo04K2A4iIiJmTJB8RUbAk+YiI\ngiXJR0QULEk+IqJgSfIREQVLko+IKFiSfEREwZLkIyIKtlvbAUT5JJ2zi8sPARtt3z7b8UTMJ8qy\nBjHTJH0KWAF8vr70GuAOYCnwGdsXtRRaRPGS5GPGSboBOM72D+rzvYEvAKuoWvPL24wvomTpk4/Z\ncACwo+/8ceCnbP/7hOsRMWLpk4/Z8EngZkmfq89fC3xK0l7A5vbCiihfumtiVkh6EfBz9emNtje0\nGU/EfJEkH7NC0kLgp+j79Gj7O+1FFHOVpGcC9znJayjprokZJ+lM4HzgPmAnIMDAC9qMK+YeSfsB\nW4GTgM8NKB6kJR+zQNIW4Cjb32s7lpjbJJ0BHAsssP3atuOZCzK7JmbDPVQ3P0U09SbgDOAgSf+5\n7WDmgnTXxGzYClwv6Qv0TZm0/b72Qoq5RtIKYMz2PZI+AZwKvKfdqLovLfmYDd8B/hbYHVjUd0RM\nx2nAR+vHlwO/1mIsc0b65COi8yTtCWwCDrX9eH3tauD9tq9vM7auS5KPGSPpT22fLenzVLNpnsT2\n8S2EFXOQpKcA+9m+v+/a0wFsf7+1wOaA9MnHTLq8/vq/Wo0i5jzbj0v6oaQFtnuSDgUOA77Ydmxd\nl5Z8RMwJkjYCLwP2A24E1gOP2f7VVgPruLTkY8ZIupNddNOMs52boWI6ZPsRSacBH7R9kaTsRzBA\nknzMpNe0HUAURZL+K/CrVDNtABa2GM+ckCQfM8b2t9uOIYpyNvAu4GrbmyQ9G7iu5Zg6L33yI1AP\nAv051Rrph0t6AXC87T9oObRWSXqYH3fXqP7q+rFtP72VwDpM0rOAZbavlfQ0YDfbD7cdV5skvQv4\nku3b2o5lLsrNUKPxYaoWxuMAtu8ATmw1og6wvcj20+tjUd/5oiT4/0jSbwCfBT5UX1oC/HV7EXXG\nVuAsSbdJukzSr9QLlcUQ0l0zGnvavkVS/7Un2gqmiyS9lKqF+jFJi4FFtr/Vdlwd81ZgJXAzgO1/\nknRAuyG1z/angU8DSDqSatvIq+rlq6+lauXf0mKInZaW/GiMSXoOddeEpNcD3203pO6QdD7wTqpP\nO1Atb/B/2ouos3bYfmz8RNJuTDE7aT6yfZvt99g+mmpgfxNwesthdVqS/Gi8leoj9mGStlMNEP1W\nmwFJekm9vR6STpb0vrq/tw3/HTge+CGA7XvJ2jW78jVJvwM8TdKxwGeAz7ccUydI2lPSz064vC9w\nk+3VbcQ0VyTJj8Z24GPAhcAVVItxndJqRNVA8CP1f4y3A/8P+ERLsTxW7+Iz/klnr5bieBJJ+0la\nKenl40fLIZ0LPADcCfwmsA54d6sRdcfjVF00/f92PgJkueEB0ic/Gp8D/g24Fbi35VjGPWHbkk4A\n/sz2R+ubSNpwpaQPAfvWg4tvphqsbo2k04GzqAY3bwdeDHwd+Pm2YrLdo3pfWn1vuqhe1uBq4A3A\nxyQdDOyfvYIHS5IfjSW2V7UdxAQP11PPTgZeLmkB8JSWYtmfatbI94HnAecBx7QUy7izgBdRfdw/\nWtJhwB+2EYikK22/YbI7hHNn8I98BFhD9an51+uvMUCS/Gj8g6Tn276z7UD6/ArwRuA02/9St3z+\nuKVYjrX9TqpuLAAk/QnVYGxbHrX9qCQkPdX2XZKe11IsZ9Vfc4fwFOq/I9X3pZxItY5NDJCboRro\na3ntBiyjms+7gx/f7DOvW2CS/gfwFuDZVGMC4xYBN9o+uZXA+NFa5G+iGiT/eeBB4Cm2j2srprlA\n0jNt/0uLr38qVXffdtsntRXHXJIk38Cg2Spt3NYv6e9tv3TC3abQwl2mkvahWjHwPVSDiuMetv2v\nsxXHIJJeAexDNd/6sUHlZ+D1J/5d/egpOnZnsKQv2P6lFl9/T6rpya+zfW1bccwlSfIREQXLFMqI\niIIlyc8ASZ26OSPxTK1r8UD3Yko8s0PSpZLul/SNSZ6XpIslbZF0h6QXDqozSX5mdO0fYOKZWtfi\nge7FlHhmx2VUa/NM5tVUkzyWUb0Hfz6owiT5iIiOsH0DMNWkhBOAT7hyE9UNhlPe9Zt58rXFixd7\n6dKlI6nr4IMPZsWKFY1GtDdu3DiSWMZJ6tQIe+IZrGsxFRrPmO39m1SwatUqj42NDVV248aNm4BH\n+y6tsb1mGi93IHBP3/m2+tqkCyImydeWLl3Khg3duUN6wrLFETEzGk9zHhsbGzp3SHrU9oqmrzkd\nSfIREQ3N4lT07cBBfedL6muTSp98REQDBnb2ekMdI7AW+PV6ls2LgYdsT7l3RVryERGNGI9obxdJ\nfwm8ElgsaRtwPvXCgrb/gmr56eOALcAjVEtzTClJPiKiCUNvRL01g9bjqfdleOt06kySj4hoqMvL\nwyTJR0Q0YKCXJB8RUa605CMiCmV7VDNnZkSSfEREQ2nJR0QUbFRTKGdCknxERAPVwGvbUUwuST4i\noqF010RElKrjA69Fr10j6R/qr0sn22klIqIJU7XkhznaUHRL3vbPtR1DRJQvN0O1RNIPbO/ddhwR\nUbYu98kX3V0ziKTVkjZI2vDAAw+0HU5EzEke+k8b5nWSt73G9grbK/bfv9EOYBExT7lehXKYow1F\nd9dERMyGXodn1yTJR0Q0kFUoIyIK1+WB16KT/PjMGtv/DBzebjQRUSQ7LfmIiJKlJR8RUSgDO5Pk\nIyLKlZZ8RETBkuQjIgrlDLxGRJQtLfmIiIIlyUdEFKqaXZNlDSIiipU9XiMiStXirk/DSJKPiGhg\nfPu/rkqSj4hoKFMo54CNGzciqe0wOquLLZX8fUVXdPH/x7gk+YiIBmyzM5uGRESUq639W4eRJB8R\n0VCXp1DO6428IyKaGp9dM8wxiKRVku6WtEXSubt4/mBJ10m6TdIdko4bVGeSfEREQ6NI8pIWApcA\nrwaWAydJWj6h2LuBK20fCZwIfHBQbOmuiYhoYnQDryuBLba3Aki6AjgB2Nz/asDT68f7APcOqjRJ\nPiKigRHeDHUgcE/f+TbgqAllfg/4sqQzgb2AYwZVmu6aiIiGevWa8oMOYLGkDX3H6mm+1EnAZbaX\nAMcBl0uaMo+nJR8R0dA0plCO2V4xyXPbgYP6zpfU1/qdBqwCsP11SXsAi4H7J3vBtOQjIhqyhzsG\nWA8sk3SIpN2pBlbXTijzHeAXACT9DLAH8MBUlaYlHxHRgBnN2jW2n5B0BnANsBC41PYmSRcAG2yv\nBd4OfFjS2+qXPtUDBgSS5CMimhjhsga21wHrJlw7r+/xZuAl06kzST4iooEsNRwRUbgk+Vki6feA\nH1DdLHCD7WslnQ2ssf1Iq8FFRLG6vJ58kbNrbJ9n+9r69GxgzzbjiYiSeeg/bZjzLXlJvwucQjVP\n9B5go6TLgL8Bfro+rpM0Zvvo1gKNiCINOT2yNXM6yUv6L1RzSY+g+lluBTaOP2/7YknnAEfbHmsn\nyogoXTYNmTkvA64e72+XNPHGgSnVtxRP97biiIgfGdU8+Zky15N8I7bXAGsAJHX3bykiOq3Ls2vm\n+sDrDcB/k/Q0SYuA1+6izMPAotkNKyLmjSHXkm/rF8GcbsnbvlXSp4F/pBp4Xb+LYmuAL0m6NwOv\nETEjOtySn9NJHsD2hcCFUzz/AeADsxdRRMw3vZ1J8hERRaqmUCbJR0QUK0k+IqJY7Q2qDiNJPiKi\nIfeS5CMiipQ++YiIwjnLGkRElKvDDfkk+YiIRuz0yUdElCx98hERhcoerxERhUuSj4golY13ZnZN\nRESx0pKPOU9S2yH8B137j9XF9yhmR8f+KT5JknxERAMZeI2IKFmWNYiIKJnpZeA1IqJcaclHRBQq\nq1BGRJQuST4iolzubpd8knxERFPpromIKJVNL5uGRESUqes3Qy1oO4CIiDnN1UbewxyDSFol6W5J\nWySdO0mZN0jaLGmTpE8NqjMt+YiIpkbQkpe0ELgEOBbYBqyXtNb25r4yy4B3AS+x/aCkAwbVm5Z8\nREQjxh7uGGAlsMX2VtuPAVcAJ0wo8xvAJbYfBLB9/6BKk+QjIhrq9TzUASyWtKHvWN1XzYHAPX3n\n2+pr/Q4FDpV0o6SbJK0aFFu6ayIiGnDdJz+kMdsrGrzcbsAy4JXAEuAGSc+3/W9TfcOcJmkp8De2\nD6/P3wHsTfUm3AwcDewLnGb779qJMiJKNqLZNduBg/rOl9TX+m0Dbrb9OPAtSd+kSvrrJ6u09O6a\n3WyvBM4Gzm87mIgo04j65NcDyyQdIml34ERg7YQyf03VgEXSYqrum61TVTrnW/IDXFV/3Qgsnfhk\n3R+2euL1iIjhDZXAB9diPyHpDOAaYCFwqe1Nki4ANtheWz/3KkmbgZ3Ab9v+3lT1lpDkn+DJn0j2\n6Hu8o/66k138rLbXAGsAJHX3boaI6K4RrkJpex2wbsK18/oeGzinPoZSQnfNfcABkp4h6anAa9oO\nKCLmDwPe6aGONsz5lrztx+uPM7dQDVLc1XJIETHPdHlZgzmf5AFsXwxcPMXzY+yiTz4iorHhBlVb\nU0SSj4ho0zTmyc+6JPmIiIbSko+IKFTXlxpOko+IaMLG2TQkIqJc2eM1IqJg6a6JiCjVCO94nQlJ\n8hERDWTgNSKiaKa3s7ud8knyERFNpLsmIqJwSfIREeXqcI5Pko+IaCIDrxEzRFLbITxJ1/6jd+39\nKdb0NvKedUnyERGNmF6WNYiIKFfXPsX1S5KPiGgqST4iokxOn3xERNk63JBPko+IaCZ7vEZElMtk\ndk1ERKlM+uQjIoqW7pqIiGK50yOvSfIREU1kqeGIiLL1dibJR0QUKatQRkSULN01EREl6/bNUAsG\nFZC0VNJdki6T9E1Jn5R0jKQbJf2TpJWS9pJ0qaRbJN0m6YT6e0+VdJWkL9VlL6qvL6zr+4akOyW9\nrb5+hKSbJN0h6WpJ+9XXr5f03rr+b0p6WX19T0lXStpcl79Z0or6uVdJ+rqkWyV9RtLeM/UmRsT8\nZnuoow3DtuSfC/wy8GZgPfBG4KXA8cDvAJuBr9p+s6R9gVskXVt/7xHAkcAO4G5JHwAOAA60fThA\n/T0AnwDOtP01SRcA5wNnj8dqe6Wk4+rrxwBvAR60vVzS4cDtdX2LgXcDx9j+oaR3AucAF0zz/YmI\nGKjLN0MNbMnXvmX7Tts9YBPwFVe/lu4ElgKvAs6VdDtwPbAHcHD9vV+x/ZDtR6l+GTwL2Ao8W9IH\nJK0Cvi9pH2Bf21+rv+/jwMv7Yriq/rqxfk2oftFcAWD7G8Ad9fUXA8uBG+uYTqlf90kkrZa0QdKG\nId+HiIgnGV+FcphjEEmrJN0taYukc6co9zpJHu+5mMqwLfkdfY97fee9uo6dwOts3z0hkKMmfO9O\nqhb5g5J+FvhF4LeANwBvGzKGnUPELeBvbZ80VSHba4A1dazd/VUcEZ02iq4YSQuBS4BjgW3Aeklr\nbW+eUG4RcBZw8zD1DtuSH+Qa4EzVm0pKOnKqwnV3ygLbf0XVrfJC2w8BD473twO/BnxtsjpqN1L9\ngkDScuD59fWbgJdIem793F6SDp3+jxURMchw/fFD/CJYCWyxvdX2Y1S9FCfsotzvA+8FHh0mulHN\nrvl94E+BOyQtAL4FvGaK8gcCH6vLAryr/noK8BeS9qTq0nnTgNf9IPBxSZuBu6i6kh6y/YCkU4G/\nlPTUuuy7gW9O78eKiBhgdJuGHAjc03e+DTiqv4CkFwIH2f6CpN8eptKBSd72PwOH952fOslzv7mL\n770MuKzvvD/xv3AX5W+n6k+feP2VfY/H+HGf/KPAybYflfQc4Frg23W5rwIvmuJHi4gYiWl01yye\nMAa4pu42HqhuFL8POHU6sc31efJ7AtdJegpVP/xb6o85ERGzYpp3vI7ZnmywdDtwUN/5kvrauEVU\njerr657xZwJrJR1ve9LJI3M6ydt+GBg4uhwRMXOMR7NpyHpgmaRDqJL7iVTT1atXqcYtF4+fS7oe\neMdUCR5GN/AaETE/Gdwb7piyGvsJ4AyqiSz/F7jS9iZJF0g6/icNb0635CMiumBUd7PaXgesm3Dt\nvEnKvnKYOpPkIyIa6vLaNUnyERENZKnhiIiS2fR2jmTgdUYkyUdENJWWfEREuUySfEREkZydoSIi\nSmY8aBJ8i5LkIyIaSks+IqJgvdEsazAjkuQjIhqo1opPko+IKFe6ayIiypUplBERBcvAa0REsUyv\nt7PtICaVJB8R0UBuhoqIKFySfEREwZLkIyKK5UyhjIgomcnNUBERRbKzrEFERMGcPvmIiJJl7ZqI\niIKlJR8RUbAk+YiIUjlTKCMiimWg56xdExFRqMyu6SxJq4HVbccREXNbknxH2V4DrAGQ1N2/pYjo\ntCT5iIhCVeOumScfEVEo4w4va7Cg7QBmg6R1kn667Tgiokwe8k8b5kVL3vZxbccQEeVKn3xERLGc\nPvmIiFJ1fY/XedEnHxExk2wPdQwiaZWkuyVtkXTuLp4/R9JmSXdI+oqkZw2qM0k+IqKhXq831DEV\nSQuBS4BXA8uBkyQtn1DsNmCF7RcAnwUuGhRbknxERCMG94Y7prYS2GJ7q+3HgCuAE570SvZ1th+p\nT28ClgyqNEk+IqKhaUyhXCxpQ9/Rv6zKgcA9fefb6muTOQ344qDYMvAaEdHANAdex2yvaPqakk4G\nVgCvGFQ2ST4ioqERza7ZDhzUd76kvvYkko4Bfhd4he0dgypNko+IaGRk8+TXA8skHUKV3E8E3thf\nQNKRwIeAVbbvH6bSJPmIiIYGzZwZhu0nJJ0BXAMsBC61vUnSBcAG22uBPwb2Bj4jCeA7to+fqt4k\n+YiIBkZ5M5TtdcC6CdfO63t8zHTrTJKPiGgke7xGRBTNZO2aiIhidXntmiT5iIhGPJKB15mSJB8R\n0UC2/4uIKFy6ayIiCpYkHxFRrEyhjIgoWlubdA8jST4iogEber2dbYcxqST5iIhGhtvary1J8hER\nDSXJR0QULEk+IqJguRkqIqJUzhTKiIhiGeilJT86kk4EnmP7wrZjiYiAbnfXLGg7gEEk7S5pr75L\nrwa+NGTZiIgZVk2hHOZoQ2eTvKSfkfQnwN3AofU1AUcAt0p6haTb6+M2SYuA/YBNkj4k6UXtRR8R\n80mS/JAk7SXpTZL+HvgwsBl4ge3b6iJHAv/o6t16B/BW20cALwP+3fZ9wPOA64AL6+T/PyX9p9n/\naSJiPhjf47WrSb5rffLfBe4ATrd91y6eXwV8sX58I/A+SZ8ErrK9DcD2DuAK4ApJBwN/Blwk6dm2\n7+2vTNJqYPXM/CgRMT8Yd3hZg0615IHXA9uBqySdJ+lZE55/FfBlANt/BJwOPA24UdJh44UkHSDp\n7cDngYXAG4H7Jr6Y7TW2V9heMSM/TUTMCx7yTxs61ZK3/WXgy5KeAZwMfE7SGFUyfxDYzfb3ACQ9\nx/adwJ11//thkr4LfBw4DLgcOM729jZ+loiYP3LH6zTVifz9wPslrQR2AscC1/YVO1vS0UAP2ETV\njbMHcDFwnbv8rkdEUbqcbjqZ5PvZvgVA0vnAR/qun7mL4juAr85SaBER9aBqd+fJdz7Jj7N9etsx\nRETsSlryEREF6/XSko+IKFda8hERpTImLfmIiCKN3/HaVUnyERENJclHRBQsST4iolim1+G1a5Lk\nIyIa6HqffNcWKIuImHvG93kddAwgaZWkuyVtkXTuLp5/qqRP18/fLGnpoDqT5CMiGhl2Dcqpk7yk\nhcAlVLvfLQdOkrR8QrHTgAdtPxf438B7B0WXJB8R0ZDdG+oYYCWwxfZW249R7YtxwoQyJ1CttAvw\nWeAX6h3zJpU++YiIhka0rMGBwD1959uAoyYrY/sJSQ8BzwDGJqs0Sf7HxoBvj6iuxUzxprcg8Uxt\nJPEMaFBNV5Hv0QiNKp6JGxP9JK6himcYe0ja0He+xvaaEcQwqST5mu39R1WXpA1d2m0q8Uyta/FA\n92JKPJOzvWpEVW0HDuo7X1Jf21WZbZJ2A/YBvjdVpemTj4johvXAMkmHSNodOBFYO6HMWuCU+vHr\nga8O2iApLfmIiA6o+9jPoOr+WQhcanuTpAuADbbXAh8FLpe0BfhXql8EU0qSnxkz2sf2E0g8U+ta\nPNC9mBLPLLC9Dlg34dp5fY8fBX55OnWqy3dqRUREM+mTj4goWJJ8RETBkuQjIgqWJB8RUbAk+YiI\ngiXJR0QULEk+IqJg/x/AvR9vGPeJlAAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test = random.choice(train_data)\n", "input_ = test[0]\n", "truth = test[1]\n", "\n", "output, hidden = encoder(input_, [input_.size(1)])\n", "pred, attn = decoder.decode(hidden, output)\n", "\n", "input_ = [index2source[i] for i in input_.data.tolist()[0]]\n", "pred = [index2target[i] for i in pred.data.tolist()]\n", "\n", "\n", "print('Source : ',' '.join([i for i in input_ if i not in ['']]))\n", "print('Truth : ',' '.join([index2target[i] for i in truth.data.tolist()[0] if i not in [2, 3]]))\n", "print('Prediction : ',' '.join([i for i in pred if i not in ['']]))\n", "\n", "if USE_CUDA:\n", " attn = attn.cpu()\n", "\n", "show_attention(input_, pred, attn.data)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# TODO " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* BLEU\n", "* Beam Search\n", "* Sampled Softmax" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Further topics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Convolutional Sequence to Sequence learning\n", "* Attention is all you need\n", "* Unsupervised Machine Translation Using Monolingual Corpora Only" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Suggested Reading " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* SMT chapter13. Neural Machine Translation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }