{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Machine Translation Tutorial\n", "---\n", "\n", "Paper Implementation: [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) - Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio (v7 2016)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import numpy as np\n", "import matplotlib.pylab as plt \n", "import matplotlib.ticker as ticker\n", "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", "DEVICE=None" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from matplotlib import font_manager, rc\n", "font_name = font_manager.FontProperties(fname='/usr/share/fonts/truetype/nanum/NanumGothicLight.ttf').get_name()\n", "rc('font', family=font_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 신경망 번역기의 목적\n", "\n", "소스(source)문장을 타겟(target)문장으로 변환하는 것\n", "\n", "Examples:\n", "\n", "|source|target|\n", "|--|--|\n", "| Nice to meet you| 만나서 반갑습니다 |\n", "| I am very happy to meet you | 만나서 참 반가워요| " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "flatten = lambda d: [t for s in d for t in s]\n", "\n", "def build_vocab(data, start_tkn=False):\n", " \"\"\"build vocabulary\"\"\"\n", " if start_tkn:\n", " vocab = {'': 0, '': 1, '': 2, '': 3}\n", " else:\n", " vocab = {'': 0, '': 1}\n", " \n", " words = set(flatten(data))\n", " for t in words:\n", " if vocab.get(t) is None:\n", " vocab[t] = len(vocab) \n", " return vocab\n", "\n", "def add_pad(data, start_tkn=False):\n", " \"\"\"add padding of sentences in batch to match lenghts\"\"\"\n", " if start_tkn:\n", " data = [[''] + sent + [''] for sent in data]\n", " max_len = max([len(sent) for sent in data])\n", " data = [sent + ['']*(max_len-len(sent)) if len(sent) < max_len else sent \\\n", " for sent in data ]\n", " return data\n", "\n", "def numericalize(data, vocab):\n", " \"\"\"numericalize and turn them into tensor\"\"\"\n", " f = lambda x: [vocab.get(t) if vocab.get(t) is not None else vocab.get('') for t in x]\n", " data = list(map(f, data))\n", " return data\n", "\n", "def preprocess(data, vocab, start_tkn=False):\n", " data = add_pad(data, start_tkn=start_tkn)\n", " data = numericalize(data, vocab)\n", " # 텐서플로우는 아래 부분을 수정해야할겁니다. torch.LongTensor 를 제거하고 data 로만 두시고\n", " # 숫자로 치환된 data 를 numpy array 로 출력하신 다음에 진행하면 될것 같습니다. \n", " return torch.LongTensor(data) \n", "\n", "def build_batch(src, trg, src_vocab, trg_vocab, is_sort=False):\n", " if is_sort:\n", " sorted_data = sorted(list(zip(src, trg)), key=lambda x: len(x[0]), reverse=True)\n", " src, trg = list(zip(*sorted_data))\n", " src = preprocess(src, src_vocab, start_tkn=False)\n", " trg = preprocess(trg, trg_vocab, start_tkn=True)\n", " if is_sort:\n", " return (src, src.ne(src_vocab.get('')).sum(1)), trg\n", " return src, trg" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset = \"\"\"Nice to meet you > 만나서 반가워요 \\n I am very happy to meet you > 만나서 참 반가워요\"\"\".splitlines()\n", "dataset = [s.strip().split('>') for s in dataset]\n", "src, trg = [[sent.split() for sent in x] for x in zip(*dataset)]\n", "src_vocab = build_vocab(src)\n", "trg_vocab = build_vocab(trg, start_tkn=True)\n", "(inputs, lengths), targets = build_batch(src, trg, src_vocab, trg_vocab, is_sort=True)\n", "# 향후에 필요함\n", "trg_itos = sorted([(v, k) for k, v in trg_vocab.items()], key=lambda x: x[0])\n", "trg_itos = [x[1] for x in trg_itos]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2, 6, 9, 7, 5, 8, 3],\n", " [4, 5, 8, 3, 1, 1, 1]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([7, 4])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lengths" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2, 4, 5, 6, 3],\n", " [2, 4, 6, 3, 1]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Structure\n", "\n", "![model_structure](./pics/translation.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Encoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![encoder](./pics/translation_encoder.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparameters for Encoder" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "EMBED = 20 # embedding_size\n", "HIDDEN = 60 # hidden_size\n", "ENC_N_LAYER = 3 # encoder number of layers\n", "L_NORM = True # whether to use layernorm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder Model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class Encoder(nn.Module):\n", " \"\"\"Encoder\"\"\"\n", " def __init__(self, vocab_size, embed_size, hidden_size, n_layers, layernorm=False, bidirec=False):\n", " super(Encoder, self).__init__() \n", " self.hidden_size = hidden_size\n", " self.n_layers = n_layers\n", " self.n_direction = 2 if bidirec else 1\n", " self.layernorm = layernorm\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.gru = nn.GRU(embed_size, hidden_size, n_layers, bidirectional=bidirec, \n", " batch_first=True)\n", " if layernorm:\n", " self.l_norm = nn.LayerNorm(embed_size)\n", " \n", " def forward(self, inputs, lengths):\n", " \"\"\"\n", " Inputs:\n", " - inputs: B, T_e\n", " - lengths: B, (list)\n", " Outputs:\n", " - outputs: B, T_e, n_directions*H\n", " - hiddens: 1, B, n_directions*H\n", " \"\"\"\n", " assert isinstance(lengths, list), \"lengths must be a list type\"\n", " # B: batch_size, T_e: enc_length, M: embed_size, H: hidden_size\n", " inputs = self.embedding(inputs) # (B, T_e) > (B, T_e, m)\n", " if self.layernorm:\n", " inputs = self.l_norm(inputs)\n", " \n", " packed_inputs = pack_padded_sequence(inputs, lengths, batch_first=True)\n", " # packed_inputs: (B*T_e, M) + batches: (T_e)\n", " packed_outputs, hiddens = self.gru(packed_inputs)\n", " # packed_outputs: (B*T_e, n_directions*H) + batches: (T_e)\n", " # hiddens: (n_layers*n_directions, B, H)\n", " outputs, outputs_lengths = pad_packed_sequence(packed_outputs, batch_first=True)\n", " # output: (B, T_e, n_directions*H) + lengths (B)\n", " hiddens = torch.cat([h for h in hiddens[-self.n_direction:]], 1).unsqueeze(0)\n", " # hiddens: (1, B, n_directions*H)\n", " return outputs, hiddens" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 7, 120]), torch.Size([1, 2, 120]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = Encoder(len(src_vocab), EMBED, HIDDEN, ENC_N_LAYER, L_NORM, bidirec=True)\n", "enc_output, enc_hidden = encoder(inputs, lengths.tolist())\n", "enc_output.size(), enc_hidden.size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decoder" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 20])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dec_embedding = nn.Embedding(len(trg_vocab), EMBED)\n", "sos = torch.LongTensor([2]*inputs.size(0)).unsqueeze(1)\n", "dec_input = dec_embedding(sos)\n", "dec_input.size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparemeters for Decoder" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "DEC_N_LAYER = 1\n", "DROP_RATE = 0.2\n", "METHOD = 'general'\n", "TF = True\n", "RETURN_W = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Attention" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![attention](./pics/translation_attention.png)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Attention(nn.Module):\n", " \"\"\"Attention\"\"\"\n", " def __init__(self, hidden_size, method='general', device='cpu'):\n", " super(Attention, self).__init__()\n", " \"\"\"\n", " * hidden_size: decoder hidden_size(H_d=encoder_gru_direction*H)\n", " methods:\n", " - 'dot': dot product between hidden and encoder_outputs\n", " - 'general': encoder_outputs through a linear layer \n", " - 'concat': concat (hidden, encoder_outputs) ***NOT YET***\n", " - 'paper': concat + tanh ***NOT YET***\n", " \"\"\"\n", " self.method = method\n", " self.device = device\n", " self.hidden_size = hidden_size \n", " if self.method == 'general':\n", " self.linear = nn.Linear(hidden_size, hidden_size)\n", "\n", " def forward(self, hiddens, enc_outputs, enc_lengths=None, return_weight=False):\n", " \"\"\"\n", " Inputs:\n", " - hiddens(previous_hiddens): B, 1, H_d\n", " - enc_outputs(enc_outputs): B, T_e, H_d\n", " - enc_lengths: real lengths of encoder outputs\n", " - return_weight = return weights(alphas)\n", " Outputs:\n", " - contexts: B, 1, H_d\n", " - attns: B, 1, T_e\n", " \"\"\"\n", " hid, out = hiddens, enc_outputs\n", " # Batch(B), Seq_length(T)\n", " B, T_d, H = hid.size()\n", " B, T_e, H = out.size()\n", " \n", " score = self.get_score(hid, out)\n", " # score: B, 1, T_e\n", " if enc_lengths is not None:\n", " mask = self.get_mask(B, T_d, T_e, enc_lengths) # masks: B, 1, T_e\n", " score = score.masked_fill(mask, float('-inf'))\n", " \n", " attns = torch.softmax(score, dim=2) # attns: B, 1, T_e\n", " contexts = attns.bmm(out)\n", " if return_weight:\n", " return contexts, attns\n", " return contexts\n", " \n", " def get_score(self, hid, out):\n", " \"\"\"\n", " Inputs:\n", " - hid(previous_hiddens): B, 1, H_d \n", " - out(enc_outputs): B, T_e, H_d\n", " Outputs:\n", " - score: B, 1, T_e\n", " \"\"\"\n", " if self.method == 'dot':\n", " # bmm: (B, 1, H_d) * (B, H, T_e) = (B, 1, T_e)\n", " score = hid.bmm(out.transpose(1, 2))\n", " return score\n", " \n", " elif self.method == 'general':\n", " # linear: (B, T_e, H_d) > (B, T_e, H_d)\n", " # bmm: (B, 1, H_d) * (B, H_d, T_e) = (B, 1, T_e)\n", " score = self.linear(out)\n", " score = hid.bmm(score.transpose(1, 2))\n", " return score\n", "\n", " def get_mask(self, B, T_d, T_e, lengths):\n", " assert isinstance(lengths, list), \"lengths must be list type\"\n", " mask = torch.zeros(B, T_d, T_e, dtype=torch.uint8).to(self.device)\n", " for i, x in enumerate(lengths):\n", " if x < T_e:\n", " mask[i, :, x:].fill_(1)\n", " return mask" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 1, 120]), torch.Size([2, 1, 7]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention = Attention(encoder.n_direction*HIDDEN, method='general', device=DEVICE).to(DEVICE)\n", "contexts, attns = attention(enc_hidden.transpose(0, 1), enc_output, lengths.tolist(), return_weight=True)\n", "contexts.size(), attns.size()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV8AAADwCAYAAACniGcOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEDtJREFUeJzt3X+onmd9x/H3N2ltR3/mh9BN1la2WFidsDUqZiAEIiSbRGchlWUdDEatgozqwIoTN1t1lMn+sy7uj7m0O39IoKko07gWM2oJTW0YKLWpYDeZG01C2Fptac757I/znORpes557qd5zv0jvl/lgXP//nIaPue6r+e6r7uSIElq17quC5CkX0aGryR1wPCVpA4YvpLUAcNXkjpg+EpSBwxfSeqA4StJHTB8JakDl3RdgCRdiJ07d+bEiRON9n3yySe/lWTnGpfUiOEradBOnDjBE0880WjfdevWbV7jchozfCUN3sIA56gxfCUNWoAhThBm+EoauBBmE75VtRe4DZgHHk9y33nb7wcWgI3AN5I8MFr/HeDZsV3vTnJ6tWsZvpKGLTC/cOHhW1VXAbcDu5KkqvZX1ZYkx89eKvnwaN8CDgMPjG27c5rrGb6SBi1M1ee7uaqOji3vS7Jv9PM24FDO9WEcBLYDx3mty4BTY8svVNU9wI3A4SRfmVSI4Stp8Kbo8z2RZOsK2zbx6kA9BWxZYd97gbNdEkneD2dbxPdX1Y+TPLJaIT5kIWnwkjT6THAS2DC2vHG07lWq6i7gqSSPLVNHgK8Db5t0McNX0qAlYaHhZ4IjwI5R6xVgN4v9umdV1UeAF5M8uMp53g1MHHhst4OkwZvFULMkp6tqPzBXVWeAY0meXtpeVduAu4FvVtWXR6s/neT5qvoicCVwOXBkuVbx+QxfSYMWYH5G43yTzAFz4+uq6gCwJ8n3gOtXOO7j017L8JU0eGv5kEWSW9fivIavpMHz8WJJaluzkQy9Y/hKGjTndpCkjswvLHRdwtQMX0kDN7uJddpk+EoatARmMK9O6wxfSYNnn68kdcDwlaSWTTmlZG8YvpKGLXG0gyR1wW4HSWpZwKFmktQFh5pJUgfsdpCkDhi+ktSyONpBkrphy1eSWuZDFpLUEYeaSVIHHGomSS1LwoJfuElS++zzlaQOONpBkjpg+EpSy5LY7SBJXXComSS1LMD8AMeaGb6SBs8+X0nqgH2+ktS2xJavJLUt2O0gSZ2w20GSOmD4SlLLnM9Xkrowwy/cqmovcBswDzye5L7ztt8PLAAbgW8keWC0fgdwF/Ai8NMkH5t0LcNX0uBN0fLdXFVHx5b3JdkHUFVXAbcDu5KkqvZX1ZYkx5d2TvLh0b4FHAYeGP38SeD3k7xcVfdW1XuSHFqtEMNX0qBNOdrhRJKtK2zbBhzKuZMdBLYDx5fZ9zLg1OjntwA/TPLyaPkh4AOA4Svp4jajtxdv4lygMvp5ywr73gssdUksd9ymSRdb9zoKlKQeSeP/JjgJbBhb3jha9ypVdRfwVJLHpjnufIavpEFLmn8mOALsGPXhAuxmsV/3rKr6CPBikgfHVj8LvLWqLhstvw/47qSL2e0gafBmMdQsyemq2g/MVdUZ4FiSp5e2V9U24G7gm1X15dHqTyd5vqruAR6sqheA54FvT7qe4Stp8GY11CzJHDA3vq6qDgB7knwPuH6F4x4FHp3mWr0N30nj7bpSVeuBzwK3JNnZdT3jVhqD2LWq+hJwKXAF8EySv+q2onOq6hLgn4D/S/KhrusBqKqnWLwFBjgDfDQ9mbygqn4D+NRocR74TJL/6rCkNX/IIsmta3HeXoZvk/F2HXov8DDwzq4LOd9yYxC7rWhRko8s/VxVX62qm5L8qMuaxvwl8I/Ano7rGHcyyZ1dF3G+0b+rvwE+lOTUpP1b46vjZ2qa8XatSnIQ4FyffC+Nj0HsjaraALwR+J+uawGoqj8CjgLPdF3LedZX1RdYvMX9WpKHui5o5O3AfwKfHzWQHk3yDx3XtKgfNwZT6Wv4TjPeTq81Pgaxc1X1m8Bfs/hH9a4kpzsuiar6HeC6JP9cVTd2XM6rJNkOUFWXAl+rqh/05K7vRuCtwO4kL1XV/VX1oyT/1nFdZICvEerrULPXNW5Oy45B7FySZ5PsZfEP6N6quq7rmoAPAjeNvrX+HPB7o2FEvZHkFRafkrq561pGfs7iHelLo+WHgVs6rOesGQ01a1Vfw3fieDu91gpjEHsjyRlgPfCGHtTyiSQfGvWtfgp4LMmXuq5rGe8CjnVdxMiTwDvGlt8J/HtHtZy1GKxp9OmTXnY7TBpv1xOvdF3AuNXGIHZYFlX1u8DHgBeAq4EDSf6jy5qWMc/iqIJeqKqvAr8ArgQeSvKTbitalORnVfXtqppjcfaunyR5pOu6YJhvsqghFi1JS9580035zJf+vtG+f7pj+5OrTKzTql62fCWpqaVuh6ExfCUNnuErSV0wfCWpfQPM3t4ONQOgqu7ouoaV9LU265peX2uzroYSFuYXGn36pNfhC/Trf/Kr9bU265peX2uzrgaWXiPkOF9JalnfgrWJmYfvr1xxZa65duLrixq5+pqNXPemG3r5W51lbevWzW6Snms2bOTXfv3G3v3OrtmwaWZ1zXpSo2s2bOJN17/5gmubdQDM8nc2S1fP8N/Yz3763Ikkb7zQ8xi+wDXXbuJP7vzErE97Ubv8issm79SBWtfPXqn1l6zvuoRlzZ+Z77qEFfWtv3PJvX/xZ89d8EkSGODEOnY7SBo8W76S1LIAC7Z8JallPl4sSd0Y4mTqhq+kgevfGN4mDF9Jg2f4SlLLnFJSkjqSecNXklpny1eS2tbDSXOaMHwlDZ7hK0ktW5pScmgMX0nDFkhPJw5ajeEraeDs85WkTgwwe5uFb1XtBW4D5oHHk9y3plVJ0hQuypZvVV0F3A7sSpKq2l9VW5IcX/vyJGl1ycU7sc424FDO/Wk5CGwHzobv6G2md8Di63UkqU2zavlOusuvqvXAZ4FbkuwcW/8d4NmxXe9Ocnq1azUJ303AqbHlU8CW8R2S7AP2Ab1955qki1VYWLjw0Q4N7/LfCzwMvPM1VSR3TnO9JuF7Erh5bHnjaJ0kdW+6iXU2V9XRseV9o8YjNLjLT3IQln2J6wtVdQ9wI3A4yVcmFdIkfI8Af15Vfzcqajfw+QbHSVI7mvf5nkiydYVtE+/yV5Lk/QC1mMr3V9WPkzyy2jETwzfJ6araD8xV1RngWJKnmxQkSWtt8Qm3mZzqgu/yR90VXwfeBqwavo3eDZ5kLskHk/xxkr+dphhJWmsZTa4z6TPBEWBHnetT2A0cfh3lvBt4YtJOPmQhadgSFmbwePGUd/mvjC9U1ReBK4HLgSNJHpt0PcNX0uDNaqhZkjlgbnxdVR0A9iSZH9tv13nHfXzaaxm+kgZtrWc1S3LrWpzX8JU0bDP8xq1Nhq+kgXNWM0nqRIY3na/hK2ngwkweL26b4Stp0HyNkCR1xPCVpNblop3PV5L6a7pZzXrD8JU0fIavJLUrwILdDvCm6zZz7yfvmPVpL9gbLunv35mfv/xy1yUs639feqnrEpb1q9de23UJ6pOL+B1uktRjPuEmSZ0wfCWpA4avJLUsgcxgMvW2Gb6SBm+ADV/DV9LQ+YWbJHXC8JWktvl4sSS1L/iQhSR1IMTJ1CWpZXY7SFI3Bpi9hq+k4bPPV5Ja5jvcJKkL9vlKUhcyyFfHr2uyU1Wtr6rPVdW/rHVBkjStLKTRp08ahS/wXuBhbClL6pvFTt9mnx5pFKZJDgJU1dpWI0lTWsreoZlJS7aq7gDuALj++utncUpJamyIX7g17XZYVZJ9SbYm2bp58+ZZnFKSmklYmF9o9OkT+3AlDd4QW77Thu8ra1KFJL1Os3zIoqr2ArcB88DjSe47b/t64LPALUl2jq3fAdwFvAj8NMnHJl1rqm6HJLum2V+S2pCk0Wc1VXUVcDvwviR/CPx2VW05b7fXjPyqxZEInwQ+kGQP8POqes+kmmfS5ytJ3Wk4zGwxfDdX1dGxzx1jJ9oGHMq5lD4IbH/VlZKDSY6cV8BbgB8meXm0/ND5xy3HPl9JwxZI8+/STiTZusK2TcCpseVTwPkt36bHbZp0kOErafBm9HjxSeDmseWNo3VNjtsw7XF2O0gatKUv3C60zxc4Auyoc0+T7QYONyjhWeCtVXXZaPl9wHcnHWTLV9KwzWhWsySnq2o/MFdVZ4BjSZ5eYfdXxo6br6p7gAer6gXgeeDbk65n+EoauNlNmpNkDpgbX1dVB4A9SebH9tt13nGPAo9Ocy3DV9LwreFDFkluXYvzGr6SBi9c/E+4SVKvJGFhYX7yjj1j+EoavF+GuR0kqXcMX0nqgOErSS1bfICiX3P1NmH4Sho8wxf4/ve/f+KySy99bkan2wycmNG5Zq2vtVnX9Ppa2y9DXTfM4iR2OwBJ3jirc1XV0VVmIOpUX2uzrun1tTbras7wlaTW2ecrSa3LjCbWaVvfw3df1wWsoq+1Wdf0+lqbdTU0xPCtIRYtSUuuvnpT3vH2P2i0778+sv/JvvRX973lK0kTBft8Jal1Q7yDN3wlDZpfuElSJxq9n613DF9Jg+d8vpLUAVu+ktS2xU7frquYmuEradCC73CTpE44t4Mktc7RDpLUiYUFW76S1KrF79sMX0lqmd0OktQNw1eS2udQM0nqgN0OktSyJM7tIEldsOUrSR0wfCWpA4avJLUuMKOHLKpqL3AbMA88nuS+Jtur6ingyGi3M8BHM+EvguEradASWJhB+FbVVcDtwK4kqar9VbUlyfEG208muXOa66274IolqWNJGn2AzVV1dOxzx9hptgGHxlqsB4HtDbevr6ovVNWDVfX+JjXb8pU0cJlmbocTSbausG0TcGps+RSwpcn2JNsBqupS4GtV9YOlFvNKbPlKGrwpWr6rOQlsGFveOFrXdDtJXgEOATdPupjhK2nwZhS+R4AdVVWj5d3A4Sm2L3kXcGzSxex2kDRoi1NKXvhQsySnq2o/MFdVZ4BjSZ5usr2qvgr8ArgSeCjJTyZdr4Y4Pk6Sllx++RW54YbfarTvM88cfXKVPt9lVdUBYE+SmT7DbMtX0uCtZSMyya1rcV7DV9LgDfEO3vCVNHC+yUKSWuc73CSpI7Z8Jal1Ib46XpLa5zvcJKkD9vlKUstm9YRb2wxfSQPnUDNJ6sSCX7hJUvvs85Wkti12+nZdxdQMX0mDFhxqJkmd8As3SeqAfb6S1Lo42kGS2uZDFpLUEcNXkloXsM9XktrnUDNJ6oDdDpLUsiQsLMz0re6tMHwlDZ4tX0nqgOErSR0wfCWpC4avJLUrCQvxCzdJap3dDpLUAcNXklrnCzQlqRPO5ytJLXNKSUnqRGz5SlIXDF9J6sCsuh2qai9wGzAPPJ7kvibbJx23HMNX0tB9K8nmhvteXlVHx5b3JdkHUFVXAbcDu5KkqvZX1ZYkx1fbDvz3asetxPCVNGhJds7oVNuAQznXjD4IbAeOT9j+3ITjlrVuRkVL0tBtAk6NLZ8arZu0fdJxyzJ8JWnRSWDD2PLG0bpJ2ycdtyzDV5IWHQF2VFWNlncDhxtsn3TcsuzzlSQgyemq2g/MVdUZ4FiSp5tsX+24ldQQnwyRpLZU1QFgTzLbeSsNX0nqgH2+ktQBw1eSOmD4SlIHDF9J6oDhK0kd+H+2vUjYm60KNQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "ax = fig.add_subplot(111)\n", "cax = ax.matshow(attns.detach().squeeze(1).numpy(), cmap='bone')\n", "fig.colorbar(cax)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decoder model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![deocder](./pics/translation_decoder.png)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class Decoder(nn.Module):\n", " \"\"\"Decoder\"\"\"\n", " def __init__(self, vocab_size, embed_size, hidden_size, n_layers=1, sos_idx=2, drop_rate=0.0, layernorm=False, method='general', teacher_force=False, device='cpu', return_w=False):\n", " super(Decoder, self).__init__()\n", " self.vocab_size = vocab_size\n", " self.n_layers = n_layers\n", " self.hidden_size = hidden_size\n", " self.device = device\n", " self.sos_idx = sos_idx\n", " self.return_w = return_w\n", " self.teacher_force = teacher_force\n", " self.layernorm = layernorm\n", " \n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.dropout = nn.Dropout(drop_rate)\n", " self.attention = Attention(hidden_size, method=method, device=device)\n", " self.gru = nn.GRU(embed_size+hidden_size, hidden_size, n_layers, bidirectional=False, \n", " batch_first=True)\n", " self.linear = nn.Linear(2*hidden_size, vocab_size)\n", " if layernorm:\n", " self.l_norm = nn.LayerNorm(embed_size)\n", " \n", " def start_token(self, batch_size):\n", " sos = torch.LongTensor([self.sos_idx]*batch_size).unsqueeze(1).to(self.device)\n", " return sos\n", " \n", " def init_hiddens(self, batch_size):\n", " return torch.zeros(batch_size, self.n_layers, self.hidden_size).to(self.device)\n", " \n", " def forward(self, hiddens, enc_output, enc_lengths=None, max_len=None, targets=None, \n", " is_eval=False, is_test=False, stop_idx=3):\n", " \"\"\"\n", " * H_d: decoder hidden_size = encoder_gru_directions * H\n", " * M_d: decoder embedding_size\n", " Inputs:\n", " - hiddens: last encoder hidden at time 0 = 1, B, H_d \n", " - enc_output: encoder output = B, T_e, H_d\n", " - enc_lengths: encoder lengths = B\n", " - max_len: max lenghts of target = T_d\n", " Outputs:\n", " - scores: results of all predictions = B*T_d, vocab_size\n", " - attn_weights: attention weight for all batches = B, T_d, T_e\n", " \"\"\"\n", " if is_test:\n", " is_eval=True\n", " inputs = self.start_token(hiddens.size(1)) # (B, 1)\n", " inputs = self.embedding(inputs) # (B, 1, M_d)\n", " if self.layernorm:\n", " inputs = self.l_norm(inputs)\n", " inputs = self.dropout(inputs)\n", " # match layer size: (1, B, H_d) > (n_layers, B, H_d)\n", " if hiddens.size(0) != self.n_layers:\n", " hiddens = hiddens.repeat(self.n_layers, 1, 1)\n", " # prepare for whole target sentence scores\n", " scores = [] \n", " attn_weights = []\n", " for i in range(1, max_len):\n", " # contexts[c{i}] = alpha(hiddens[s{i-1}], encoder_output[h])\n", " # select last hidden: (1, B, H_d) > transpose: (B, 1, H_d) > attention\n", " contexts = self.attention(hiddens[-1:, :].transpose(0, 1), enc_output, enc_lengths, \n", " return_weight=self.return_w)\n", " \n", " if self.return_w:\n", " attns = contexts[1] # attns: (B, seq_len=1, T_e) \n", " contexts = contexts[0] # contexts: (B, seq_len=1, H_d)\n", " attn_weights.append(attns)\n", " \n", " # gru_inputs = concat(embeded_token[y{i-1}], contexts[c{i}]): (B, seq_len=1, H_d+M_d)\n", " gru_inputs = torch.cat((inputs, contexts), 2)\n", " \n", " # gru: s{i} = f(gru_inputs, s{i-1})\n", " # (B, 1, M_d+H_d) > (n_layers, B, H_d)\n", " _, hiddens = self.gru(gru_inputs, hiddens) \n", " \n", " # scores = g(s{i}, c{i})\n", " # select last hidden: (1, B, H_d) > transpose: (B, 1, H_d) > concat: (B, 1, H_d + H_d) >\n", " # output linear : (B, seq_len=1, vocab_size)\n", " score = self.linear(torch.cat((hiddens[-1:, :].transpose(0, 1), contexts), 2))\n", " scores.append(score)\n", " \n", " if (self.teacher_force and not is_eval):\n", " selected_targets = targets[:, i].unsqueeze(1)\n", " else:\n", " selected_targets = None\n", " \n", " inputs, stop_decode = self.decode(is_tf=self.teacher_force, \n", " is_eval=is_eval,\n", " is_test=is_test,\n", " score=score, \n", " targets=selected_targets, \n", " stop_idx=stop_idx)\n", " if stop_decode:\n", " break\n", " \n", " scores = torch.cat(scores, 1).view(-1, self.vocab_size) # (B, T_d, vocab_size) > (B*T_d, vocab_size)\n", " if self.return_w:\n", " return scores, torch.cat(attn_weights, 1) # (B, T_d, T_e)\n", " return scores\n", " \n", " def decode(self, is_tf, is_eval, is_test, score, targets, stop_idx):\n", " \"\"\"\n", " - for validation: if is_tf, set 'is_eval' True, else False\n", " - for test evaluation: set 'is_tf' False and set 'is_eval' True\n", " \"\"\"\n", " stop_decode = False\n", " if is_test:\n", " # test\n", " preds = score.max(2)[1]\n", " if preds.view(-1).item() == stop_idx:\n", " stop_decode = True\n", " inputs = self.embedding(preds)\n", " else:\n", " # train & valid\n", " if is_tf and not is_eval:\n", " assert targets is not None, \"target must not be None in teacher force mode\"\n", " inputs = self.embedding(targets)\n", " else:\n", " preds = score.max(2)[1]\n", " inputs = self.embedding(preds)\n", "\n", " if self.layernorm:\n", " inputs = self.l_norm(inputs)\n", " inputs = self.dropout(inputs)\n", " return inputs, stop_decode" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "decoder = Decoder(len(trg_vocab), EMBED, encoder.n_direction*HIDDEN, n_layers=DEC_N_LAYER,\n", " drop_rate=DROP_RATE, method=METHOD, layernorm=L_NORM, \n", " sos_idx=trg_vocab[''], teacher_force=TF, \n", " return_w=RETURN_W, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([8, 7]), torch.Size([2, 4, 7]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output, attns = decoder(hiddens=enc_hidden, enc_output=enc_output, max_len=targets.size(1),\n", " targets=targets)\n", "output.size(), attns.size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "LR = 0.01\n", "LAMBDA = 0.00001\n", "DECLR = 5.0\n", "STEP = 5" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "loss_function = nn.CrossEntropyLoss(ignore_index=trg_vocab[''])\n", "enc_optimizer = optim.Adam(encoder.parameters(), \n", " lr=LR, \n", " weight_decay=LAMBDA)\n", "dec_optimizer = optim.Adam(decoder.parameters(), \n", " lr=LR * DECLR, \n", " weight_decay=LAMBDA)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "STEP: 0, loss: 1.9032\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWwAAADwCAYAAAAkTF41AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFl9JREFUeJzt3X+sHfV95vH3Y5OYGlAW2wkUYoekdVcov2NrqyVSd5FcCUcNoWEXs0tJLSQM7S5bhW4jQgE5JT8oqzZdIkpr1oXYRKa0tEA2bFKytUIgrIsB76Lll52GEJpG+NqBdQw19j3P/nHmmPHNvT5z7Ln3znieFzrCM/Odmc+5wMPc73znO7JNREQ035zZLiAiIqpJYEdEtEQCOyKiJRLYEREtkcCOiGiJBHZEREsksCMiWiKBHRHREgnsiIiWOG62C4iImGnnnHOOx8bGKrV97LHHvmH7nGkuqZIEdkR0ztjYGI8++miltnPmzFk0zeVUlsCOiE7qtXAepQR2RHSOgTZOfJfAjogOMiaBHRHRfIbxXgI7IqLxTH192JIuAlYB48Ajtm+csP0WoAcsAL5m+45i/TeBHaWmV9l++XDnSmBHRCfV0Yct6STgYmClbUvaKGmp7e2l8/xG0VbAg8AdpW2Xj3K+PDgTEZ1ku9IHWCRpa+mzpnSYs4AH/Eb63wucPcUp5wG7S8s/kXR9EfKXVqk5V9gR0Tm2R+kSGbO9fIptCzk0hHcDS6do+1ngYHeJ7fPg4JX3LZK+a/tvD1dIrrAjopNGuMI+nF3AyaXlBcW6Q0j6JPCE7YcnqcPAV4H3DTtZAjsiOsfAuF3pM8QWYEVxlQxwLv1+6oMk/Saw1/ZXDnOcXwKGPnqZLpGI6KQ6bjraflnSRmCTpAPANtvPDLZLOgu4Crhf0p8Uq6+1vVPSHwAnAscDWya7+p4ogR0RnVTXsD7bm4BN5XWS7gYusP0dYMkU+/32qOdKYEdE91Trnz6Kw/v86ThuAjsiOidziUREtMh4rzfbJYysVaNEJC2RtFfS0tK6fyfpYkmHuwM7KyTdKunU2a6jDsfSd4kYTP5U5a8maVVg06/3f9IfgD4wF5hr+6LZKemw3sSx81vMsfRdouNs6FX8NEkb/wN8CXhU0oW27xyslHS/7Y9IWgR8DniVfsD/LrAcuAT4MfCK7eumu0hJq4F/AayVtAG4FNgL/DPgDtv/fZrP/w7gz4DHgUXAd4APAvuB/bb/s6TPAW+hP7Tov9l+SNJ7gU8BY/T/Z/hJ+nMlDL7LTbb/T001/jHwh7Z3SFpVnG8lE35Okq4Fvmn7kWK/+21/pI4aorvShz1zbgXulnR/ad2bi7//PvD7tp+Dg499XgessN2T9AVJy2w/Np0F2r5d0r8G1hY1fdb205LmAF+T9LDtH09jCQIO2P4dAEkvAMuK8Z9fLmYY22P7dyUdR/9Jq5VFrf/W9l5JlwPnlb+L7RdrrHETcCH935g+Tv/fx2sm/pwofosq7ffmnzpSxIgS2DOkCN7fAz4DTAzexYOwLrwVOAX4fPEw0ikc+ijpTDjF9tNwsPZtwM8BW6f5vD8o/Xm77Z3Fn/cApwMfkHRDsW5f8fefB64tflZvAYYO5j8KDwG/I+kU+r/9/PwUP6eJ5k6yLqKyOqdXnUmtDGwA209I+jVgGfC/S5v+UdJ7bT9ZLI8BPwKutj3Tt4XH6f+M/1HSmaUrx/cDX5jhWib+22ngf9n+ownr/x74jO3XJqwffJf6CupPR7mN/hX2BmDNFD+nV4Cfhf6N58GfI46Y3cpRIm0L7PHiM7AWeJJ+P+3rxbpPAf9F0p6i7dXAfwX+XNIu+t0E/3GG6v0W8EfAXwOflvQq/avWm23/v2k+98Sf1f4J2/4K+C1Jf0b/6vqhYq6DzwAbJY3Rv9F4he1XB99F0nrbX62xzo3A14E1wHeBGyb+nCTdCdwq6V8V36OWPvTotjZ2iaiNRcexQ9LPAb9m+zOzXUt0x3ve/37/1Te+XqntP//Z0x47zPSqM6ptV9hxDJH0b4BfBX5rtmuJ7mnakL0qEtgxa2z/JfCXs11HdFMbexcS2BHRSQnsiIgWcEtHibTt0fShJrwgs1GaWlvqGk1T64Lm1tbEump6RdiMOuYCm/7wsKZqam2pazRNrQuaW1uj6ho8OFPl0yTpEomITmraTHxVNCKwFy5a5CVLJn2LzsgWL17MBz/0oVr+SdT969DbFy/mAx/8YC0HnTOnvl+OFi9ezIeWLaulLg1vUtmSJUtYVlNdjz/+eB2HOUhSY/9rb2ptNdY1ZvutR3uQDOs7QkuWLOFbD0/nlBVHZt/+/cMbzZIT5s2b7RImddzcZk7zMe9Nb5rtEqI+3z/aA9im18Kbjo0I7IiImda0/ukqEtgR0UlNGwFSRQI7IjopgR0R0QJu4JC9KhLYEdFJGdYXEdECBsZbOK4vgR0RnZQ+7IiIlkgfdkREGzRwYqcqEtgR0TkmXSIREa2RLpGIiJZIYEdEtMBgPuy2SWBHRPd05aajpLfZfmlIm1Nt/+jIy4qImF51XWFLughYBYwDj9i+ccL2W4AesAD4mu07ivUrgE8Ce4EXbV857FwjzYIv6TTgttLynZIWTtL0FklnjHLsiIiZMhglcrTvdJR0EnAx8DHbvwq8V9LSQ85l/4bt/wD8e+CyYj8BnwY+bvsC4FVJvzys7sqBLel0YD1webG8AJhre9ckzS8Dbpb0zqrHj4iYSeO9XqUPsEjS1tKn/H7Ks4AH/Eay3wucPcUp5wG7iz//AvCU7X3F8j2H2e+gSl0iRVjfClxq+8Vi9YXAXZLmA38M/Bh4xfZa2y9JWg3cJuk/2f77SY65huLFnIsXL65SRkRETTzK5E9jtpdPsW0hb4QwxZ+XTtH2s8Cgu2Sy/SbrrTjE0CvsostjYlgDfBS4D3h7cZwrba8dbLS9E1gN3FR0pRzC9jrby20vX7ho0bAyIiJqY1f/DLELOLm0vKBYdwhJnwSesP3wKPtNVKVLZB79zvRXSidfCrxge5/t5+j3a98kadWEffcCrwPzK5wnImLG9Io5sYd9htgCrCj6pAHOBR4sN5D0m8Be218prd4BvEfS4OWsHwO+NexkQ7tEbP9Q0tXABkm/bnsP8AngjlKbzcBmSf9D0v229xRdJV8GbrC9Y9h5IiJmUh3D+my/LGkjsEnSAWCb7WcG2yWdBVwF3C/pT4rV19reKel64CuSfgLsBP5m2Pkq9WHbflLSNfRDezXwYeC6oqAz6d/tfA14vgjrE4ANwOdtP1bpm0dEzJA6H5yxvQnYVF4n6W7gAtvfAZZMsd9mYPMo56o8Dtv2U5Kuon8389uDu6K2n6Z/xV22Hrje9rZRiomImBE2vf4IkGk6vM+fjuOO9OCM7WclXWL7e0OaXmt7+1HUFRExvbrwpGOFsCZhHRFN57wiLCKiHVp4gZ3Ajoju6Y+xbl9iJ7AjopMS2BERrWB649M3SmS6JLAjonPSJRIR0SIJ7IiItkhgR0S0QwvzOoEdER3k3HSMiGiFwSvC2qYRgT3e6/HKq6/Odhk/5aSf+ZnZLmFK8+fNG95oFrx+4MBslxBRSQI7IqIlEtgREW1gQyZ/iohoh1xhR0S0gIFerrAjIlogj6ZHRLRHXmAQEdEKzhV2RERbJLAjIlog06tGRLSIxxPYERGtkCvsiIg2cG46RkS0RgI7IqIFMr1qRERbGJwXGEREtEE7+7DnjLqDpLdVaHPqkZUTETEz+mOxh3+aZKTAlnQacFtp+U5JCydpeoukM46utIiI6eNipMiwT5NUDmxJpwPrgcuL5QXAXNu7Jml+GXCzpHfWUmVERI3s/uRPVT5NUimwi7C+FbjU9g+K1RcCd0maL+l2SV+UtBbA9kvAauBLkt41xTHXSNoqaevuXZNlfkTE9GnjFfbQm45Fl8cgrP+htOmjwHnAO+gH/5UufTvbOyWtBm6XtMb2D8vHtb0OWAfwvg98oFk/lYg4xpler55RIpIuAlYB48Ajtm+csH0u8HvAMtvnlNZ/E9hRanqV7ZcPd64qo0TmFYW8UjrRUuAF2/uA5yTdBtwk6SHbf17ady/wOjC/wnkiImZGTZM/SToJuBhYaduSNkpaant7qdmvAPcBv/hTZdiXj3K+oV0ixZXx1cCGojiATwB3lNpstn0FsHrQRtJ8YANwg+0dREQ0Sc/VPrBo0H1bfNaUjnIW8ECpd+Fe4OzyaWzfa3vLJBX8RNL1RchfWqXkSuOwbT8p6Rr6ob0a+DBwHYCkM4FPA68Bz9veI+kE+mH9eduPVTlHRMRM6T/pWLn5mO3lU2xbCOwuLe8GllaqwT4PQJLoj6z7ru2/Pdw+lR+csf2UpKuAe4BvD/6PYvtp+lfcZeuB621vq3r8iIiZVNMNxV3Au0vLC4p1o9RhSV8F3gccNrBHGodt+1ngEtvXDWl6bcI6IhrLpjfeq/QZYguworhKBjgXePAIKvol4NFhjUZ+NN329yq02T6sTUTEbKrjCtv2y5I2ApskHQC22X5miub7ywuS/gA4ETge2GL74WHny1wiEdE5dc7WZ3sTsKm8TtLdwAW2x0vtVk7Y77dHPVcCOyK6Z8S7jiMf3j5/Oo6bwI6IDmreU4xVJLAjopPcvumwE9gR0UGmtkfTZ1ICOyI6J68Ii4hokQR2REQrNG+u6yoS2BHRPTXN1jfTEtgR0U0J7IiI5jPQS5fIkTlu7lwWnnjibJfxU0xz/4Hu+ad/mu0SJnVgfHx4o4jZVrzTsW0aEdgRETMrTzpGRLRGAjsioiUS2BERLWCDh7+coHES2BHRSS28wE5gR0QX5aZjRERrJLAjItogj6ZHRLSDyYMzEREtYZwXGEREtEC6RCIi2qOFeZ3AjohuSh92REQL5J2OERFtkT7siIi2ML0WjhKZM+oOkt5Woc2pR1ZORMTMcM+VPk0yUmBLOg24rbR8p6SFkzS9RdIZR1daRMQ06XdiV/s0SOXAlnQ6sB64vFheAMy1vWuS5pcBN0t6Zy1VRkTUqKV5XS2wi7C+FbjU9g+K1RcCd0maL+l2SV+UtBbA9kvAauBLkt41xTHXSNoqaevYzp1H+z0iIkZiu9KnSYbedCy6PAZh/Q+lTR8FzgPeQT/4r3Tp29neKWk1cLukNbZ/WD6u7XXAOoAPLVvWrJ9KRBzbbHotfIFBlSvsecA48MpghaSlwAu299l+jn6/9k2SVk3Ydy/wOjC/pnojImrRxivsoYFdXBlfDWyQdFKx+hPAHaU2m21fAawetJE0H9gA3GB7R+2VR0QcocGDM8dcYAPYfhK4hn5ovwX4MPAQgKQzJW2Q9KfA87b3SDoB2Eg/rP9ummqPiDhidQW2pIsk3SfpryV9apLtcyV9TtLXJ6xfIelrku6S9IdVaq784IztpyRdBdwDfHvQX237afpX3GXrgettb6t6/IiImVPPEJCiR+FiYKVtS9ooaant7aVmvwLcB/xiaT8BnwY+YnufpM9K+mXbDxzufCM96Wj7WUmX2P7ekKbXTig4IqI5DK5+z3GRpK2l5XXFoAmAs4AHSgMu7gXOBg7mn+17AfoZfdAvAE/Z3lcs3wN8HKgvsIuTDwtrEtYR0XQjPJo+Znv5FNsWArtLy7uBpRWOOdl+kz2EeIjMJRIRnVPjbH27gHeXlhcU66rsd/Ko+408l0hEROu5tpuOW4AVeqO/41zgwQoV7ADeI2lesfwx4FvDdsoVdkR0UD0TO9l+WdJGYJOkA8A2289M0Xx/ab9xSdcDX5H0E2An8DfDzpfAjohuqmmMte1NwKbyOkl3AxfYHi+1Wzlhv83A5lHOlcCOiE4y0/dQjO3zp+O4CeyI6Bzb9Hrjwxs2TAI7IjqpaY+dV5HAjohOSmBHRLREAjsiogX6Y6zbNx92AvswXnt9//BGs+S4uXNnu4RJnTBv3vBGEQ2QwI6IaIl0iUREtEQCOyKiFdKHHRHRCnausCMiWiOBHRHRCsbVX2DQGAnsiOgkk8COiGiFdIlERLRAbjpGRLRGpdd/NU4COyI6KfNhR0S0RK6wIyLaoN+JPdtVjCyBHRGdY6b3nY7TJYEdEZ2UuUQiIlqhI6NEJL3N9ktD2pxq+0dHXlZExPTqtfDR9DmjNJZ0GnBbaflOSQsnaXqLpDOOrrSIiOnRv+fYq/RpksqBLel0YD1webG8AJhre9ckzS8Dbpb0zlqqjIiolYv3Og7/NEmlwC7C+lbgUts/KFZfCNwlab6k2yV9UdJagKLLZDXwJUnvmuKYayRtlbR1bOfOo/0eERGjGQztG/ZpkKGBXXR5DML6xdKmjwL3AW8vjnOl7bWDjbZ30g/tm4qulEPYXmd7ue3li9761qP6EhERo3LFv5qkyhX2PGAceGWwQtJS4AXb+2w/R79f+yZJqybsuxd4HZhfU70REbU4JrtEbP8QuBrYIOmkYvUngDtKbTbbvgJYPWgjaT6wAbjB9o7aK4+IOEK26fXGK32apFIftu0ngWvoh/ZbgA8DDwFIOlPSBkl/Cjxve4+kE4CN9MP676ap9oiII9bGK+zK47BtPyXpKuAe4Nsuvontp+lfcZetB663va22SiMiatS0MK5ipAdnbD8r6RLb3xvS9Frb24+iroiIaXXMBzZAhbAmYR0RzWZo2EMxVWQukYjoHBt6NQW2pIuAVfRH0z1i+8Yq2yU9AWwpmh0ArvCQy/4EdkR00ghdIoskbS0tr7O9DqAYFXcxsNK2JW2UtHTQyzBk+y7bl49ScwI7IjrIo8wTMmZ7+RTbzgIeKF0Z3wucDWyvsH2upC8AS4C/sH3PsEIS2BHRSTXddFwI7C4t7waWVtlu+2wASW8C/kLS/x12/2+k2foiIo4VNY3D3gWcXFpeUKyruh3b+4EHgHcPO1kCOyI6pz+vUy2BvQVYIUnF8rnAgyNsH/iXwNDnVtIlEhEdZOyjf+zc9suSNgKbJB0Attl+psp2SV8GXgNOBO6x/fyw8yWwI6KT6npwxvYmYFN5naS7gQtsj0+2vdjv10c9VwI7IjppOp90tH3+dBw3gR0RHdS8iZ2qSGBHROcM3unYNo0I7Ccef3zshOOP/35Nh1sEjNV0rLo1tbbUNZqm1gXNra3Out5Rx0FyhX2EbNf2jjBJWw/zVNKsamptqWs0Ta0Lmltb8+oy7uUKOyKiFZr2vsYqEtgR0Unpw26GdbNdwGE0tbbUNZqm1gXNra1RdQ2edGwbtbHoiIij8eY3H+9TTjmjUtsXX3z2sab0vx+LV9gREUP1ctMxIqId0ocdEdEG/U7s2a5iZAnsiOgck2F9ERGt0cYBFwnsiOik9GFHRLSCM0okIqIN2vrgTAI7IjopgR0R0QqG9GFHRLRDhvVFRLREukQiIlrANr3e+GyXMbIEdkR0Uq6wIyJaIoEdEdESCeyIiLZIYEdENJ9tes5Nx4iIVkiXSERESySwIyJawQnsiIi2yHzYEREtkOlVIyJaw7nCjohoiwR2RERL1NUlIukiYBUwDjxi+8Yq24ftN5kEdkR00TdsL6rY9nhJW0vL62yvA5B0EnAxsNK2JW2UtNT29sNtB350uP2mksCOiM6xfU5NhzoLeMBvXK7fC5wNbB+y/ftD9pvUnJqKjojoooXA7tLy7mLdsO3D9ptUAjsi4sjtAk4uLS8o1g3bPmy/SSWwIyKO3BZghSQVy+cCD1bYPmy/SaUPOyLiCNl+WdJGYJOkA8A2289U2X64/aaiNj7tExHRZJLuBi6w653DNYEdEdES6cOOiGiJBHZEREsksCMiWiKBHRHREgnsiIiW+P/XEhOQlMa3/wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "STEP: 1, loss: 0.7809\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAADwCAYAAAAKCX+nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFj1JREFUeJzt3X2wXXV97/H3h8jDhBIgCSbILaKSWkapbUl7p7RzK3eYuVCVYlFCG6MZRyN1tFanWEvBi08t0jva0qptEIuJND7Rik+jQ31ClAKhpOVWHlsVqCg5iVgELiZnf+4fe21Y7JyHtfdeZ++VdT4vZk/2Wuu3f/u7T8J3/853rfX7yTYREdE+B0w6gIiIWBhJ8BERLZUEHxHRUknwEREtlQQfEdFSSfARES2VBB8R0VJJ8BERLZUEHxHRUk+ZdAAREfur0047zVNTU5Xa3nzzzV+0fdoCh/QkSfAREUOampripptuqtT2gAMOWLnA4ewjCT4iYgSdBs/nlQQfETEkA02esDEJPiJiaMYkwUdEtI9hupMEHxHROiY1+IiI1koNPiKipZLgIyJayHZKNBERbZURfERECxmYToKPiGinjOAjIloqNfiIiDayM4KPiGijzEUTEdFi053OpEOY1aJd0UnSsZIelrSmtO+3JW2QdOUkY5uJpMskrZ50HKNqy+eI6HLl/yZh0SZ4up/9S8A7S/uWAEtsr59MSHM6kHb8xtWWzxGBDZ2Kj0lY7P+jPQDcJOkc2x/t7ZT0edu/IWkl8C7gEbpfCH8MrAVeCfwQ+JHtty50kJI2Ar8MXCRpC/Bq4GHgCOAjtj+7gO/9dOBDwD8DK4FvAr8A7AH22P4DSe8CDgd+Cvig7esknQi8GZii+8X5RmBD6XNcavtfa4rx/cB7bN8taV3xfqfT9zOSdCHwj7avL173edu/UUcMsXilBt9slwFXSfp8ad9BxZ/vBt5t+04ASQLeCpxquyPpTyWdZPvmhQzQ9hWSng9cVMT0Ttu3SToA+Jykb9j+4QK9vYC9ts8DkHQPcJLtnZI+LGk98JDtP5b0FOAzdJPru4GX2n5Y0rnAmeXPYfu+GmPcBpxD97ex36L77/qC/p8RxW9opdcdtE9PEQNKgm+wIlG/HXgb0J+of7qX3AtHAauAP+nmelYBR44l0Cessn0bPB77DuBZwPYFfM97S8/vsr2zeP4QcAzw85IuLvY9Vvx5PHBh8XM6HPjGAsZ3HXCepFV0f7M6fpafUb8lM+yLqCzTBe8HbN8i6WXAScC/lA7dL+lE27cW21PA94HzbY/71Pk03b+v+yWdUBqdPg/40zHG0f+v2cA/2f7zvv3/AbzN9qN9+3ufo76AbBdJ/J3AFmDTLD+jHwFHQ/cke+95xNDsRl9Fs5gT/HTx6LkIuJVurfknxb43A38m6aGi7fnAXwAfk7SLbunidWOK92vAnwP/APyRpEfojozfZ/u/FvB9+39Oe/qO/T3wBkkfojt6v872lXR/I9oqaYruidXX236k9zkkXW77MzXGuRX4ArAJ+Hfg4v6fkaSPApdJ+vXic9RyDiAWtyaXaNTk4CKqkvQs4GW23zbpWGLxeO7znue//+IXKrV99tFPu9n22gUO6UkW8wg+WkLSS4AXA2+YdCyx+NR5CWRx0cI6ur8dX2/7kr7jtwA3FJt76f5mPGsESfCx37P9SeCTk44jFqe6qiCSDqN7KfHpxXmlrZLW2L6r1GyX7XOr9rmYb3SKiBiZiwnH5nsAKyVtLz029XV1MnBNaUR+NXBKX5slxeXZV0o6c77YMoKPiBiSB7uKZmqeGvwKYHdpezewptzA9ikAkg4EPiHp3/pG+E+SEfwcZviGbYTENbimxpa4BtPEuAYYwc9nF0++r2Z5sW+m99wDXAM8Z64Ok+Dn1rh/TIXENbimxpa4BtOouHo3OlV5VHADcGpxxzzAGcC1c7T/FWDHXB2mRBMRMYK6Zoq0/aCkrcA2SXuBHbZvL7eR9GHgUbrzPn3K9nfm6rN1CX7ZEUf4qKPruUFx5erVPOuEE2r523vohw/V0Q0AP3XYETx19X+r7eKs//dI/82mwzn44ENZtmxFbXHt3btn/kYVHXjgwSxduqyW2B59tL6/SwBJjbwZZZHENWX7qFE6qPMySdvb6M6t9DhJVwFn2562/YpB+mtdgj/q6KO5ZMuWSYexjy997CuTDmFWd9xy6/yNJmD37u9NOoQZ7djx5UmHEPX57igvtk1ngacqsH3WsK9tXYKPiBinTDYWEdFSTZ7uJQk+ImIESfARES3k6pdATkQSfETECCa1oHYVSfAREUMyMD2pFbUrSIKPiBhBavARES2VGnxERBtVn0hsIpLgIyKGZFKiiYhorZRoIiJaKgk+IqKFevPBN1USfETEsBp+knXBV3SS9Il5jl8uadmo/URETEKNKzrVrrYEL+kWSV8tHrdL6q0GfmBx/EJJJ87w0iX9cUhaL+mcvnYH1hVrREQdelfR1LQma+3qLNF81/aZAJI2Aof3HV9SPKpYBTRzFYqIiJLpBV7wYxTjqMEfL+kP6Cbtqv47sF3SM4EPFfsOma1xsdL6JugusxcRMR5u9GRjddbgn9kr0QB/COwq9t9t+/8AP5jjtZ+V9HsAko4EDgU22f4P28+3/Xzg+7O92PZm22ttr112xBF1fJaIiHnZ1R+TUNsI3vbPSZJnLzYZOKQ4oXo48DN0VwcHeKHtB4vn/xt4O/CLkjbY3lpXjBERdVtMl0l+AnhJ374vFn9eB7wbeAy4H/gW8JlyQ0m/A+y2fSNwo6S/kvQt2zfXHGdERC2afJlk3Ql+n/5sf6D480vAl/qPSypv/ovtvyu99nXqaxAR0RSL7UanZUUNvszAObZnq8HvBToAtv+t/2Cp5LOnriAjImph01ksV9HY/p9DvOZVFdu9dPCIIiIW2CIawUdELCrOkn0REe3U4AF8EnxExLC617g3N8MnwUdEjCAJPiKilUxnepFcRRMRsZjUXaKRtB5YB0wD19u+ZIY2TwG2AA/Zfs1c/S34fPAREW1W13TBkg4DNgC/afvFwImS1szQ9ALgCirMzpsEHxExiuqzja2UtL302NTX08nANaWbO68GTik3KKZz2Q7cWSW0lGgiIkYwQIVmyvbaOY6vAHaXtncDj4/gJf0CsNr230k6rsobJsFHRAzLtZ5k3QU8p7S9nCemXQc4BzhC0l8Dh9Gdcfe1tt8/W4dJ8BERQ+ot2VeTG4A3SHpvUaY5A/iTx9/L/sPe82IEf8FcyR1amOCPPPRQzvqlX5p0GPv48se/OukQZnX//f8+6RBmtHPnvZMOIWJedSV42w9K2gpsk7QX2GH79lmaT9OdqHFOrUvwERHjVOdlkra3AdvK+yRdBZxte7rU7l7g3Pn6S4KPiBiWDQs82Zjts4Z9bRJ8RMQIMlVBREQLGehkuuCIiBbKbJIREe2VBT8iIlqp2jwzk5IEHxExgiT4iIgWyopOEREt5ukk+IiIVsoIPiKijSou5jEpSfARESNIgo+IaKGapwuuXRJ8RMSwDK5vwY/aJcFHRAyt2TX4oRfdlvSJeY5/UNKyUfuJiGiy6mtuj9+8I3hJH7d9dvH8FcAPbX8aOLDY92dAbwmlFcD5tj9T9H1AqZ83Ay+guxLJHcDv2d7T66fU7uXAbwMqHv8FvNX2bSN8zoiIBdHkEXyVEs1xkt5SPD8JuLJ80PZ5veeSzp2pT0k/Czzb9q8X2xuBVwAf7Gt3CvB84EW29xb7ngZ8Eji50ieKiBgTu9mTjVUp0dxj+2LbFwOfK+3/NUlflfQzpX3/A7huhj4eBFZLOroo25wE3D/L+3Xonpzucd/2PiRtkrRd0vadO3fO93kiImrj4lr4+R6TUGUEv2SW59fZPrO3Iem5dMs35Qz72aLEc6mkdwHvAQ4BPma7/GUBgO2vSPrp4nXlEs2r5grQ9mZgM8DatWub+3UaES1jOp39+yqaH0v6Gt1R9DTwhv4GklYAFwOv7Dv0wmKl8HXAKuBG4FhglaTzgJtmeL/7gb/qfQFIek/q7xHRSPv7ZGO2N0j6pO2X9B36MoCkXwYuAH7f9gOz9PGxou1xwO8Dl9Et97yA7oi+7PC+fcfO+ykiIialwTX4qtfB79PO9qXF0/8FnGP7kapvWrQ9BUDSp4o/Xw2sBw7tbqpXllkt6avF89fZ/r9V3yciYiF172SddBSzq5zgS0m2x3QT+ztmec1euidMyzoz7NsDYPsyuiP7iIj9xn5dogGw/cJBO7a9z4lR2/cAb+rb99JB+46IaASbTqYqiIhop/1+BB8REfvKbJIREW3V8LOsSfAREUNr9mySSfARESNwc8+xJsFHRAzN1DpVgaT1wDq6swZcb/uSvuPvpzsD76HAnbYvmqu/JPiIiCHVeZJV0mHABuB025a0VdIa23c9/n72a0vtPyzp2bbvmK3PoRf8iIiIgWaTXNmb9bZ4bOrr6mTgGj/xjXE1xR3//SQdCRwF/GCu2DKCj4gYmgeZD37K9to5jq8Adpe2dwNryg0kHQ+8je6XwRttPzjXG2YEHxExLNc6H/wu4MjS9vJi3xNvZ99tez3dxL9e0uq5OkyCj4gYRX2Lst4AnFqshQFwBnDtzG/pvXTX5zhorg5ToomIGJKBTk3TBRdrZ2wFtknaC+ywfXvvuKRfpDuX14+BZcBVxfxes2pdgr/v/gc47+3vm3QY+3jGc4+bdAiz2nnvr006hBndfcctkw5hRlNT9006hGiKmtdktb0N2FbeJ+kq4Gzb/wy8bJD+WpfgIyLGZ+HvZLV91rCvTYKPiBhBpiqIiGipJPiIiBaywVnwIyKinRo8gE+Cj4gYXqYLjohorST4iIg2chJ8REQrmXpvdKpbEnxExNCMa1zwo25J8BERw0qJJiKivRqc35PgIyJGkRp8REQL1bkm60JIgo+IGFZq8BERbWU6Db6KZmJL9kl6aoU2c643GBExae640mMSJpLgJT0N+NvS9kclrZih6QckHTeuuCIiBtItwte1Jmvtxp7gJR0DXA6cW2wvB5bY3jVD89cA75P0jDGGGBFRScPz+3gTfJHcLwNebfveYvc5wMclLZV0haT3SroIwPYDwEbgLyU9c45+N0naLmn7Iw//eGE/REREie1Kj0kY20nWogTTS+7/WTr0IuBM4Ol0v3De5NJPw/ZOSRuBKyRtsv29/r5tbwY2A6w+5tjmntKOiHax6TR4wY9xjuAPBqaBH/V2SFoD3GP7Mdt30q3LXyppXd9rHwZ+AiwdV7AREVU0eQQ/tgRfjLzPB7ZIOqzY/XLgI6U2X7H9emBjr42kpcAW4GLbd48r3oiI+fRudFr0CR7A9q3ABXST/OHArwLXAUg6QdIWSX8DfMf2Q5IOBbbSTe43jjPWiIgqmpzgx36jk+1vSXoL8Cng6716u+3b6I7oyy4H3mF7x5jDjIioYIKXyFQwkTtZbd8h6ZW2vz1P0wtt3zWWoCIiBmVwc8+xTm6qggrJnST3iGi6Jk9VkLloIiKGlNkkIyLaKrNJRkS01eQmEqsiCT4iYhQ1juAlrQfW0b0p9Hrbl/Qd/wDQAZYDn7P9kX17eUISfETECEw9Cb64uXMDcLptS9oqaU35YhPbv1u0FXAtpRtFZ5IEHxExJNt0OtNVm6+UtL20vbmYR6vnZOCa0lxcVwOnADNdTXgwsHu+N0yCj4gYwQAnWadsr53j+AqenLR3A2tmaftO4JJZjj0uCT4iYgQ1XkWzC3hOaXt5se9JJL0RuMX2N+brcGJL9kVEtEGNc9HcAJxa1NcBzqBbZ3+cpNcCD9u+skqHGcFHRAypm7zruZPV9oOStgLbJO0Fdti+vXdc0snAW4DPS/rrYveFtnfO1mfrEvxjjzzG3bc0b1bhH3z3B5MOYVYrjl4+6RBmdMzxL5h0CDO6Zcc/TjqEaJC6Eny3L28DtpX3SboKONv2N4FjB+mvdQk+ImKcFvpOVttnDfvaJPiIiBFkqoKIiFaqrwa/EJLgIyKG5Ew2FhHRXknwERGtZJwFPyIi2skkwUdEtFJKNBERLZSTrBERrVV5npmJSIKPiBjBAPPBj10SfETECDKCj4hoo24RftJRzCoJPiJiSKa+NVkXQhJ8RMQIMhdNREQr5SqaGUl6qu0H5mmz2vb3xxVTRMSgOg2eqmAia7JKehrwt6Xtj0paMUPTD0g6blxxRUQMonuOtVPpMQljT/CSjgEuB84ttpcDS2zvs3o48BrgfZKeMcYQIyIqqrbg9qTKOGNN8EVyvwx4te17i93nAB+XtFTSFZLeK+kigKKEsxH4S0nPnKPfTZK2S9r+2GOPLuyHiIgo610qOd9jAsZWgy9KML3k/p+lQy8CzgSeTvcL500ufd3Z3ilpI3CFpE22v9fft+3NwGaAI49c1dwzHhHROk2+THKcI/iDgWngR70dktYA99h+zPaddOvyl0pa1/fah4GfAEvHFWxERBUp0QDFyPt8YIukw4rdLwc+UmrzFduvBzb22khaCmwBLrZ997jijYiYj206nelKj0kYaw3e9q3ABXST/OHArwLXAUg6QdIWSX8DfMf2Q5IOBbbSTe43jjPWiIgqmjyCH/t18La/JektwKeAr/fq7bZvozuiL7sceIftHWMOMyKiktzo1Mf2HZJeafvb8zS90PZdYwkqImIISfAzqJDcSXKPiGYzZC6aiIj2saGTBB8R0U4p0UREtJIzXXBERFs1eQQ/kdkkIyLaos7r4CWtl/RpSf8g6c0zHF8i6V2SvlClvyT4iIghdecRq5zgV/YmRSwem8p9FXfvbwB+0/aLgROL6VzKXgh8morVl5RoIiKGZuzK0xBM2V47x/GTgWtKky1eDZwCPH65uO2rASRVesMk+IiIEdRYg18B7C5t7wb6R/ADSYKPiBhBjQl+F/Cc0vbyYt/QUoOPiBharSs63QCcqifqL2cA144SXUbwERFD6q3JWk9fflDSVmCbpL3ADtu3z9J8T5U+1eRrOIchaSfw3Zq6WwlM1dRXnRLX4JoaW+IaTN1xPd32UcO++OCDl/qYY6qVyb/97X+9eZ6TrDOSdBVwtgc4m9vTuhH8KH9Z/SRtH+YvZKElrsE1NbbENZjmxWXcWdg7WW2fNexrW5fgIyLGqclrsibBR0SMIHPR7L82TzqAWSSuwTU1tsQ1mEbF1buTtalad5I1ImJcDjroEK9adVyltvfdd8dQJ1lHkRF8RMQIOgt8knUUSfARESNIDT4ioo26RfhJRzGrJPiIiCGZXCYZEdFaTb5QJQk+ImIEqcFHRLSScxVNREQbNf1GpyT4iIgRJMFHRLSSITX4iIh2ymWSEREtlRJNREQL2abTGXihpbFJgo+IGEFG8BERLZUEHxHRUknwERFtlQQfEdE+tuk4J1kjIlopJZqIiJZKgo+IaCUnwUdEtFXmg4+IaKFMFxwR0VrOCD4ioq2S4CMiWqrOEo2k9cA6YBq43vYlgxzvlwQfETG8L9peWbHtIZK2l7Y3297c25B0GLABON22JW2VtMb2XVWOzyQJPiJiSLZPq7G7k4Fr/MSvBFcDpwB3VTy+jwNqDC4iIoa3Athd2t5d7Kt6fB9J8BERzbALOLK0vbzYV/X4PpLgIyKa4QbgVEkqts8Arh3g+D5Sg4+IaADbD0raCmyTtBfYYfv2qsdnoibfhRURsdhJugo42x58XuIk+IiIlkoNPiKipZLgIyJaKgk+IqKlkuAjIloqCT4ioqX+PwAsOGMJPi96AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "STEP: 2, loss: 3.3569\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW8AAADwCAYAAADPe+U2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAExtJREFUeJzt3X+QXlV9x/HPZwNROzIKhPLDCpnRWDpK0bLVDjqWdHAKaiOUFrQx044/AjhUawvU8kOh4i/sUKr4owsoEnCtChTRjEjboTSISKKoM4pAbRGqQjYxlAIK2f30j+du8rB5nt3ngbvPfc7u+5W5s3vPuXvud5fwzdlzzz3HSQQAKMtI0wEAAPpH8gaAApG8AaBAJG8AKBDJGwAKRPIGgAKRvAGgQCRvACgQyRsACrRb0wEAwDA66qijMjEx0dO1mzZtuj7JUfMc0hOQvAGgg4mJCd122209XTsyMrJsnsPZBckbALqYGuK1n0jeANBBJA3zwn0kbwDoKIpI3gBQlkiTUyRvAChKxJg3ABSJMW8AKBDJGwAKk4RhEwAoET1vAChMJE2SvAGgPPS8AaBAjHkDQGkSet4AUBrWNgGAQk1OTTUdQlcLcicd2wfaftj2irayN9heY/vKJmPrxPbFtvdrOo6naqF8H0BLev7ThAWZvNX6vv5V0nltZUskLUmyupmQZrW7FsZvQQvl+wCUSFM9Hk1YyP+jPSDpNtuvT/K56ULb65O82vYySe+T9Ihayf5MSaOS3iTp55IeTPLu+Q7S9p9Jeqmkc2xfLumtkh6W9GxJVyT58jze+yBJn5L0LUnLJH1d0kskPS7p8SSn2n6fpGdJeqakS5JssH2IpNMlTaj1j+I7Ja1p+z4+kuS7NcX4cUkXJLnb9gnV/Y7WjJ+R7bMl/UuSW6qvW5/k1XXEgMWLMe/mXCzpKtvr28qWVh8/JOlDSe6UJNuW9G5JRyaZsv0B24cl2TSfASa5zPYRks6pYjovyQ9sj0j6iu2bk/x8nm5vSduTnCZJtn8s6bAkm21/xvZqSQ8lOdP2bpKuUytxfkjSHyd52PZJko5p/z6S3FdjjOOSXq/Wb1F/qNbf2bNm/oxU/WbV9nVLd2kJ6BPJuyFVEv5bSedKmpmEnzuduCv7SNpX0vtbeVz7StpzIIHutG+SH0g7Yr9d0vMkbZzHe97b9vldSTZXnz8k6TmSXmz7g1XZL6uPz5d0dvVzepakm+cxvg2STrO9r1q/ET2/y89opiUdyoCesSRsw5J82/YbJR0m6TttVT+1fUiS71XnE5J+JumMJIN+xDyp1n+Ln9r+jbZe5aGSPjDAOGb+TY2kbyS5cEb5jySdm+TRGeXT30d9ASWpEvR5ki6XtLbLz+hBSftLrQfW058DT1oy1LNNFmrynqyOaedI+p5aY7uPVWWnS/qw7Yeqa8+Q9A+S/sn2FrWGE04ZULz/LulCSddI+hvbj6jVo/1Ykv+dx/vO/Dk9PqPuaknvsP0ptXrdG5JcqdZvMutsT6j1kPLPkzwy/X3YvjTJdTXGuU7SVyWtlfSfkj4482dk+3OSLrb9u9X3UcuYOxa3YR428TAHB0iS7edJemOSc5uOBYvHiw49NFdf/9Werv31/Q/YlGR0nkN6goXa88YCYfuPJB0r6R1Nx4LFZ4i3sCR5Y7gl+aKkLzYdBxanYR6ZIHkDQBckbwAoTIZ8tslCfT1+TrbXNh1DJ8TVv2GNjbj6M4xxpVoWdq6jCYs2eas17WwYEVf/hjU24urPUMU1/ZJOL0cTGDYBgC6aWjGwF0Ul72XLlmX58uW1tHXggQdqdHS0lv8ymzbVu/yJ7aH8GzOscUnDGxtx9afmuCaS7PNUGmCqYE2WL1+ub952W9Nh7GL33XZvOoSupqYm574IWJjueSpfnERTQ/zAsqjkDQCDxMJUAFAg5nkDQIFI3gBQmDQ4DbAXJG8A6IKpggBQmEiaHOK5giRvAOiizjHvak/YE9Ta6OSWJOfPqH+nWjt+PabWNn4nV5ucdLSYX48HgFnV9Xq87T0krZH0uiTHSjrE9oq2+mertfn5G5O8SdL3Jb1qtjZJ3gDQSY+LUlW982W2N7YdM9dpOVzSDdnZlb9W0sq2+gfV2sN2f9vPkHSQWptvd8WwCQB0EPU1bDIxxzZoe0va2na+VdKOnne10fanJb1N0hZJNyfZMtsN6XkDQBc1riq4RdKebed7VWWSJNu/Kem1Sc5OcqGkR22/ZbYGSd4A0EWNyftWSUfadnW+StJNbfX7S3Lb+aOSls/WIMMmANDB9HretbSVbLO9TtK47e2Sbk9yR9slX5P0StuXS/qlpF+R9PbZ2iR5A0AnNe+Sk2Rc0nh7me2rJB2fZFLSmf20N/BhE9sH2L6xOla2lX9h0LEAwGzmeyedJMdVibtvTfS8T9TOKTAvt70xyUOShndRbACLTp+zTQauieR9iVpvDy2RdIWkixqIAQDmxO7xbZLcq9YOF6dKGkuyrap6RTWU8oL2622vnZ74vnnz5kGHC2DRSs9/mtDEmPfBkj4v6W5Jr7G9qqrakOSIJHe2X59kLMloktF99nlK29EBQM+S3o8mDHTYpOpVv1nSKUnut71U0jFV9fpBxgIAc2E970rVqz7N9jW2d7xtZPtkSdcNMhYAmAsPLNvY3k3SkiRHtJWNSLpa0gWDjgcAOqnzJZ350MRsk0lJh9q+sa1sRNIdnS8HgAYkmhri2SYDT97VkogHDfq+ANA3et4AUJ6wDRoAlGeIO94kbwDopDWHe3izN8kbALogeQNAcaKpSWabAEBRGDYBgEKRvAGgRCRvACjPEOdukjcAdBQeWAJAcdgGrWYjdtMh7GKYF68B8OSRvAGgQCRvAChNIrEwFQCUh543ABQmkqboeQNAYXg9HgDKxGYMAFCc0PMGgBKRvAGgMCwJCwCFyiTJGwCKQ88bAEoTHlgCQJFI3gBQGJaEBYASRUqNmzHYXi3pBEmTkm5Jcv6M+udJOrM6nZT0niQ/6dYeyRsAOqpvzNv2HpLWSDo6SWyvs70iyV1VvSV9UNKJSbb20uZILZH1wfYBtm+sjpVt5V8YdCwAMJvWXO+5D0nLbG9sO9bOaOpwSTdk578G10pa2Vb/25LulfR+21fafstcsTXR8z5R0obq85fb3pjkIUm7NxALAHTVR897IsnoLPV7S2rvUW+VtKLtfLmkF0laleQXtj9h+4dJ/qNbg00k70skLamOKyRd1EAMADCrpNaFqbZIemHb+V5V2bRH1OqZ/6I6/5KkwyR1Td4DHzZJcq+keySdKmksybaq6hXVUMoL2q+3vXb6V5HNmzcPOlwAi1iqud5zHT24VdKR1di2JK2SdFNb/SZJL207f5mk787W4MB73rYPlvReSd+Q9BrbW5J8SdKGJMfMvD7JmKQxSRodHR3eeTsAFpjUtrl4km2210kat71d0u1J7mir/6ntr9kel/SwpP9O8m+ztTnQ5F31qt8s6ZQk99teKmk6Ya8fZCwAMKuaF6ZKMi5pvL3M9lWSjk8ymeRiSRf32t5Ak3eSOyWdZvsa23tOl9s+WdJ1g4wFAOY0z5sxJDnuyX5tE8Mmu0lakuSItrIRSVdLumDQ8QBAJ603LJuOorsmZptMSjrU9o1tZSOS7uh8OQA0g9fj21ST1A8a9H0BoC+Jpmp8Pb5uvB4PAF3Q8waAwrCqIACUaMifWJK8AaAjdtIBgCJleJ9XkrwBoKOottfj5wPJGwA64IElABSK5A0AxUmd63nXjuQNAJ3UvKpg3UjeANANyRsAyhJJUwyb1OPH992vU07/u6bD2MXqPz2j6RC6uv4rn2k6hI62bXug6RA62r79saZDwLCodw/L2hWVvAFgcHjDEgCKRPIGgAKRvAGgMIkUNmMAgPIMcceb5A0AnfHAEgCKRPIGgNLwejwAlCfiJR0AKFAUNmMAgMIwbAIAZRri3E3yBoBuGPMGgMKwhyUAlIgxbwAoUTQ1xLNNRgZ9Q9sH2L6xOla2lX9h0LEAwGwylZ6OJjTR8z5R0obq85fb3pjkIUm7NxALAHTWGvRuOoqumkjel0haUh1XSLqogRgAYFZDnrsHP2yS5F5J90g6VdJYkm1V1SuqoZQXtF9ve63tjbY3Pvro/w06XACLWJKejl7YXm37S7avsX16l2t2s/1Z2/84V3tNjHkfLOnzku6W9Brbq6qqDUmOSHJn+/VJxpKMJhl9xjOeOehwASxWiaYmp3o65mJ7D0lrJL0uybGSDrG9osOlZ0m6TK2RiVkNdNik6lW/WdIpSe63vVTSMVX1+kHGAgBz6WOq4DLbG9vOx5KMtZ0fLumG7GzwWkkrJd01fYHtP5G0UdITOrDdDDR5V73q06pfG/acLrd9sqTrBhkLAMymz5d0JpKMzlK/t6StbedbJe3oedt+iaT9knzW9vJebjjwB5a2d5O0JMkRbWUjkq6WdMGg4wGAbmp8SWeLpBe2ne9VlU17vaRn2/6kpD0k/ZbttyX5eLcGm5htMinpUNs3tpWNSLqjgVgAoIvUOd3kVknvsP331dDJKknv33Gn5K+nP6963mfNlrilBpJ3FfhBg74vAPQlUmp6wTLJNtvrJI3b3i7p9iTdOqyTkrbP1SavxwNAF3W+Hp9kXNJ4e5ntqyQdn2Sy7bp7JZ00V3skbwDoYBCrCiY57sl+LckbADphVUEAKFFzi071guQNAN3Q8waA8kQkbwAoShJNTU3OfWFDSN4A0AUPLAGgQCRvACgQyRsACtPaaGF4NyAuKnlPPPA/GrvwrKbD2MVzn3tw0yF09cpXHt90CB0t+7VlTYfQ0dhHzmg6BAwRkjcAFIhhEwAoEMkbAIrDmDcAFCcsTAUAZSJ5A0BxotS4GUPdSN4A0EVE8gaA4jBsAgCF4YElABQpJG8AKBHreQNAgeh5A0BpWoPeTUfRFckbADqI2MMSAIrE2iYAUJwFNtvE9q8meeCp3tj2fkl+9lTbAYD5MjXEr8eP9HOx7QMkfbrfm9heY/sNM4p/x/Z7+m0LAAah9bxyqqejCT0nb9vPkXSppJOexH2WVMcOSf5Z0oTtc59EewAwz1LtYzn30YSekneVuC+W9NYk99q+1Pb+bfXrq4/n2P6o7Ytsf872qhntLLE9ZvvFkpTkY5J+Zvu9s9x7re2NtjcO8/gTgAVoerrgXEcD5kzetvfWzsR9X1U8sye9tO3zbyU5RdIaSX/R3pSkCyRdmuT26cIkn5B0n+2OOwsnGUsymmTUdi/fEwDUIj3+aUIvPe+nSZqU9GCnSttLZhTdJUlJHpeesJ7i2yXtJ+mbHZq5T9JwbicOYNEqetgkyU8knSHpctt7VMUPSpoeNnmZ1NM/PZdJWle1tYPtP5B0rKS/6i1kAJh/STQ1NdnT0YSexryTfE/SWWol8GepNYzyHtsflfT7kjZXl05Wx7TH28q3JvmypMdsv06SbB8j6bWS1iYZ3hVgACxKw9zz7nmed5Lv236XpMuSHKtW0p15zXtnnB9dfVzXVvZhSbL9e5JeJemk8CQSwBAa5tTU10s6SX5o+y9ruvePJZ1C4gYwrOpMT7ZXSzpBrZGIW5KcP6P+E2o9J9xL0leSXDFbe32/YZnkv/r9mi7t3F1HOwAwPyL1/gLOMtsb287HkoxNn1TPC9dIOjpJbK+zvSLJXTvulpxcXWtJN0mqN3kDwGKQSFO9J++JJKOz1B8u6Ya2kYZrJa1UNTtvhqdJ2jrXDUneANBFjcMme+uJCXmrpBVdrj1P0vld6nYgeQNAR6lz3ZItkl7Ydr5XVfYEtt8p6dtJbp6rwb4WpgKAxaTGqYK3SjrSO18TX6XWuPYOtt8m6eEkV/bSID1vAOiirmGTJNtsr5M0bnu7pNuT3DFdb/twSe+StN72J6vis5Ns7tCcJJI3AHTUWnOqvqmCScYljbeX2b5K0vFJvi7pwH7aI3kDQEfRfL/4neS4J/u1JG8A6GKY3yEkeQNAFyRvACjOAtuAGAAWg+k9LIdVUck7ycTjj//ynpqaWyZpoo6GfvSj79TRzLTa4pJqja3WuGo2rLERV3/qjuugp9oAPe+aJNmnrraqPTFnW4ugEcTVv2GNjbj6M3xxRZmi5w0AxWlqf8pekLwBoAvGvIfT2NyXNIK4+jessRFXf4YqrrrfsKybhzk4AGjK0qVPz777Lu/p2vvu++GmQY/XL+aeNwDMaooHlgBQHsa8AaA0rUHvpqPoiuQNAB1ETBUEgCIN84QOkjcAdMGYNwAUJ8w2AYDSDPtLOiRvAOiC5A0AxYnEmDcAlIepggBQIIZNAKAwSTQ1Ndl0GF2RvAGgC3reAFAgkjcAFIjkDQAlInkDQFmSaCo8sASA4jBsAgAFInkDQHFC8gaAErGeNwAUhiVhAaBIoecNACUieQNAgeocNrG9WtIJkiYl3ZLk/H7qZyJ5A0Bn1ydZ1uO1T7e9se18LMnY9IntPSStkXR0ktheZ3tFkrt6qe+E5A0AHSQ5qsbmDpd0Q3Z25a+VtFLSXT3W72KkxuAAAJ3tLWlr2/nWqqzX+l2QvAFg/m2RtGfb+V5VWa/1uyB5A8D8u1XSkbZdna+SdFMf9btgzBsA5lmSbbbXSRq3vV3S7Unu6LW+Ew/zG0QAsJDZvkrS8Un/a8+SvAGgQIx5A0CBSN4AUCCSNwAUiOQNAAUieQNAgf4f1CA9Bjl2vwsAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "STEP: 3, loss: 0.1916\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAADwCAYAAAAKCX+nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFKdJREFUeJzt3XuwpHV95/H3d8YBSheRGchwKS6rjiYVDUmYaC1JJbDF1kI0LFkVMDgJ5cbxUibGlBhXwWiiBsmWuq6J2Rk0LoOON4wEpTSsAckYFhlWorsK4moANxjm4hAWjDLnfPaPfg40PefSp0+f7p7nvF9U1/Rz+/W3zwyf/p3f8/TvqSRIktpn1bgLkCQtDwNeklrKgJekljLgJamlDHhJaikDXpJayoCXpJYy4CWppQx4SWqpJ4y7AEk6WJ111lnZvXt3X/vedtttn09y1jKX9DgGvCQNaPfu3dx666197btq1aqjlrmcAxjwkrQE0xM8n5cBL0kDCjDJEzYa8JI0sBAMeElqn8DUtAEvSa0THIOXpNZyDF6SWsqAl6QWSuIQjSS1lT14SWqhAFMGvCS1kz14SWopx+AlqY0Se/CS1EbORSNJLTY1PT3uEua0Yu/oVFUnVtVDVbWha92Lq2pTVX14nLXNpqq2VtUx465jqdryPqSO9P3fOKzYgKfz3r8AvK1r3WpgdZILx1PSvNbQjt+42vI+JBKY7vMxDiv9f7T7gVur6oIkH51ZWVXXJfnlqjoKeDvwMJ0PhDcBG4GXAt8HHkjy5uUusqouAp4DvKWqrgReBjwEPAW4KslnlvG1TwI+CPxP4Cjgb4GfAR4BHknyuqp6O3AE8C+AK5LsqKpnA68HdtP54HwtsKnrfbw3yVeHVOOfAu9K8q2qOr95vbPp+RlV1aXAf09yc3PcdUl+eRg1aOVyDH6ybQWurqrrutYd0vz5TuCdSb4JUFUFvBk4M8l0Vf1RVZ2a5LblLDDJh6rqdOAtTU1vS/KNqloFfLaqvpTk+8v08gXsT3IxQFXdA5yaZFdV/bequhB4MMmbquoJwLV0wvWdwIuSPFRVrwDO7X4fSb47xBq3AxfQ+W3s39P5d31J78+I5je0ruMOOaAlaZEM+AnWBPUfAG8FeoP6hJlwbxwNrAfe0cl61gNHjqTQx6xP8g14tPbbgacBO5fxNe/ten5Xkl3N8weB44GfrqrLmnU/bP58OnBp83M6AvjSMta3A7i4qtbT+c3q6XP8jHqtnmWd1DenCz4IJPlKVb0EOBX4u65N91XVs5N8rVneDXwPeGOSUZ86n6Lz93VfVf1EV+/0FOCPRlhH77/mAP8jyXt61n8beGuSH/Ssn3kfwysoSRPibwOuBDbP8TN6ADgWOifZZ55LA0sm+iqalRzwU81jxluAr9EZa/5Rs+71wB9X1YPNvm8E/jPwsaraQ2fo4tUjqveLwHuAvwD+Y1U9TKdn/CdJ/mkZX7f35/RIz7ZPAa+pqg/S6b3vSPJhOr8Rbauq3XROrP5Wkodn3kdVfSDJtUOscxvwOWAz8H+Ay3p/RlX1UWBrVf1S8z6Gcg5AK9skD9HUJBcn9auqnga8JMlbx12LVo5nnXJKPvX5z/W17zOPPe62JBuXuaTHWck9eLVEVb0Q+FXgNeOuRSvPBN+S1YDXwS/JJ4FPjrsOrUyTPApiwEvSEhjwktRCmfCraFbyVAULqqrN465hNta1eJNam3UtziTWlWbK4IUe42DAz2/i/jE1rGvxJrU261qciapr5otO/TzGwSEaSVqCcc0U2Y/WBfyRa9fm+BNOGEpbxx1/PM865ZSh/O3d+fU7htEMAKtWrWbNmkOH9q9q//4fLbxTn6pqYv+1T2pt1rU4Q65rd5Kjl9KAl0mO0PEnnMDVn+vviwej9Is//a/GXcKc7r//nnGXMIcJ/j9HbXH3Ug5OwvQEn2RtXcBL0ig52ZgktZTXwUtSSxnwktRCGeMlkP0w4CVpCbxMUpJaKMDUBF8nacBL0hI4Bi9JLeUYvCS10ZAnEquqC4Hz6dwO8+Ykl/dsfy2de0f/iM5N41/Z3ApzVk42JkkDCouaTfKoqtrZ9XjcxGlVdTiwCfh3SX4VeHZVbeja/hTgzCQvSfJS4OvAv5mvPnvwkrQEixii2b3APVlPA67PY78SXAOcAdzVLD8A3FdVxwL7gJOAK+Z7QQNekpZgiGPw64C9Xct7gUd78ElSVX8OvArYA3wpyZ75GnSIRpIGNOT54PcAR3Ytr23WAVBVPwU8P8mlSd4D/KCqfnO+Bg14SRpUn+PvfZ6IvQU4s6qqWT4HuKlr+7FAdS3/ADh5vgaXPeCr6hMLbP9AVT15qe1I0jgMqwefZB+wDdheVVcBX03SfSOJvwKmqurKqtoKvAR493xtDm0Mvqq+QuckAMAxdC7fuQFY02y/FPh0kq/1HLqang+a5lKhqSQf7Vq9Zli1StIwzFxFM7T2ku3A9u51VXU1cF6SKeBNi2lvmCdZ705yblPQRcARPdtXN49+rAd6PwgkaeJMLfMNP5K8YNBjR3EVzdOr6nV0QrtfzwV2VtVTgQ826w6ba+fmetLN0LnNniSNRiZ6srFhjsE/tapurKobgd/jsbO/30ryn4B/nOfYz1TVbwNU1ZHAk4DNSb6d5PQkpwPfm+vgJFuSbEyy8ch164bxXiRpQUn/j3EYWg8+yU9VVWXuAakAhzUnVI8AnkHnLDB0Lv3Z1zz/feAPgJ+tqk1Jtg2rRkkatpU0F80ngBf2rPt88+cO4J3AD4H76HzN9truHavq14C9Sb4MfLmq3ldVX09y25DrlKShWEmzSR7QXpL3N39+AfhC7/bHLvkE4O+SfKTr2FdXzw6SNClmvug0qYYd8E9uxuC7BbggyVxj8PuBaYAk/7t3Y9eQzyPDKlKShiJhepmvolmKoQZ8kn89wDHzftW2a78XLb4iSVpmK6gHL0krSrxlnyS10wR34A14SRpU5xr3yU14A16SlsCAl6RWCtNTK+QqGklaSRyikaQWM+Alqa0MeElqpwnOdwNekgYWT7JKUisN+5Z9w9a6gD9szRqeeeyx4y7jAPfff/e4S5C0DAx4SWopA16S2igBJxuTpHayBy9JLRRg2h68JLWQUxVIUnt5ww9JaqXYg5ektjLgJamFnC5YklosUwa8JLWSPXhJaqN4klWSWsuAl6QWcrpgSWqrQLzhhyS10XDH4KvqQuB8YAq4OcnlPdufBrypWZwCfj/JP8zV3sQFfFUdB3ykWXxrkhua9Z9I8qLxVSZJBxpWvlfV4cAm4OwkqaptVbUhyV3N9gIuA16eZG8/bU5cwAMvB3Y0z3++qnYmeRBYM8aaJGlWi+jBH1VVO7uWtyTZ0rV8GnB9HmvwGuAM4K5m+eeAe4F3NB8GNyS5Yr4XnMSAvwJY3TyuAt433nIkaXbJoiYb251k4zzb1wHdPfO9wIau5ZOBZwHnJPnnqnp/Vd2Z5G/manBVv5WNSpJ7gbuB19H5hNvXbPqFqrqxqp7Re0xVba6qnVW1c9euXaMsV9IKl+Za+IUefdgDHNm1vLZZN+NhOj38f26W/xI4db4GJy7gq+rHgY8D3wKeV1XnNJt2JDk9yTd7j0myJcnGJBuPPvroUZYraUUL09PTfT36cAtwZjPWDnAOcFPX9tuA53QtPxf46nwNTuIQzX8AXp3kH6vqEODcZv11Y6xJkg40xMnGkuyrqm3A9qraD9ye5I6u7fdV1V9V1XbgIeDvk/z1fG1OXMAnubiq/qKqHv1VpapeCVw7xrIkaXZDvOFHku3A9u51VXU1cF6SqSRbga39tjdxAV9VTwBWJzm9a90q4FPAu8ZVlyT16nyTdZlfI3nBoMdOXMDTuXj/lKq6sWvdKuCO2XeXpPFxqoJFaK4BPWncdUjSghKmnapAktrJHrwktZCzSUpSW43iLOsSGPCSNDDv6CRJrZXJPcdqwEvSwEK/0xCMhQEvSQPyJKsktZgBL0mtlMXMBz9yBrwkDWqIs0kuBwNekpbCgJek9gkw7RCNJLXQ4u7JOnIGvCQNzG+ySlJrGfCS1FIGvCS1UALxhh+S1E4T3IE34CVpcJ5klaTWMuAlqY2cqkCS2in4RSdJaqkQb/ghSS3kEI0ktdcE57sBL0lL4Ri8JLWQ92SVpLaa8DH4VYMeWFWfWGD7FVX15KW2I0mTK0xPT/f1GIcFe/BV9fEk5zXPfwP4fpK/BNY06/4Y+Llm93XAG5Nc27S9qqud1wPPA6aAO4HfTvLITDtd+/068GKgmsc/AW9O8o0lvE9JWhYH+xj8yVX1hub5qcCHuzcmuXjmeVW9YrY2q+rHgWcm+aVm+SLgN4ArevY7Azgd+JUk+5t1xwGfBE7r6x1J0qh0BuHHXcWc+hmiuSfJZUkuAz7btf4XqurGqnpG17pfBHbM0sY+4JiqOrYZtjkVuG+O15um82ObkZ5lSZoIM/nez2Mc+unBr57j+Y4k584sVNWz6Azf7Ora5zPNEM97q+rtwLuAw4CPJen+sAAgyQ1VdUJzXPcQzW/OV2BVbQY2A5x44ol9vCVJGo5JPsnaT8D/v6r6Ip0PqyngNb07VNU64DLgpT2bnp9kX1WdD6wHvgycCKyvqouBW2d5vfuA9818AFTVuxYaf0+yBdgCsHHjxsn9aUtql4TpId7wo6ouBM6nk7U3J7l8ln2eAFwJPJjk5fO1t2DAJ9lUVZ9M8sKeTX/dvNhzgEuA30ly/xxtfKzZ92Tgd4CtdIZ7nkenR9/tiJ51dsklTaxF9OCPqqqdXctbms4pAFV1OLAJODtJqmpbVW1IcldPO5cAHwLOW+gF+70O/oD9kry3efpvgQuSPNxnWzT7ngFQVZ9u/nwZcCHwpM5izQzLHFNVNzbPX53kf/X7OpK0nBb5RafdSTbOs/004Po81uA1dHLy0YCvql8DdgLf7OcF+w74rpCdETrB/odzHLOfzgnTbtOzrHsEIMlWOj17STpoDHEMfh2wt2t5L7BhZqGqfgY4JslHmtGQBfUV8Eme33+Njx5zwInRJPcAv9uz7kWLbVuSJsNQL5HZA/xk1/LaZt2MC4CnVNWfAYcDP1tVr0ryp3M16FQFkjSoQIZ3jvUW4DVV9e5mmOYc4B2PvlTyezPPmx78JfOFOxjwkrQkw5qGoLnicBuwvar2A7cnuWOO3afoDIPPy4CXpAENezbJJNuB7d3rqupq4LwkU1373Qu8YqH2DHhJGtQIZpNM8oJBjzXgJWlgOegnG5MkzeUgn6pAkjSHTPBciAa8JA0oCdPTUwvvOCYGvCQtwcE+m6QkaQ4GvCS1lAEvSS2UhAxxroJha13A33nntzn99BePu4wDPPe5i56vbWQeeGD3uEuY1fofO2ncJczqizd9bNwlaIIY8JLUUg7RSFJLGfCS1EqOwUtSK2UEk40thQEvSUtgwEtSK4UM6YYfy8GAl6QlCAa8JLWSQzSS1EKeZJWk1ooBL0lt5XzwktRS9uAlqY06g/DjrmJOBrwkDSh4T1ZJai3nopGkVvIqmllV1Y8luX+BfY5J8r1R1SRJizU9wVMVrBrHi1bVccCfdy1/tKrWzbLr+6vq5FHVJUmL0TnHOt3XYxxGHvBVdTzwAeAVzfJaYHWSPbPs/nLgT6rqX46wREnqU5r7si78GIeRBnwT7luBlyW5t1l9AfDxqnpiVX2oqt5dVW8BaIZwLgL+S1U9dZ52N1fVzqra+cgjP1zeNyFJ3WYulVzoMQYjG4NvhmBmwv3/dm36FeBc4CQ6Hzi/m66PuyS7quoi4ENVtTnJP/S2nWQLsAXg8MPXTu4ZD0mtM8mXSY6yB38oMAU8MLOiqjYA9yT5YZJv0hmXf29Vnd9z7EPAj4AnjqpYSeqHQzRA0/N+I3BlVR3erP514KqufW5I8lvARTP7VNUTgSuBy5J8a1T1StJCkjA9PdXXYxxGOgaf5GvAJXRC/gjg54EdAFX1E1V1ZVX9V+DvkzxYVU8CttEJ9y+PslZJ6sck9+BHfh18kq9X1RuATwN/MzPenuQbdHr03T4A/GGS20dcpiT1xS869UhyZ1W9NMl3Ftj10iR3jaQoSRrAMAO+qi4EzqdzvvLmJJf3bH8/MA2sBT6b5KoDW3nM2L7J2ke4Y7hLmmyBIX2JqTnvuAk4O0mqaltVbejOwSSvbPYt4Ca6zmHOxrloJGlACUz3H/BHVdXOruUtzSXeM04Dru+6TPwa4Axgto7uocDehV7QgJekJVjEEM3uJBvn2b6Ox4f2XmDDHPu+Dbh8jm2PMuAlaWAZ5jwze4Cf7Fpe26x7nKp6LfCVJF9aqMGxTDYmSW0xxMskbwHObMbXAc6hM87+qKp6FfBQkg/306A9eElagmFdRZNkX1VtA7ZX1X7g9iR3zGyvqtOANwDXVdWfNasvTbJrrjYNeEkaUGceseFdJplkO7C9e11VXQ2cl+RvgRMX054BL0kDC8nyTkOQ5AWDHmvAS9IS+E1WSWopA16SWsmbbktSK83ck3VS1SR/+gyiqnYBdw+puaOA3UNqa5isa/EmtTbrWpxh13VSkqMHPfjQQ5+Y44+f68umj/ed73z1tgW+yTp0revBL+Uvq1dV7Rz1X0g/rGvxJrU261qcyasrZHpye/CtC3hJGqVJvierAS9JSzDJY/AG/Py2LLzLWFjX4k1qbda1OBNV17C/yTpsrTvJKkmjcsghh2X9+pP72ve7373Tk6ySdDCZ9iSrJLWTY/CS1EadQfhxVzEnA16SBhS8TFKSWmuSL1Qx4CVpCRyDl6RWilfRSFIbTfoXnQx4SVoCA16SWingGLwktZOXSUpSSzlEI0ktlITp6alxlzEnA16SlsAevCS1lAEvSS1lwEtSWxnwktQ+SZiOJ1klqZUcopGkljLgJamVYsBLUls5H7wktZDTBUtSa8UevCS1lQEvSS01zCGaqroQOB+YAm5Ocvlitvcy4CVpcJ9PclSf+x5WVTu7lrck2TKzUFWHA5uAs5OkqrZV1YYkd/WzfTYGvCQNKMlZQ2zuNOD6PPYrwTXAGcBdfW4/wKohFidJGtw6YG/X8t5mXb/bD2DAS9Jk2AMc2bW8tlnX7/YDGPCSNBluAc6sqmqWzwFuWsT2AzgGL0kTIMm+qtoGbK+q/cDtSe7od/tsapK/hSVJK11VXQ2clyx+XmIDXpJayjF4SWopA16SWsqAl6SWMuAlqaUMeElqqf8PMeSl1pFIgasAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "STEP: 4, loss: 0.1557\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAADwCAYAAAAKCX+nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFapJREFUeJzt3X2wXVV9xvHvk4gwQASSaIAWpGp8GUWrpHaGdpR06AgVKa1C0BhhGI3UUQGnWIugWF8asUMttbUNYCmRRgRakZeqVEEMpUAotLYCYl9E21hJQjAGVHLP0z/2PmRz7tu55+57z86+z4c5c89ee521f+fm8jvrrL322rJNRES0z7xhBxARETMjCT4ioqWS4CMiWioJPiKipZLgIyJaKgk+IqKlkuAjIloqCT4ioqWS4CMiWuppww4gImJ3dcwxx3jz5s191b377ru/bPuYGQ7pKZLgIyIGtHnzZu66666+6s6bN2/xDIczShJ8RMQ0dBq8nlcSfETEgAw0ecHGJPiIiIEZkwQfEdE+hpFOEnxEROuYjMFHRLRWxuAjIloqCT4iooVsZ4gmIqKt0oOPiGghAyNJ8BER7ZQefERES2UMPiKijez04CMi2ihr0UREtNhIpzPsEMY1Z+/oJOlQSTskLa2UvVHSKklXDDO2sUi6WNKBw45jutryPiIK7vu/YZizCZ7ivX8V+EilbD4w3/bK4YQ0oT1oxzeutryPCGzo9PkYhrn+P9oPgbsknWz7c91CSTfa/g1Ji4GPAo9RfCC8H1gGnAY8Ajxq+wMzHaSkU4FXAudLuhx4G7AD2B/4rO3rZ/DYzwY+A/wzsBj4R+DlwBPAE7Z/V9JHgf2AfYFLbG+QdDjwXmAzxQfnWcCqyvu4yPa/1hTjnwMX2v6OpBXl8Y6l53ck6TzgH2zfXr7uRtu/UUcMMXdlDL7ZLgaukXRjpezp5c+PAx+3/W0ASQI+ABxtuyPpDyUdYfvumQzQ9mWSjgLOL2P6iO37JM0DbpB0m+1HZujwAnbaPhtA0kPAEbYflvTXklYC222/X9LTgOsokuvHgRNt75B0OnBC9X3Y/n6NMa4HTqb4NvbbFH/X5/b+jii/oVVe9/RRLUVMURJ8g5WJ+g+ADwG9ifqQbnIvPRNYAnysyPUsAQ6YlUB3WWL7Pngy9nuB5wIbZ/CY36s8f9D2w+Xz7cDPAb8oaU1Z9tPy5/OA88rf037AbTMY3wbgbElLKL5ZPW+c31Gv+WOURfQtywXvBmzfI+nNwBHAv1R2bZJ0uO1vltubgR8A59ie7VPnIxT/XpskvajSO30Z8IezGEfvX7OBf7L9yZ7y/wQ+ZPvxnvLu+6gvINtlEv8IcDmwepzf0aPAQVCcZO8+jxiY3ehZNHM5wY+Uj67zgW9SjDX/rCx7L/AJSdvLuucAfwJcKWkLxdDFO2cp3q8DnwT+Dvh9SY9R9Iz/zPaPZvC4vb+nJ3r2/S1whqTPUPTeN9i+guIb0TpJmylOrL7L9mPd9yHpUtvX1RjnOuBLwGrgP4A1vb8jSZ8DLpb06vJ91HIOIOa2Jg/RqMnBRfRL0nOBN9v+0LBjibnjJS97mf/2y1/qq+4LDjr4btvLZjikp5jLPfhoCUlvAH4LOGPYscTc0+BbsibBx+7P9tXA1cOOI+amJo+CJMFHRExDEnxERAu54bNo5vJSBZOStHrYMYwlcU1dU2NLXFPTxLhcLhk82WMYkuAn1rg/plLimrqmxpa4pqZRcXUvdOrnMQwZoomImIZhrRTZj9Yl+L33XeD9Fy6upa39DljEwYf+Qi3/etsfebSOZgDYc8+9WbBgYW1/VT/5yY5a2pk3bz577LFnbXGNjOysqykkMW/e/Fpiq/siZkmNzBBzJK7Ntp85nQYyTXIW7b9wMW8/+/xhhzHKzVf//bBDGNd99//TsEMY0/btW4cdwpgef3z7sEOI+nx3Oi+2TafBJ1lbl+AjImZTFhuLiGipzIOPiGipJPiIiBbyEKdA9iMJPiJiGjJNMiKihQyMNHieZBJ8RMQ0ZAw+IqKl6hyDL29iv4Libmm3276gZ/9ZFLcW/RnFPYV/p7xT2piyFk1ExKD6XGis7OUvlrSx8njKujqSFgCrgN+0/VvA4ZKWVvbvDxxt+822TwO+Bfz6ROGlBx8RMSAzpSGazZPcsu9I4CbvavBaYDnwYLn9KLBJ0kHANuDZwCUTHTAJPiJiGmocolkEVNfn2Ao82YO3bUl/BbwD2ALcZnvLRA1miCYiYhpqXC54C3BAZXthWQaApJcCx9k+z/YngcclvXWiBpPgIyIGVPN68HcAR0tSuX08cGtl/0GAKtuPA4dN1GCGaCIiBlXj3Zpsb5O0DlgvaSdwr+37K1W+ArxK0uXAT4G9gXdP1OaMJ3hJV9k+cYL9lwJn2f7RdNqJiBiGOqdJ2l4PrK+WSboGOMn2CPD+qbRX2xCNpHsk3VI+7pe0vNy1R7n/PEmHj/HS+b1xSFop6eSeenvUFWtERB26s2hm8p6stl9fJvcpq7MH/13bJwBIOhXYr2f//PLRjyXAN+sLLSJiZozM8Rt+PE/S71Ik7X79MrBR0nOAz5Rle41XubxgYDUUt9mLiJgdbvRiY3XOonlOd4gG+D12Te/5ju0/Av5vgtdeL+ndAJIOAPYBVtv+T9tH2T4K+MF4L7a91vYy28v23ndBHe8lImJSdv+PYaitB2/7pZLk8QebDOwl6RkUwzfPp5jmA8Xczm3l8w8CfwC8QtIq2+vqijEiom5zaT34q4A39JR9ufy5Afg4xfSeTRTrKFxXrSjpTcBW23cCd0r6lKRv2b675jgjImoxl1aTHNWe7U+XP78KfLV3/645/QD8i+2/qbz2neqpEBHRFN0LnZqq7gT/jHIMvsrAybbHG4PfCXQAbP97787KkM8TdQUZEVELm85cmUVj+9cGeM2EaylU6uUip4honjnUg4+ImFOcW/ZFRLRTgzvwSfAREYMq5rg3N8MnwUdETEMSfEREK5nOyByZRRMRMZdkiCYiosWS4CMi2ioJPiKinRqc35PgIyIG5pxkjYhope4t+5qqdQn+4Gct4oPvOmXYYYzyoTNOG3YI47Kb2wOJaLok+IiIlkqCj4hoIxuy2FhERDulBx8R0UIGOunBR0S0UJYqiIhor9zwIyKilZwefEREWyXBR0S0UJYLjohoMY8kwUdEtFJ68BERbeScZI2IaK0k+IiIFspywRERbWVwjTf8kLQSWAGMALfbvqBn/3OB95ebI8AHbf/veO0lwUdEDKy+MXhJC4BVwLG2LWmdpKW2Hyz3C1gDvN321n7anDeNYK6aZP8lkp4x3XYiIpqsmAs/+QNYLGlj5bG6p6kjgZu86xPjWmB5Zf8vAd8DPibpCklvnSy2SXvwkj5v+6Ty+SnAI7a/COxRln2iPDDAIuAc29eVbc+rtPNe4LUUXyseAN5t+4luO5V6bwHeCKh8/Aj4gO37Jos1ImK2TaEHv9n2sgn2LwKqPfOtwNLK9mHAS4Djbf9E0qclPWD7G+M12M8QzWGS3lc+PwK4orrT9tnd55JOH6tNSS8EXmD71eX2qcApwCU99ZYDRwGvs72zLDsYuJri0y0iojHsWhcb2wK8uLK9sCzreoyih/+TcvuLFDl53ATfzxDNQ7bX2F4D3FAp/1VJt0h6fqXsVcCGMdrYBhwo6aBy2OYIYNM4x+tQnJzucs/2KJJWd7/2PPzww5O9n4iI2ricCz/Zow93AEeXY+0AxwO3VvbfDbyysv3LwL9O1GA/Pfj54zzfYPuE7oakl1AM31Qz7PXlEM9Fkj4KXAjsBVxpu/phAYDtmyUdUr6uOkQz4ViT7bXAWoBly5Y1d85SRLSM6XTqmUVje5ukdcB6STuBe23fX9m/SdJXJK0HdgD/bftrE7XZT4L/saSvU/SiR4AzeitIWkRxdve0nl3HlUGvAJYAdwKHAksknQ3cNcbxNgGf6n4ASLow4+8R0Ug1LzZmez2wvlom6RrgJNsjti8GLu63vUkTvO1Vkq62/YaeXV8rD/5K4FzgTNs/HKeNK8u6hwFnlgHeQHHSda+e6vv1lB066buIiBiWGb7hh+3XD/rafufBj6pn+6Ly6WuAk20/1u9By7rLASR9ofz5NmAlsE+x+eQUoAMl3VI+f6ftf+v3OBERM6m4knXYUYyv7wRfSbJdpkjsHx7nNTspTphWdcYoewJgql89IiKaYLdfqsD2cVNt2PaoE6O2HwLe01N24lTbjohoBJtOjUsV1C1LFURETMNu34OPiIjRsppkRERbNfwsaxJ8RMTAckeniIjWcnPPsSbBR0QMzNS2VMFMSIKPiBhQTrJGRLRYEnxERCu5zvXga5cEHxExqJpXk6xbEnxExHQkwUdEtI+BToZoZs9Ip8MjO3YMO4xR9tlnv2GHMK4f//iRYYcQsXuq956stWtdgo+ImD25kjUiorWS4CMiWioJPiKihWxwbvgREdFODe7AJ8FHRAwuJ1kjIlorCT4ioo2yVEFERDuZXOgUEdFSxrnhR0REC2WIJiKivRqc35PgIyKmI2PwEREtlHuyRkS0VcPH4OcN+kJJV02y/xJJz5huOxERzWU6nU5fj2GYtAcv6fO2TyqfnwI8YvuLwB5l2SeAXyqrLwLOsX1d2fa8SjvvBV4LjAAPAO+2/US3nUq9twBvBFQ+fgR8wPZ903ifEREzYncfgz9M0vvK50cAV1R32j67+1zS6WO1KemFwAtsv7rcPhU4Bbikp95y4CjgdbZ3lmUHA1cDR/b1jiIiZksxCD/sKMbVzxDNQ7bX2F4D3FAp/1VJt0h6fqXsVcCGMdrYBhwo6aBy2OYIYNM4x+tQ/Nq63LMdEdEI3fzez2MY+unBzx/n+QbbJ3Q3JL2EYvjm4Uqd68shnoskfRS4ENgLuNJ29cMCANs3SzqkfF11iOatEwUoaTWwGuDnDzmkj7cUEVGPOk+ySloJrKAYyr7d9gVj1HkacDmw3fbbJ2qvnwT/Y0lfp/iwGgHOGOOAi4A1wGk9u46zvU3SCmAJcCdwKLBE0tnAXWMcbxPwqe4HgKQLJxt/t70WWAvw8le8Ir39iJgdNp2abvghaQGwCjjWtiWtk7TU9oM9Vc8FLgNOmqzNSRO87VWSrrb9hp5dXyuDemV5wDNt/3CcNq4s6x4GnAlcTDHc81qKHn3Vfj1lh04WY0TEsEyhB79Y0sbK9tqyc9p1JHCTdzV4LbAceDLBS3oTsBH4dj8H7Hce/Kh6ti8qn74GONn2Y322RVl3OYCkL5Q/3wasBPYpNtUdljlQ0i3l83fa/rd+jxMRMZOmeKHTZtvLJti/CNha2d4KLO1uSHo5cKDtvyk7y5PqO8FXkmyXKRL7h8d5zU6KE6ZVnTHKngCwfTFFzz4iYrdR4xj8FuDFle2FZVnXycD+kv4CWAC8QtI7bP/5eA32leBtHzfVSG2POjFq+yHgPT1lJ0617YiIZqh1iswdwBmS/rgcpjke+NiTR7J/r/u87MGfO1FyhyxVEBExOINruki1nJCyDlgvaSdwr+37x6k+QjFKMqEk+IiIaahzGQLb64H11TJJ1wAn2R6p1PsecPpk7SXBR0QMaDZWk7T9+kFfmwQfETGohq8mmQQfETEw7/aLjUVExHjSg4+IaCc3eC3EJPiIiAHZptMZmbzikCTBR0RMQ06yRkS0VBJ8RERLJcFHRLSQbVzXWgUzoHUJ/t577mHhvvsOO4zdjIYdwJia+j9OcbOxiEJT/06hhQk+ImI2ZYgmIqKlkuAjIlopY/AREa3kLDYWEdFeSfAREa1kXOMNP+qWBB8RMQ0mCT4iopUyRBMR0UI5yRoR0VpOgo+IaKusBx8R0VLpwUdEtFExCD/sKMaVBB8RMSCTe7JGRLRW1qKJiGilzKIZk6Rn2f7hJHUOtP2D2YopImKqOg1eqmDeMA4q6WDgryrbn5O0aIyqn5Z02GzFFRExFcU51k5fj2GY9QQv6eeAS4HTy+2FwHzbW8ao/nbgzyT9wiyGGBHRJ5f3ZZ38MQyzmuDL5H4x8Dbb3yuLTwY+L2lvSZdJ+mNJ5wOUQzinAn8q6TkTtLta0kZJG2f2HURE9OhOlZzsMQSzNgZfDsF0k/v/VHa9DjgBeDbFB857XPm4s/2wpFOByySttv2/vW3bXgusLY/T3DMeEdE6TZ4mOZs9+D2BEeDRboGkpcBDtn9q+9sU4/IXSVrR89odwM+AvWcr2IiIfmSIBih73ucAl0taUBa/Bfhspc7Ntt8FnNqtI2lv4HJgje3vzFa8ERGTsU2nM9LXYxhmdQze9jeBcymS/H7ArwAbACS9SNLlkv4S+G/b2yXtA6yjSO53zmasERH9aHIPftbnwdv+lqT3AV8AvtEdb7d9H0WPvupS4MO2753lMCMi+pILnXrYfkDSabb/a5Kq59l+cFaCiogYQJ0JXtJKYAXF+crbbV/Qs//TQAdYCNxg+7OjW9llaFey9pHcSXKPiGYz1HQRU3necRVwrG1LWidpaTUP2v6dsq6AW6mcwxxL1qKJiBiQDZ3+E/zinmt11pZTvLuOBG6qTBO/FlgOjNXR3RPYOtkBk+AjIqZhCkM0m20vm2D/Ip6atLcCS8ep+xHggnH2PSkJPiJiYK5znZktwIsr2wvLsqeQdBZwj+3bJmtwKIuNRUS0RY3TJO8Aji7H1wGOpxhnf5KkdwA7bF/RT4PpwUdETENds2hsb5O0DlgvaSdwr+37u/slHQm8D7hR0l+UxefZfni8NpPgIyIGVKwjVt80SdvrgfXVMknXACfZ/kfg0Km0lwQfETEwY8/sMgS2Xz/oa5PgIyKmIVeyRkS0VBJ8REQr5abbERGt1L0na1O1McFvBr5bU1uLy/aapua4auuB1BrXrunAtZgj/5a1mStxPXu6DaQHP4tsP7OutiRtnOTS4qFIXFPX1NgS19Q0Ly7jTnrwERGt1OR7sibBR0RMQ8bgd19rJ68yFIlr6poaW+KamkbFVfeVrHVTk4OLiGiypz99Ly9Zclhfdb///Qfunu3zB+nBR0RMQycnWSMi2ilj8BERbVQMwg87inElwUdEDMhkmmRERGs1eaJKEnxExDRkDD4iopWcWTQREW3U9AudkuAjIqYhCT4iopUMGYOPiGinTJOMiGipDNFERLSQbTqdkWGHMa4k+IiIaUgPPiKipZLgIyJaKgk+IqKtkuAjItrHNh3nJGtERCtliCYioqWS4CMiWslJ8BERbZX14CMiWijLBUdEtJbTg4+IaKsk+IiIlqpziEbSSmAFMALcbvuCqezvlQQfETG4L9te3GfdvSRtrGyvtb22uyFpAbAKONa2Ja2TtNT2g/3sH0sSfETEgGwfU2NzRwI3eddXgmuB5cCDfe4fZV6NwUVExOAWAVsr21vLsn73j5IEHxHRDFuAAyrbC8uyfvePkgQfEdEMdwBHS1K5fTxw6xT2j5Ix+IiIBrC9TdI6YL2kncC9tu/vd/9Y1OSrsCIi5jpJ1wAn2VNflzgJPiKipTIGHxHRUknwEREtlQQfEdFSSfARES2VBB8R0VL/DxH7ECe4NysJAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "encoder.train()\n", "decoder.train()\n", "losses = []\n", "for i in range(STEP):\n", " encoder.zero_grad()\n", " decoder.zero_grad()\n", " enc_output, enc_hidden = encoder(inputs, lengths.tolist())\n", " outputs, attns = decoder(enc_hidden, enc_output, lengths.tolist(), \n", " targets.size(1), targets, is_eval=False)\n", " \n", " \n", " loss = loss_function(outputs, targets[:, 1:].contiguous().view(-1))\n", " losses.append(loss.item())\n", " # check training & show attentions\n", " print('STEP: {}, loss: {:.4f}'.format(i, loss.item()))\n", " preds = outputs.max(1)[1]\n", " translated = list(map(lambda x: trg_itos[x], preds.tolist()))\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(attns[1].detach().numpy(), cmap='bone')\n", " fig.colorbar(cax)\n", " ax.set_xticklabels([''] + src[0])\n", " ax.set_yticklabels([''] + translated)\n", " plt.show()\n", " \n", " loss.backward()\n", " \n", " enc_optimizer.step()\n", " dec_optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### IWSLT 데이터셋\n", "\n", "IWSLT 2016 TED talk 번역 문제입니다. 독어-영어의 번역이 목적입니다. \n", "\n", "훈련시간이 오래 걸리기 때문에, 미리 훈련을 시켰습니다. 훈련 하이퍼파라미터는 아래와 같습니다.\n", "\n", "\n", "\n", "| 변수명 |1차 훈련 | 2차 훈련 | 3차 훈련 | 설명 | \n", "|--|--|--|--|--|\n", "| BATCH| 50 | 50 | 50 | 배치 사이즈 | \n", "| MAX_LEN | 30 | 30 | 30 | 훈련 시킬 문장의 최대 길이 |\n", "| MIN_FREQ | 2 | 2 | 2 | 최소 단어 출현 횟수 |\n", "| EMBED | 256 | 256 | 256 | 임베딩 크기 |\n", "| HIDDEN | 512 | 512 | 512 | 히든 크기 |\n", "| ENC_N_LAYER | 3 | 3 | 3 | 인코더 층 개수 |\n", "| DEC_N_LAYER | 1 | 1 | 1 | 디코더 층 개수 |\n", "| L_NORM | True | True | True | 임베딩 후, layer normalization |\n", "| DROP_RATE | 0.2 | 0.2 | 0.2 | 임베딩 후, dropout 확률 |\n", "| METHOD | general | general | general | attention 방법 |\n", "| LAMBDA | 0.00001 | 0.00001 | 0.0001 | weight decay rate |\n", "| LR | 0.001 | 0.0001 | 1.0 | 학습률 |\n", "| DECLR | 5.0 | 5.0 | - | decoder의 학습강도, 학습률에 곱해준다 |\n", "| OPTIM | adam | adam | adelta | 최적화 알고리즘 방법 |\n", "| STEP | 30 | 20 | 20 | 훈련 step 기준 1/3, 3/4 지점에서 각각 0.1 을 곱해서 학습률을 조정합니다.|\n", "| TF | True | True | True | teacher forcing, 모델에게 다음 토큰을 알려주는지 여부 |\n", "\n", "각 학습 세이브 체크포인트\n", "\n", "* 1차 훈련: [8/30] (train) loss 2.4335 (valid) loss 5.6971 \n", "* 2차 훈련: [1/30] (train) loss 2.3545 (valid) loss 5.6575 \n", "* 3차 훈련: [6/20] (train) loss 1.9401 (valid) loss 5.4970 " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['runtrain_big.sh', 'runtrain_big2.sh', 'runtrain_loaded2.sh', 'runtrain_loaded.sh', 'runtrain.sh', 'runtrain_wmt.sh']\n" ] } ], "source": [ "import sys\n", "import os\n", "sys.path.append(os.getcwd() + '/model')\n", "print([d for d in os.listdir(path='./model') if '.sh' in d])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import importlib\n", "import torch\n", "import numpy as np\n", "import matplotlib.pylab as plt \n", "import matplotlib.ticker as ticker\n", "import settings\n", "from train import import_data, build_model, validation\n", "\n", "%run ./model/build_config.py -r \"./model/runtrain_big2.sh\" -c \"./model/settings.py\" -n\n", "importlib.reload(settings)\n", "DEVICE = None" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Source Language: 51823 words, Target Language: 30980 words\n", "Training Examples: 189416, Validation Examples: 946\n", "Building Model ...\n" ] } ], "source": [ "SRC, TRG, train, _, test, _, _, test_loader = import_data(config=settings, device=DEVICE, is_test=True)\n", "enc, dec, loss_function, *_ = build_model(config=settings, src_field=SRC, trg_field=TRG, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test loss is 6.1745\n" ] } ], "source": [ "enc.load_state_dict(torch.load(settings.SAVE_ENC_PATH, map_location={'cuda:0': 'cpu'}))\n", "dec.load_state_dict(torch.load(settings.SAVE_DEC_PATH, map_location={'cuda:0': 'cpu'}))\n", "print(\"test loss is {:.4f}\".format(validation(settings, enc, dec, test_loader, loss_function)))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def evaluate(test_data, max_len, src_field, trg_field, stop_token=''):\n", " f = lambda x: trg_field.vocab.itos[x]\n", " test_ex = np.random.choice(test_data.examples)\n", " src_sent = test_ex.src\n", " trg_sent = test_ex.trg\n", " inputs, lengths = src_field.numericalize(([src_sent], [len(src_sent)]))\n", " enc_output, enc_hidden = enc(inputs, lengths.tolist())\n", " outputs, attns = dec.forward(enc_hidden, enc_output, lengths.tolist(), max_len, \n", " targets=None, is_test=True, \n", " stop_idx=trg_field.vocab.stoi[stop_token])\n", " preds = outputs.max(1)[1]\n", " translated = list(map(f, preds.tolist()))\n", " \n", " print(\" < input: {}\".format(' '.join(src_sent)))\n", " print(\" > target: {}\".format(' '.join(trg_sent)))\n", " print(\" > predict: {}\".format(' '.join(translated[:-1])))\n", " return attns.squeeze(0).detach(), (src_sent, trg_sent, translated)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def show_attention(input_sentence, output_words, attentions):\n", " # Set up figure with colorbar\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(attentions.numpy(), cmap='bone')\n", " fig.colorbar(cax)\n", "\n", " # Set up axes\n", " ax.set_xticklabels([''] + input_sentence + [''], rotation=90)\n", " ax.set_yticklabels([''] + output_words)\n", "\n", " # Show label at every tick\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", " plt.show()\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def eval_and_show(dataset, max_len, src_field, trg_field, stop_token=''):\n", " attns, sents = evaluate(dataset, max_len, src_field, trg_field, stop_token)\n", " show_attention(sents[0], sents[2], attns)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " < input: und so würden zwei davon , in den körper implantiert , weniger als einen zehner wiegen .\n", " > target: so two of them implanted in the body would weigh less than a dime .\n", " > predict: and so , that would two two of them in the body , less than a car .\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eval_and_show(train, settings.MAX_LEN, SRC, TRG)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " < input: ich habe das privileg , diesen 20ern wie emma jeden tag mitzuteilen : die 30er sind nicht die neuen 20er , macht euch also auf , erlangt ein identitätskapital , nutzt eure schwachen bande sucht euch eure familie selbst aus .\n", " > target: it 's what i now have the privilege of saying to twentysomethings like emma every single day : thirty is not the new 20 , so claim your adulthood , get some identity capital , use your weak ties , pick your family .\n", " > predict: i 've been taking the privilege , which is the day that i met every day : every new new , so you listen to a new new , not so much , your own friends .\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eval_and_show(test, settings.MAX_LEN, SRC, TRG)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }