{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 關於Keras序列到序列學習的十分鐘介紹\n", "\n", "我經常看到有人問這個問題 - 如何在Keras中實現RNN序列到序列(sequence-to-sequence)的學習?\n", "\n", "這篇文章是對\"sequence-to-sequence\"一個簡短的介紹。\n", "\n", "請注意,這篇文章假設你已經有一些遞歸網絡(recurrent networks)和Keras的經驗。\n", "\n", "![seq2seq](http://pytorch.org/tutorials/_images/seq2seq.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 什麼是從序列到序列 (seq2seq) 的學習?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "序列到序列(Seq2Seq)學習是關於訓練模型以將來自一個領域(例如,英語的句子)的序列轉換成另一個領域(例如翻譯成中文的相同句子)的序列的模型。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "\"the cat sat on the mat\" -> [Seq2Seq model] -> \"那隻貓坐在地毯上\"\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "這可以用於機器翻譯或任何Q&A(根據自然語言問題生成自然語言答案) - 通常,只要您需要生成文本,就可以使用它。\n", "\n", "有多種方式來處理這樣的任務,或者使用RNN或者使用一維的卷積網絡(convnets)。這裡我們將重點放在RNN的使用。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 簡單的案例:當輸入和輸出序列具有相同的長度時\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "當輸入序列和輸出序列長度相同時,您可以簡單地用Keras LSTM或GRU層(或其堆疊)來實現這些模型。以下的示範就是這種情況,它顯示瞭如何教導RNN學習如何對數字進行相加(加法):\n", "\n", "![addition](https://blog.keras.io/img/seq2seq/addition-rnn.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 1. 引入相關的函數庫" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "from keras.models import Sequential\n", "from keras import layers\n", "from keras.utils import plot_model\n", "import numpy as np\n", "from six.moves import range\n", "from IPython.display import Image" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharacterTable(object):\n", " \"\"\"\n", " 給予一組的字符:\n", " + 將這些字符使用one-hot編碼成數字表示\n", " + 解碼one-hot編碼數字表示成為原本的字符\n", " + 解碼字符機率的向量以回覆最有可能的字符\n", " \"\"\"\n", " def __init__(self, chars):\n", " \"\"\"初始化字符表\n", " \n", " # 參數:\n", " chars: 會出現在輸入的可能字符集\n", " \"\"\"\n", " self.chars = sorted(set(chars))\n", " self.char_indices = dict((c, i) for i, c in enumerate(self.chars))\n", " self.indices_char = dict((i, c) for i, c in enumerate(self.chars))\n", " \n", " def encode(self, C, num_rows):\n", " \"\"\"對輸入的字串進行one-hot編碼\n", " \n", " # 參數:\n", " C: 要被編碼的字符\n", " num_rows: one-hot編碼後要回傳的最大行數。這是用來確保每一個輸入都會得到\n", " 相同行數的輸出\n", " \"\"\"\n", " x = np.zeros((num_rows, len(self.chars)))\n", " for i, c in enumerate(C):\n", " x[i, self.char_indices[c]] = 1\n", " return x\n", " \n", " def decode(self, x, calc_argmax=True):\n", " \"\"\"對輸入的編碼(向量)進行解碼\n", " \n", " # 參數:\n", " x: 要被解碼的字符向量或字符編碼\n", " calc_argmax: 是否要用argmax算符找出機率最大的字符編碼\n", " \"\"\"\n", " if calc_argmax:\n", " x = x.argmax(axis=-1)\n", " return ''.join(self.indices_char[x] for x in x)\n", " \n", "class colors:\n", " ok = '\\033[92m'\n", " fail = '\\033[91m'\n", " close = '\\033[0m'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 2. 相關的參數與產生訓練用的資料集" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating data...\n", "Total addition questions: 50000\n" ] } ], "source": [ "# 模型與資料集的參數\n", "TRAINING_SIZE = 50000 # 訓練資料集的samples數\n", "DIGITS = 3 # 加數或被加數的字符數\n", "INVERT = True \n", "\n", "# 輸入的最大長度 'int + int' (比如, '345+678')\n", "MAXLEN = DIGITS + 1 + DIGITS\n", "\n", "# 所有要用到的字符(包括數字、加號及空格)\n", "chars = '0123456789+ '\n", "ctable = CharacterTable(chars) # 創建CharacterTable的instance\n", "\n", "questions = [] # 訓練用的句子 \"xxx+yyy\"\n", "expected = [] # 訓練用的標籤\n", "seen = set()\n", "\n", "print('Generating data...') # 產生訓練資料\n", "\n", "while len(questions) < TRAINING_SIZE:\n", " # 數字產生器 (3個字符)\n", " f = lambda: int(''.join(np.random.choice(list('0123456789'))\n", " for i in range(np.random.randint(1, DIGITS+1))))\n", " a, b = f(), f()\n", " # 跳過己經看過的題目以及x+Y = Y+x這樣的題目\n", " key = tuple(sorted((a, b)))\n", " if key in seen:\n", " continue \n", " seen.add(key)\n", " \n", " # 當數字不足MAXLEN則填補空白\n", " q = '{}+{}'.format(a, b)\n", " query = q + ' ' * (MAXLEN - len(q))\n", " ans = str(a + b)\n", " \n", " # 答案的最大的字符長度為DIGITS + 1\n", " ans += ' ' * (DIGITS + 1 - len(ans))\n", " if INVERT:\n", " # 調轉問題字符的方向, 比如. '12+345'變成'543+21'\n", " query = query[::-1]\n", " questions.append(query)\n", " expected.append(ans)\n", " \n", "print('Total addition questions:', len(questions))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 3.資料的前處理" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vectorization...\n", "Feature data: (50000, 7, 12)\n", "Label data: (50000, 4, 12)\n", "Training Data:\n", "(45000, 7, 12)\n", "(45000, 4, 12)\n", "Validation Data:\n", "(5000, 7, 12)\n", "(5000, 4, 12)\n" ] } ], "source": [ "# 把資料做適當的轉換, LSTM預期的資料結構 -> [samples, timesteps, features]\n", "print('Vectorization...')\n", "x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool) # 初始一個3維的numpy ndarray (特徵資料)\n", "y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool) # 初始一個3維的numpy ndarray (標籤資料)\n", "\n", "# 將\"特徵資料\"轉換成LSTM預期的資料結構 -> [samples, timesteps, features]\n", "for i, sentence in enumerate(questions):\n", " x[i] = ctable.encode(sentence, MAXLEN) # <--- 要了解為什麼要這樣整理資料\n", "\n", "print(\"Feature data: \", x.shape)\n", "\n", "# 將\"標籤資料\"轉換成LSTM預期的資料結構 -> [samples, timesteps, features]\n", "for i, sentence in enumerate(expected):\n", " y[i] = ctable.encode(sentence, DIGITS + 1) # <--- 要了解為什麼要這樣整理資料\n", "\n", "print(\"Label data: \", y.shape)\n", "\n", "# 打散 Shuffle(x, y)\n", "indices = np.arange(len(y))\n", "np.random.shuffle(indices)\n", "x = x[indices]\n", "y = y[indices]\n", "\n", "# 保留10%的資料來做為驗證\n", "split_at = len(x) - len(x) // 10\n", "(x_train, x_val) = x[:split_at], x[split_at:]\n", "(y_train, y_val) = y[:split_at], y[split_at:]\n", "\n", "print('Training Data:')\n", "print(x_train.shape)\n", "print(y_train.shape)\n", "\n", "print('Validation Data:')\n", "print(x_val.shape)\n", "print(y_val.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 4.構建網絡架構" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Build model...\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "lstm_1 (LSTM) (None, 128) 72192 \n", "_________________________________________________________________\n", "repeat_vector_1 (RepeatVecto (None, 4, 128) 0 \n", "_________________________________________________________________\n", "lstm_2 (LSTM) (None, 4, 128) 131584 \n", "_________________________________________________________________\n", "time_distributed_1 (TimeDist (None, 4, 12) 1548 \n", "_________________________________________________________________\n", "activation_1 (Activation) (None, 4, 12) 0 \n", "=================================================================\n", "Total params: 205,324\n", "Trainable params: 205,324\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "# 可以試著替代其它種的rnn units, 比如,GRU或SimpleRNN\n", "RNN = layers.LSTM\n", "HIDDEN_SIZE = 128\n", "BATCH_SIZE = 128\n", "LAYERS = 1\n", "\n", "print('Build model...')\n", "model = Sequential()\n", "\n", "# ===== 編碼 (encoder) ====\n", "\n", "# 使用RNN“編碼”輸入序列,產生HIDDEN_SIZE的輸出。\n", "# 注意:在輸入序列長度可變的情況下,使用input_shape =(None,num_features)\n", "model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars)))) # MAXLEN代表是timesteps, 而len(chars)是one-hot編碼的features\n", "\n", "# 作為解碼器RNN的輸入,重複提供每個時間步的RNN的最後一個隱藏狀態。\n", "# 重複“DIGITS + 1”次,因為這是最大輸出長度,例如當DIGITS = 3時,最大輸出是999 + 999 = 1998(長度為4)。\n", "model.add(layers.RepeatVector(DIGITS+1))\n", "\n", "# ==== 解碼 (decoder) ====\n", "# 解碼器RNN可以是多層堆疊或單層。\n", "for _ in range(LAYERS):\n", " # 通過將return_sequences設置為True,不僅返回最後一個輸出,而且還以(num_samples,timesteps,output_dim)\n", " # 的形式返回所有輸出。這是必要的,因為下面的TimeDistributed需要第一個維度是時間步長。\n", " model.add(RNN(HIDDEN_SIZE, return_sequences=True))\n", "\n", "# 對輸入的每個時間片推送到密集層來對於輸出序列的每一時間步,決定選擇哪個字符。\n", "model.add(layers.TimeDistributed(layers.Dense(len(chars))))\n", "\n", "model.add(layers.Activation('softmax'))\n", "model.compile(loss='categorical_crossentropy',\n", " optimizer='adam',\n", " metrics=['accuracy'])\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 5.訓練模型/驗證評估\n", "\n", " 我們將進行50次的訓練,並且在每次訓練之後就進行檢查。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "--------------------------------------------------\n", "Iteration 1\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 10s 221us/step - loss: 1.8837 - acc: 0.3226 - val_loss: 1.7765 - val_acc: 0.3463\n", "Q 249+82 T 331 \u001b[91m☒\u001b[0m 109 \n", "Q 800+51 T 851 \u001b[91m☒\u001b[0m 109 \n", "Q 6+346 T 352 \u001b[91m☒\u001b[0m 70 \n", "Q 95+816 T 911 \u001b[91m☒\u001b[0m 109 \n", "Q 2+116 T 118 \u001b[91m☒\u001b[0m 22 \n", "Q 3+874 T 877 \u001b[91m☒\u001b[0m 10 \n", "Q 34+868 T 902 \u001b[91m☒\u001b[0m 109 \n", "Q 1+118 T 119 \u001b[91m☒\u001b[0m 22 \n", "Q 68+909 T 977 \u001b[91m☒\u001b[0m 100 \n", "Q 926+78 T 1004 \u001b[91m☒\u001b[0m 100 \n", "\n", "--------------------------------------------------\n", "Iteration 2\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 166us/step - loss: 1.7156 - acc: 0.3660 - val_loss: 1.6534 - val_acc: 0.3875\n", "Q 76+64 T 140 \u001b[91m☒\u001b[0m 167 \n", "Q 49+497 T 546 \u001b[91m☒\u001b[0m 409 \n", "Q 235+0 T 235 \u001b[91m☒\u001b[0m 110 \n", "Q 317+97 T 414 \u001b[91m☒\u001b[0m 709 \n", "Q 588+7 T 595 \u001b[91m☒\u001b[0m 889 \n", "Q 745+9 T 754 \u001b[91m☒\u001b[0m 154 \n", "Q 37+37 T 74 \u001b[91m☒\u001b[0m 33 \n", "Q 54+87 T 141 \u001b[91m☒\u001b[0m 154 \n", "Q 660+775 T 1435 \u001b[91m☒\u001b[0m 1477\n", "Q 472+903 T 1375 \u001b[91m☒\u001b[0m 1329\n", "\n", "--------------------------------------------------\n", "Iteration 3\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 174us/step - loss: 1.5789 - acc: 0.4064 - val_loss: 1.5200 - val_acc: 0.4273\n", "Q 85+6 T 91 \u001b[91m☒\u001b[0m 10 \n", "Q 49+517 T 566 \u001b[91m☒\u001b[0m 540 \n", "Q 55+474 T 529 \u001b[91m☒\u001b[0m 540 \n", "Q 984+369 T 1353 \u001b[91m☒\u001b[0m 1331\n", "Q 52+490 T 542 \u001b[91m☒\u001b[0m 490 \n", "Q 540+96 T 636 \u001b[91m☒\u001b[0m 504 \n", "Q 175+861 T 1036 \u001b[91m☒\u001b[0m 1104\n", "Q 8+285 T 293 \u001b[91m☒\u001b[0m 882 \n", "Q 467+47 T 514 \u001b[91m☒\u001b[0m 444 \n", "Q 467+47 T 514 \u001b[91m☒\u001b[0m 444 \n", "\n", "--------------------------------------------------\n", "Iteration 4\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 162us/step - loss: 1.4369 - acc: 0.4614 - val_loss: 1.3570 - val_acc: 0.4923\n", "Q 539+109 T 648 \u001b[91m☒\u001b[0m 682 \n", "Q 58+828 T 886 \u001b[91m☒\u001b[0m 883 \n", "Q 408+15 T 423 \u001b[91m☒\u001b[0m 444 \n", "Q 16+588 T 604 \u001b[91m☒\u001b[0m 623 \n", "Q 716+93 T 809 \u001b[91m☒\u001b[0m 777 \n", "Q 3+694 T 697 \u001b[91m☒\u001b[0m 664 \n", "Q 870+40 T 910 \u001b[91m☒\u001b[0m 884 \n", "Q 788+87 T 875 \u001b[91m☒\u001b[0m 884 \n", "Q 0+205 T 205 \u001b[91m☒\u001b[0m 222 \n", "Q 946+68 T 1014 \u001b[91m☒\u001b[0m 1004\n", "\n", "--------------------------------------------------\n", "Iteration 5\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 156us/step - loss: 1.2804 - acc: 0.5239 - val_loss: 1.2137 - val_acc: 0.5504\n", "Q 492+88 T 580 \u001b[91m☒\u001b[0m 552 \n", "Q 6+685 T 691 \u001b[91m☒\u001b[0m 675 \n", "Q 228+651 T 879 \u001b[91m☒\u001b[0m 991 \n", "Q 75+688 T 763 \u001b[91m☒\u001b[0m 743 \n", "Q 14+640 T 654 \u001b[91m☒\u001b[0m 679 \n", "Q 21+291 T 312 \u001b[91m☒\u001b[0m 203 \n", "Q 57+858 T 915 \u001b[91m☒\u001b[0m 842 \n", "Q 202+37 T 239 \u001b[91m☒\u001b[0m 243 \n", "Q 391+810 T 1201 \u001b[91m☒\u001b[0m 1111\n", "Q 90+987 T 1077 \u001b[91m☒\u001b[0m 1051\n", "\n", "--------------------------------------------------\n", "Iteration 6\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 159us/step - loss: 1.1528 - acc: 0.5759 - val_loss: 1.0907 - val_acc: 0.6036\n", "Q 788+704 T 1492 \u001b[91m☒\u001b[0m 1485\n", "Q 396+6 T 402 \u001b[91m☒\u001b[0m 490 \n", "Q 449+97 T 546 \u001b[91m☒\u001b[0m 548 \n", "Q 87+92 T 179 \u001b[91m☒\u001b[0m 188 \n", "Q 82+266 T 348 \u001b[91m☒\u001b[0m 355 \n", "Q 596+405 T 1001 \u001b[91m☒\u001b[0m 100 \n", "Q 50+24 T 74 \u001b[91m☒\u001b[0m 80 \n", "Q 354+17 T 371 \u001b[91m☒\u001b[0m 370 \n", "Q 42+381 T 423 \u001b[91m☒\u001b[0m 425 \n", "Q 450+76 T 526 \u001b[91m☒\u001b[0m 525 \n", "\n", "--------------------------------------------------\n", "Iteration 7\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 176us/step - loss: 1.0391 - acc: 0.6230 - val_loss: 1.0094 - val_acc: 0.6263\n", "Q 66+585 T 651 \u001b[91m☒\u001b[0m 641 \n", "Q 968+25 T 993 \u001b[91m☒\u001b[0m 990 \n", "Q 24+831 T 855 \u001b[91m☒\u001b[0m 864 \n", "Q 364+68 T 432 \u001b[91m☒\u001b[0m 434 \n", "Q 40+62 T 102 \u001b[91m☒\u001b[0m 110 \n", "Q 151+118 T 269 \u001b[91m☒\u001b[0m 299 \n", "Q 907+9 T 916 \u001b[91m☒\u001b[0m 913 \n", "Q 20+726 T 746 \u001b[91m☒\u001b[0m 745 \n", "Q 895+36 T 931 \u001b[91m☒\u001b[0m 939 \n", "Q 71+709 T 780 \u001b[91m☒\u001b[0m 789 \n", "\n", "--------------------------------------------------\n", "Iteration 8\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 173us/step - loss: 0.9626 - acc: 0.6523 - val_loss: 0.9241 - val_acc: 0.6655\n", "Q 504+15 T 519 \u001b[91m☒\u001b[0m 528 \n", "Q 903+334 T 1237 \u001b[91m☒\u001b[0m 1253\n", "Q 896+809 T 1705 \u001b[91m☒\u001b[0m 1777\n", "Q 330+338 T 668 \u001b[91m☒\u001b[0m 675 \n", "Q 43+63 T 106 \u001b[91m☒\u001b[0m 102 \n", "Q 12+682 T 694 \u001b[91m☒\u001b[0m 698 \n", "Q 135+43 T 178 \u001b[91m☒\u001b[0m 174 \n", "Q 121+64 T 185 \u001b[91m☒\u001b[0m 184 \n", "Q 459+148 T 607 \u001b[91m☒\u001b[0m 610 \n", "Q 7+34 T 41 \u001b[91m☒\u001b[0m 33 \n", "\n", "--------------------------------------------------\n", "Iteration 9\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 177us/step - loss: 0.8857 - acc: 0.6836 - val_loss: 0.8589 - val_acc: 0.6865\n", "Q 31+486 T 517 \u001b[91m☒\u001b[0m 518 \n", "Q 843+187 T 1030 \u001b[91m☒\u001b[0m 1008\n", "Q 901+694 T 1595 \u001b[91m☒\u001b[0m 1690\n", "Q 155+836 T 991 \u001b[91m☒\u001b[0m 900 \n", "Q 8+253 T 261 \u001b[91m☒\u001b[0m 264 \n", "Q 8+428 T 436 \u001b[91m☒\u001b[0m 441 \n", "Q 254+59 T 313 \u001b[91m☒\u001b[0m 316 \n", "Q 57+718 T 775 \u001b[91m☒\u001b[0m 776 \n", "Q 406+13 T 419 \u001b[91m☒\u001b[0m 418 \n", "Q 836+28 T 864 \u001b[91m☒\u001b[0m 867 \n", "\n", "--------------------------------------------------\n", "Iteration 10\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 162us/step - loss: 0.8077 - acc: 0.7143 - val_loss: 0.7743 - val_acc: 0.7218\n", "Q 68+41 T 109 \u001b[91m☒\u001b[0m 102 \n", "Q 866+607 T 1473 \u001b[91m☒\u001b[0m 1470\n", "Q 552+35 T 587 \u001b[91m☒\u001b[0m 589 \n", "Q 339+68 T 407 \u001b[91m☒\u001b[0m 406 \n", "Q 78+672 T 750 \u001b[91m☒\u001b[0m 759 \n", "Q 773+45 T 818 \u001b[91m☒\u001b[0m 810 \n", "Q 953+28 T 981 \u001b[91m☒\u001b[0m 976 \n", "Q 474+927 T 1401 \u001b[91m☒\u001b[0m 1399\n", "Q 86+53 T 139 \u001b[91m☒\u001b[0m 136 \n", "Q 66+973 T 1039 \u001b[91m☒\u001b[0m 1036\n", "\n", "--------------------------------------------------\n", "Iteration 11\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 165us/step - loss: 0.7308 - acc: 0.7399 - val_loss: 0.6885 - val_acc: 0.7488\n", "Q 950+1 T 951 \u001b[92m☑\u001b[0m 951 \n", "Q 739+230 T 969 \u001b[91m☒\u001b[0m 975 \n", "Q 406+3 T 409 \u001b[91m☒\u001b[0m 408 \n", "Q 11+228 T 239 \u001b[91m☒\u001b[0m 249 \n", "Q 59+512 T 571 \u001b[91m☒\u001b[0m 579 \n", "Q 155+836 T 991 \u001b[91m☒\u001b[0m 997 \n", "Q 300+3 T 303 \u001b[92m☑\u001b[0m 303 \n", "Q 96+639 T 735 \u001b[91m☒\u001b[0m 731 \n", "Q 469+19 T 488 \u001b[91m☒\u001b[0m 484 \n", "Q 5+184 T 189 \u001b[92m☑\u001b[0m 189 \n", "\n", "--------------------------------------------------\n", "Iteration 12\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 165us/step - loss: 0.6097 - acc: 0.7817 - val_loss: 0.5211 - val_acc: 0.8122\n", "Q 776+284 T 1060 \u001b[91m☒\u001b[0m 1050\n", "Q 433+633 T 1066 \u001b[92m☑\u001b[0m 1066\n", "Q 1+176 T 177 \u001b[91m☒\u001b[0m 178 \n", "Q 465+125 T 590 \u001b[92m☑\u001b[0m 590 \n", "Q 55+69 T 124 \u001b[92m☑\u001b[0m 124 \n", "Q 447+303 T 750 \u001b[91m☒\u001b[0m 759 \n", "Q 478+555 T 1033 \u001b[91m☒\u001b[0m 1043\n", "Q 390+29 T 419 \u001b[91m☒\u001b[0m 418 \n", "Q 871+465 T 1336 \u001b[91m☒\u001b[0m 1347\n", "Q 787+36 T 823 \u001b[91m☒\u001b[0m 822 \n", "\n", "--------------------------------------------------\n", "Iteration 13\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 172us/step - loss: 0.4288 - acc: 0.8542 - val_loss: 0.3511 - val_acc: 0.8901\n", "Q 31+749 T 780 \u001b[92m☑\u001b[0m 780 \n", "Q 293+59 T 352 \u001b[92m☑\u001b[0m 352 \n", "Q 56+197 T 253 \u001b[91m☒\u001b[0m 255 \n", "Q 528+209 T 737 \u001b[91m☒\u001b[0m 738 \n", "Q 887+79 T 966 \u001b[92m☑\u001b[0m 966 \n", "Q 476+563 T 1039 \u001b[92m☑\u001b[0m 1039\n", "Q 302+115 T 417 \u001b[91m☒\u001b[0m 438 \n", "Q 1+25 T 26 \u001b[91m☒\u001b[0m 36 \n", "Q 737+17 T 754 \u001b[91m☒\u001b[0m 755 \n", "Q 781+309 T 1090 \u001b[91m☒\u001b[0m 1099\n", "\n", "--------------------------------------------------\n", "Iteration 14\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 176us/step - loss: 0.2916 - acc: 0.9216 - val_loss: 0.2544 - val_acc: 0.9331\n", "Q 643+33 T 676 \u001b[92m☑\u001b[0m 676 \n", "Q 21+623 T 644 \u001b[92m☑\u001b[0m 644 \n", "Q 35+84 T 119 \u001b[92m☑\u001b[0m 119 \n", "Q 57+989 T 1046 \u001b[92m☑\u001b[0m 1046\n", "Q 199+941 T 1140 \u001b[91m☒\u001b[0m 1130\n", "Q 259+0 T 259 \u001b[91m☒\u001b[0m 250 \n", "Q 99+448 T 547 \u001b[92m☑\u001b[0m 547 \n", "Q 5+459 T 464 \u001b[91m☒\u001b[0m 463 \n", "Q 10+0 T 10 \u001b[91m☒\u001b[0m 11 \n", "Q 315+51 T 366 \u001b[92m☑\u001b[0m 366 \n", "\n", "--------------------------------------------------\n", "Iteration 15\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "45000/45000 [==============================] - 8s 169us/step - loss: 0.2108 - acc: 0.9509 - val_loss: 0.1777 - val_acc: 0.9599\n", "Q 308+9 T 317 \u001b[92m☑\u001b[0m 317 \n", "Q 327+674 T 1001 \u001b[92m☑\u001b[0m 1001\n", "Q 106+579 T 685 \u001b[92m☑\u001b[0m 685 \n", "Q 14+957 T 971 \u001b[92m☑\u001b[0m 971 \n", "Q 860+158 T 1018 \u001b[91m☒\u001b[0m 1008\n", "Q 1+166 T 167 \u001b[92m☑\u001b[0m 167 \n", "Q 931+806 T 1737 \u001b[91m☒\u001b[0m 1747\n", "Q 754+5 T 759 \u001b[92m☑\u001b[0m 759 \n", "Q 636+976 T 1612 \u001b[92m☑\u001b[0m 1612\n", "Q 750+83 T 833 \u001b[92m☑\u001b[0m 833 \n", "\n", "--------------------------------------------------\n", "Iteration 16\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 173us/step - loss: 0.1572 - acc: 0.9669 - val_loss: 0.1312 - val_acc: 0.9752\n", "Q 40+379 T 419 \u001b[92m☑\u001b[0m 419 \n", "Q 592+845 T 1437 \u001b[92m☑\u001b[0m 1437\n", "Q 81+756 T 837 \u001b[92m☑\u001b[0m 837 \n", "Q 513+25 T 538 \u001b[92m☑\u001b[0m 538 \n", "Q 62+339 T 401 \u001b[92m☑\u001b[0m 401 \n", "Q 462+70 T 532 \u001b[92m☑\u001b[0m 532 \n", "Q 563+658 T 1221 \u001b[92m☑\u001b[0m 1221\n", "Q 366+29 T 395 \u001b[92m☑\u001b[0m 395 \n", "Q 98+474 T 572 \u001b[92m☑\u001b[0m 572 \n", "Q 495+91 T 586 \u001b[92m☑\u001b[0m 586 \n", "\n", "--------------------------------------------------\n", "Iteration 17\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 159us/step - loss: 0.1106 - acc: 0.9813 - val_loss: 0.1020 - val_acc: 0.9818\n", "Q 13+861 T 874 \u001b[92m☑\u001b[0m 874 \n", "Q 103+89 T 192 \u001b[92m☑\u001b[0m 192 \n", "Q 99+88 T 187 \u001b[92m☑\u001b[0m 187 \n", "Q 870+91 T 961 \u001b[92m☑\u001b[0m 961 \n", "Q 447+303 T 750 \u001b[92m☑\u001b[0m 750 \n", "Q 88+21 T 109 \u001b[92m☑\u001b[0m 109 \n", "Q 711+282 T 993 \u001b[92m☑\u001b[0m 993 \n", "Q 7+154 T 161 \u001b[92m☑\u001b[0m 161 \n", "Q 536+17 T 553 \u001b[92m☑\u001b[0m 553 \n", "Q 698+25 T 723 \u001b[92m☑\u001b[0m 723 \n", "\n", "--------------------------------------------------\n", "Iteration 18\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 182us/step - loss: 0.0855 - acc: 0.9867 - val_loss: 0.0846 - val_acc: 0.9838\n", "Q 91+761 T 852 \u001b[92m☑\u001b[0m 852 \n", "Q 526+3 T 529 \u001b[92m☑\u001b[0m 529 \n", "Q 492+41 T 533 \u001b[92m☑\u001b[0m 533 \n", "Q 571+35 T 606 \u001b[92m☑\u001b[0m 606 \n", "Q 46+426 T 472 \u001b[92m☑\u001b[0m 472 \n", "Q 857+1 T 858 \u001b[92m☑\u001b[0m 858 \n", "Q 449+67 T 516 \u001b[92m☑\u001b[0m 516 \n", "Q 675+78 T 753 \u001b[92m☑\u001b[0m 753 \n", "Q 18+79 T 97 \u001b[92m☑\u001b[0m 97 \n", "Q 908+134 T 1042 \u001b[91m☒\u001b[0m 1041\n", "\n", "--------------------------------------------------\n", "Iteration 19\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 174us/step - loss: 0.0688 - acc: 0.9897 - val_loss: 0.0710 - val_acc: 0.9865\n", "Q 886+310 T 1196 \u001b[92m☑\u001b[0m 1196\n", "Q 914+108 T 1022 \u001b[92m☑\u001b[0m 1022\n", "Q 913+533 T 1446 \u001b[92m☑\u001b[0m 1446\n", "Q 454+207 T 661 \u001b[92m☑\u001b[0m 661 \n", "Q 614+43 T 657 \u001b[92m☑\u001b[0m 657 \n", "Q 824+134 T 958 \u001b[92m☑\u001b[0m 958 \n", "Q 70+826 T 896 \u001b[92m☑\u001b[0m 896 \n", "Q 88+70 T 158 \u001b[92m☑\u001b[0m 158 \n", "Q 96+843 T 939 \u001b[91m☒\u001b[0m 949 \n", "Q 78+920 T 998 \u001b[92m☑\u001b[0m 998 \n", "\n", "--------------------------------------------------\n", "Iteration 20\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 169us/step - loss: 0.0612 - acc: 0.9901 - val_loss: 0.0815 - val_acc: 0.9795\n", "Q 945+1 T 946 \u001b[92m☑\u001b[0m 946 \n", "Q 454+930 T 1384 \u001b[92m☑\u001b[0m 1384\n", "Q 667+100 T 767 \u001b[91m☒\u001b[0m 768 \n", "Q 938+13 T 951 \u001b[92m☑\u001b[0m 951 \n", "Q 836+28 T 864 \u001b[92m☑\u001b[0m 864 \n", "Q 38+4 T 42 \u001b[92m☑\u001b[0m 42 \n", "Q 348+865 T 1213 \u001b[92m☑\u001b[0m 1213\n", "Q 891+365 T 1256 \u001b[92m☑\u001b[0m 1256\n", "Q 328+93 T 421 \u001b[92m☑\u001b[0m 421 \n", "Q 181+336 T 517 \u001b[92m☑\u001b[0m 517 \n", "\n", "--------------------------------------------------\n", "Iteration 21\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 158us/step - loss: 0.0587 - acc: 0.9886 - val_loss: 0.0454 - val_acc: 0.9929\n", "Q 843+38 T 881 \u001b[92m☑\u001b[0m 881 \n", "Q 98+827 T 925 \u001b[92m☑\u001b[0m 925 \n", "Q 25+726 T 751 \u001b[92m☑\u001b[0m 751 \n", "Q 322+21 T 343 \u001b[92m☑\u001b[0m 343 \n", "Q 148+13 T 161 \u001b[92m☑\u001b[0m 161 \n", "Q 418+587 T 1005 \u001b[92m☑\u001b[0m 1005\n", "Q 43+472 T 515 \u001b[92m☑\u001b[0m 515 \n", "Q 1+808 T 809 \u001b[92m☑\u001b[0m 809 \n", "Q 112+16 T 128 \u001b[92m☑\u001b[0m 128 \n", "Q 218+763 T 981 \u001b[92m☑\u001b[0m 981 \n", "\n", "--------------------------------------------------\n", "Iteration 22\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 164us/step - loss: 0.0360 - acc: 0.9963 - val_loss: 0.0838 - val_acc: 0.9748\n", "Q 965+110 T 1075 \u001b[92m☑\u001b[0m 1075\n", "Q 246+323 T 569 \u001b[92m☑\u001b[0m 569 \n", "Q 939+6 T 945 \u001b[92m☑\u001b[0m 945 \n", "Q 78+743 T 821 \u001b[92m☑\u001b[0m 821 \n", "Q 0+978 T 978 \u001b[92m☑\u001b[0m 978 \n", "Q 54+205 T 259 \u001b[92m☑\u001b[0m 259 \n", "Q 29+26 T 55 \u001b[92m☑\u001b[0m 55 \n", "Q 474+5 T 479 \u001b[92m☑\u001b[0m 479 \n", "Q 93+366 T 459 \u001b[92m☑\u001b[0m 459 \n", "Q 80+429 T 509 \u001b[92m☑\u001b[0m 509 \n", "\n", "--------------------------------------------------\n", "Iteration 23\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 168us/step - loss: 0.0381 - acc: 0.9938 - val_loss: 0.0551 - val_acc: 0.9863\n", "Q 89+974 T 1063 \u001b[92m☑\u001b[0m 1063\n", "Q 89+35 T 124 \u001b[92m☑\u001b[0m 124 \n", "Q 289+532 T 821 \u001b[92m☑\u001b[0m 821 \n", "Q 46+21 T 67 \u001b[92m☑\u001b[0m 67 \n", "Q 883+565 T 1448 \u001b[92m☑\u001b[0m 1448\n", "Q 53+454 T 507 \u001b[92m☑\u001b[0m 507 \n", "Q 60+97 T 157 \u001b[92m☑\u001b[0m 157 \n", "Q 580+4 T 584 \u001b[92m☑\u001b[0m 584 \n", "Q 18+58 T 76 \u001b[92m☑\u001b[0m 76 \n", "Q 57+579 T 636 \u001b[92m☑\u001b[0m 636 \n", "\n", "--------------------------------------------------\n", "Iteration 24\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 160us/step - loss: 0.0595 - acc: 0.9844 - val_loss: 0.0280 - val_acc: 0.9964\n", "Q 6+323 T 329 \u001b[91m☒\u001b[0m 339 \n", "Q 515+237 T 752 \u001b[92m☑\u001b[0m 752 \n", "Q 90+8 T 98 \u001b[92m☑\u001b[0m 98 \n", "Q 256+88 T 344 \u001b[92m☑\u001b[0m 344 \n", "Q 566+500 T 1066 \u001b[92m☑\u001b[0m 1066\n", "Q 9+739 T 748 \u001b[92m☑\u001b[0m 748 \n", "Q 3+500 T 503 \u001b[92m☑\u001b[0m 503 \n", "Q 782+527 T 1309 \u001b[92m☑\u001b[0m 1309\n", "Q 97+422 T 519 \u001b[92m☑\u001b[0m 519 \n", "Q 87+65 T 152 \u001b[92m☑\u001b[0m 152 \n", "\n", "--------------------------------------------------\n", "Iteration 25\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 169us/step - loss: 0.0210 - acc: 0.9986 - val_loss: 0.0251 - val_acc: 0.9962\n", "Q 25+726 T 751 \u001b[92m☑\u001b[0m 751 \n", "Q 740+81 T 821 \u001b[92m☑\u001b[0m 821 \n", "Q 135+70 T 205 \u001b[92m☑\u001b[0m 205 \n", "Q 865+452 T 1317 \u001b[92m☑\u001b[0m 1317\n", "Q 13+51 T 64 \u001b[92m☑\u001b[0m 64 \n", "Q 13+908 T 921 \u001b[92m☑\u001b[0m 921 \n", "Q 90+637 T 727 \u001b[92m☑\u001b[0m 727 \n", "Q 224+25 T 249 \u001b[92m☑\u001b[0m 249 \n", "Q 769+98 T 867 \u001b[92m☑\u001b[0m 867 \n", "Q 951+412 T 1363 \u001b[92m☑\u001b[0m 1363\n", "\n", "--------------------------------------------------\n", "Iteration 26\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 8s 171us/step - loss: 0.0182 - acc: 0.9987 - val_loss: 0.0244 - val_acc: 0.9960\n", "Q 56+728 T 784 \u001b[92m☑\u001b[0m 784 \n", "Q 65+26 T 91 \u001b[92m☑\u001b[0m 91 \n", "Q 252+556 T 808 \u001b[92m☑\u001b[0m 808 \n", "Q 843+187 T 1030 \u001b[92m☑\u001b[0m 1030\n", "Q 93+468 T 561 \u001b[92m☑\u001b[0m 561 \n", "Q 56+242 T 298 \u001b[92m☑\u001b[0m 298 \n", "Q 206+8 T 214 \u001b[92m☑\u001b[0m 214 \n", "Q 659+119 T 778 \u001b[92m☑\u001b[0m 778 \n", "Q 377+113 T 490 \u001b[92m☑\u001b[0m 490 \n", "Q 497+67 T 564 \u001b[92m☑\u001b[0m 564 \n", "\n", "--------------------------------------------------\n", "Iteration 27\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 149us/step - loss: 0.0437 - acc: 0.9883 - val_loss: 0.0259 - val_acc: 0.9952\n", "Q 321+268 T 589 \u001b[92m☑\u001b[0m 589 \n", "Q 202+37 T 239 \u001b[92m☑\u001b[0m 239 \n", "Q 860+410 T 1270 \u001b[92m☑\u001b[0m 1270\n", "Q 33+371 T 404 \u001b[92m☑\u001b[0m 404 \n", "Q 581+85 T 666 \u001b[92m☑\u001b[0m 666 \n", "Q 66+841 T 907 \u001b[92m☑\u001b[0m 907 \n", "Q 653+28 T 681 \u001b[92m☑\u001b[0m 681 \n", "Q 284+357 T 641 \u001b[92m☑\u001b[0m 641 \n", "Q 969+128 T 1097 \u001b[92m☑\u001b[0m 1097\n", "Q 308+9 T 317 \u001b[92m☑\u001b[0m 317 \n", "\n", "--------------------------------------------------\n", "Iteration 28\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n", "45000/45000 [==============================] - 7s 164us/step - loss: 0.0145 - acc: 0.9991 - val_loss: 0.0176 - val_acc: 0.9976\n", "Q 390+693 T 1083 \u001b[92m☑\u001b[0m 1083\n", "Q 67+53 T 120 \u001b[92m☑\u001b[0m 120 \n", "Q 8+839 T 847 \u001b[92m☑\u001b[0m 847 \n", "Q 614+9 T 623 \u001b[92m☑\u001b[0m 623 \n", "Q 792+841 T 1633 \u001b[92m☑\u001b[0m 1633\n", "Q 291+688 T 979 \u001b[92m☑\u001b[0m 979 \n", "Q 9+311 T 320 \u001b[92m☑\u001b[0m 320 \n", "Q 801+14 T 815 \u001b[92m☑\u001b[0m 815 \n", "Q 937+1 T 938 \u001b[92m☑\u001b[0m 938 \n", "Q 978+821 T 1799 \u001b[92m☑\u001b[0m 1799\n", "\n", "--------------------------------------------------\n", "Iteration 29\n", "Train on 45000 samples, validate on 5000 samples\n", "Epoch 1/1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "45000/45000 [==============================] - 7s 161us/step - loss: 0.0118 - acc: 0.9993 - val_loss: 0.0133 - val_acc: 0.9986\n", "Q 67+717 T 784 \u001b[92m☑\u001b[0m 784 \n", "Q 3+26 T 29 \u001b[92m☑\u001b[0m 29 \n", "Q 149+952 T 1101 \u001b[92m☑\u001b[0m 1101\n", "Q 897+453 T 1350 \u001b[92m☑\u001b[0m 1350\n", "Q 93+308 T 401 \u001b[92m☑\u001b[0m 401 \n", "Q 3+197 T 200 \u001b[92m☑\u001b[0m 200 \n", "Q 7+421 T 428 \u001b[92m☑\u001b[0m 428 \n", "Q 354+220 T 574 \u001b[92m☑\u001b[0m 574 \n", "Q 595+938 T 1533 \u001b[92m☑\u001b[0m 1533\n", "Q 97+782 T 879 \u001b[92m☑\u001b[0m 879 \n" ] } ], "source": [ "for iteration in range(1, 30):\n", " print()\n", " print('-' * 50)\n", " print('Iteration', iteration)\n", " model.fit(x_train, y_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=1,\n", " validation_data=(x_val, y_val))\n", " \n", " for i in range(10):\n", " ind = np.random.randint(0, len(x_val))\n", " rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]\n", " preds = model.predict_classes(rowx, verbose=0)\n", " \n", " q = ctable.decode(rowx[0])\n", " correct = ctable.decode(rowy[0])\n", " guess = ctable.decode(preds[0], calc_argmax=False)\n", " print('Q', q[::-1] if INVERT else q, end=' ')\n", " print('T', correct, end=' ')\n", " if correct == guess:\n", " print(colors.ok + '☑' + colors.close, end=' ')\n", " else:\n", " print(colors.fail + '☒' + colors.close, end=' ')\n", " print(guess)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我們可以看到在30次的訓練循環之後,我們己經可以在驗證準確性上達到99.8%的程度。\n", "\n", "以上方法的一個先行條件是它假設:給定固定長度的序列當輸入[... t]有可能生成固定長度的目標[...t]序列。\n", "\n", "這在某些情況下可行,但不適用於大多數使用情境。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一般情境:序列到序列(seq-to-seq)的典型範例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在一般情況下,輸入序列和輸出序列具有不同的長度(例如機器翻譯),並且為了開始預測目標,需要整個輸入序列。這需要更高級的設置,這是人們在沒有更多的上下文的情況下提到“序列到序列模型”時經常提到的。這是如何工作的:\n", "1. RNN層(或多個RNN層的堆疊)作為“編碼器(encoder)”:它處理輸入序列並返回其自身的內部狀態。請注意,我們丟棄編碼器RNN的輸出,只保留它的內部狀態。這個狀態將作為下一步解碼器的“上下文”或“條件”。\n", "2. 另一個RNN層(或多個RNN層的堆疊)充當“解碼器(decoder)”:對給定的目標序列的先前字符進行訓練,以預測目標序列的下一個字符。具體而言,訓練是將目標序列轉換成相同的序列偏移(offset)一個步驟的過程,這種情況稱為“教師強制(teacher forcing)”的訓練過程。重要的是,解碼器(decoder)使用來自編碼器(encoder)的狀態向量作為初始狀態,這是解碼器如何獲得關於它應該產生何種產出的關鍵資訊。實際上,解碼器學習以輸入序列為條件生成給定目標[... t]的目標[t + 1 ...]。\n", "\n", "![teacher_forcing](https://blog.keras.io/img/seq2seq/seq2seq-teacher-forcing.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在預測模式下,當我們想要解碼(decode)未知的輸入序列時,我們經歷一個稍微不同的過程:\n", "1. 將輸入序列編碼成狀態向量(hidden state vector)。\n", "2. 從大小為1的目標序列開始(只是開始序列字符)。\n", "3. 將狀態向量和1-char目標序列饋送給解碼器以產生下一個字符的預測。\n", "4. 使用這些預測對下一個字符進行採樣(我們簡單地使用argmax)。\n", "5. 將採樣的字符附加到目標序列。\n", "6. 重複,直到我們拿到生成序列結束字符或我們達到字符限制。\n", "\n", "![seq-to-seq-decoder](https://blog.keras.io/img/seq2seq/seq2seq-inference.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "也可以使用相同的過程來訓練Seq2Seq網絡,而不需要“教師強制”,即通過將解碼器的預測重新輸入到解碼器中。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "讓我們用實際的程式碼來說明這些想法。\n", "\n", "為了實現我們的範例,我們將使用英語句子對應的中文語句翻譯的數據集,您可以從[[manythings.org/anki](http://manythings.org/anki)]下載這些數據集。\n", "要下載的文件被稱為cmn-eng.zip(簡中對應到英文)。為了更貼近學習的效果, 我己經把簡中轉成了繁中的版本(cmn-tw.txt),可以從[Github](https://github.com/erhwenkuo/deep-learning-with-keras-notebooks/blob/master/assets/data/cmn-tw.txt)上取得這個資料檔。我們將實現一個字符級(character-level)的序列到序列模型,逐個字符地處理輸入,並逐個字符地產生輸出。另一個選擇是一個字級(word-level)模型,這個模型往往是機器翻譯更常見的。在這篇文章的最後,你會發現一些關於使用嵌入圖層(embedding layers)將我們的模型轉換為字級模型的參考連結。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 資料準備\n", "* 從[Github](https://github.com/erhwenkuo/deep-learning-with-keras-notebooks/blob/master/assets/data/cmn-tw.txt)下載cmn-tw.txt檔案。\n", "* 在這個Jupyter Notebook所在的目錄下產生一個新的子目錄\"data\"。\n", "* 把下載的資料檔複製到\"data\"的目錄裡頭。\n", "\n", "最後你的目錄結構看起來像這樣:\n", "```\n", "xxx.ipynb\n", "data/ \n", "└── cmn-tw.txt\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下是我們的流程總結:\n", "1. 將句子(sentence)轉換為3個Numpy數組, `encoder_input_data`, `decoder_input_data`, `decoder_target_data`:\n", " * `encoder_input_data`是包含英文句子的one-hot向量化的三維形狀數組(num_pairs, max_english_sentence_length, num_english_characters)。\n", " * `decoder_input_data`是包含中文句子的one-hot向量化的三維形狀數組(num_pairs, max_chinese_sentence_length, num_chinese_characters)。\n", " * `decoder_target_data`與`decoder_input_data`相同,但是偏移了一個時間步長。 `decoder_target_data` [:,t,:]將與`decoder_input_data` [:,t+1,:]相同。\n", " \n", "2. 訓練一個基本的基於LSTM的Seq2Seq模型來預測給出`encoder_input_data`和`decoder_input_data`的`decoder_target_data`。我們的模型使用教師強制(teacher forcing)的手法。\n", "3. 解碼一些句子以檢查模型是否正常工作(將來自`encoder_input_data`的樣本轉換為來自`decoder_target_data`的對應樣本)。\n", "\n", "\n", "整個網絡的架構構建可以參考以下的圖示:\n", "\n", "![seq2seq](https://upload-images.jianshu.io/upload_images/1667471-dc52883e89b07014.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/646)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 1. 引入相關的函數庫" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from keras.models import Model\n", "from keras.layers import Input, LSTM, Dense\n", "import numpy as np\n", "import os\n", "\n", "# 專案的根目錄路徑\n", "ROOT_DIR = os.getcwd()\n", "\n", "# 置放訓練資料的目錄\n", "DATA_PATH = os.path.join(ROOT_DIR, \"data\")\n", "\n", "# 訓練資料檔\n", "DATA_FILE = os.path.join(DATA_PATH, \"cmn-tw.txt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 2. 相關的參數" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "batch_size = 64 # 訓練時的批次數量\n", "epochs = 100 # 訓練循環數\n", "latent_dim = 256 # 編碼後的潛在空間的維度(dimensions of latent space)\n", "num_samples = 10000 # 用來訓練的樣本數" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 3.資料的前處理" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of samples: 10000\n", "Number of unique input tokens: 73\n", "Number of unique output tokens: 2165\n", "Max sequence length for inputs: 33\n", "Max sequence length for outputs: 22\n" ] } ], "source": [ "# 資料向量化\n", "input_texts = []\n", "target_texts = []\n", "input_characters = set() # 英文字符集\n", "target_characters = set() # 中文字符集\n", "lines = open(DATA_FILE, mode=\"r\", encoding=\"utf-8\").read().split('\\n')\n", "\n", "# 逐行的讀取與處理\n", "for line in lines[: min(num_samples, len(lines)-1)]:\n", " input_text, target_text = line.split('\\t')\n", " \n", " # 我們使用“tab”作為“開始序列[SOS]”字符或目標,“\\n”作為“結束序列[EOS]”字符。 <-- **重要\n", " target_text = '\\t' + target_text + '\\n'\n", " \n", " input_texts.append(input_text)\n", " target_texts.append(target_text)\n", " \n", " for char in input_text:\n", " if char not in input_characters:\n", " input_characters.add(char)\n", " for char in target_text:\n", " if char not in target_characters:\n", " target_characters.add(char)\n", " \n", "input_characters = sorted(list(input_characters)) # 全部輸入的字符集\n", "target_characters = sorted(list(target_characters)) # 全部目標字符集\n", "\n", "num_encoder_tokens = len(input_characters) # 所有輸入字符的數量\n", "num_decoder_tokens = len(target_characters) # 所有輸目標字符的數量\n", "\n", "max_encoder_seq_length = max([len(txt) for txt in input_texts]) # 最長的輸入句子長度\n", "max_decoder_seq_length = max([len(txt) for txt in target_texts]) # 最長的目標句子長度\n", "\n", "print('Number of samples:', len(input_texts))\n", "print('Number of unique input tokens:', num_encoder_tokens)\n", "print('Number of unique output tokens:', num_decoder_tokens)\n", "print('Max sequence length for inputs:', max_encoder_seq_length)\n", "print('Max sequence length for outputs:', max_decoder_seq_length)\n", "\n", "# 輸入字符的索引字典\n", "input_token_index = dict(\n", " [(char, i) for i, char in enumerate(input_characters)])\n", "\n", "# 輸目標字符的索引字典\n", "target_token_index = dict(\n", " [(char, i) for i, char in enumerate(target_characters)])\n", "\n", "# 包含英文句子的one-hot向量化的三維形狀數組(num_pairs,max_english_sentence_length,num_english_characters)\n", "encoder_input_data = np.zeros(\n", " (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n", " dtype='float32')\n", "\n", "# 包含中文句子的one-hot向量化的三維形狀數組(num_pairs,max_chinese_sentence_length,num_chinese_characters)\n", "decoder_input_data = np.zeros(\n", " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", " dtype='float32')\n", "\n", "# decoder_target_data與decoder_input_data相同,但是偏移了一個時間步長。 \n", "# decoder_target_data [:, t,:]將與decoder_input_data [:,t + 1,:]相同\n", "decoder_target_data = np.zeros(\n", " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", " dtype='float32')\n", "\n", "# 把資料轉換成要用來訓練用的張量資料結構 <-- 重要\n", "for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n", " for t, char in enumerate(input_text):\n", " encoder_input_data[i, t, input_token_index[char]] = 1.\n", " \n", " for t, char in enumerate(target_text):\n", " # decoder_target_data is ahead of decoder_input_data by one timestep\n", " decoder_input_data[i, t, target_token_index[char]] = 1.\n", " if t > 0:\n", " # decoder_target_data will be ahead by one timestep\n", " # and will not include the start character.\n", " decoder_target_data[i, t - 1, target_token_index[char]] = 1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 4.構建網絡架構\n", "\n", "![seq2seq_translation](https://camo.githubusercontent.com/44a4c60ee9446a14effc6057a16c9f12b61102b5/68747470733a2f2f692e696d6775722e636f6d2f4e7a766c4733582e706e67)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "encoder_input (InputLayer) (None, None, 73) 0 \n", "__________________________________________________________________________________________________\n", "decoder_input (InputLayer) (None, None, 2165) 0 \n", "__________________________________________________________________________________________________\n", "encoder_lstm (LSTM) [(None, 256), (None, 337920 encoder_input[0][0] \n", "__________________________________________________________________________________________________\n", "decoder_lstm (LSTM) [(None, None, 256), 2480128 decoder_input[0][0] \n", " encoder_lstm[0][1] \n", " encoder_lstm[0][2] \n", "__________________________________________________________________________________________________\n", "decoder_output (Dense) (None, None, 2165) 556405 decoder_lstm[0][0] \n", "==================================================================================================\n", "Total params: 3,374,453\n", "Trainable params: 3,374,453\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "# ===== 編碼 (encoder) ====\n", "\n", "# 定義輸入的序列\n", "# 注意:因為輸入序列長度(timesteps)可變的情況,使用input_shape =(None,num_features)\n", "encoder_inputs = Input(shape=(None, num_encoder_tokens), name='encoder_input') \n", "encoder = LSTM(latent_dim, return_state=True, name='encoder_lstm') # 需要取得LSTM的內部state, 因此設定\"return_state=True\"\n", "encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n", "\n", "# 我們拋棄掉`encoder_outputs`因為我們只需要LSTM cell的內部state參數\n", "encoder_states = [state_h, state_c]\n", "\n", "# ==== 解碼 (decoder) ====\n", "\n", "# 設定解碼器(decoder)\n", "# 注意:因為輸出序列的長度(timesteps)是變動的,使用input_shape =(None,num_features)\n", "decoder_inputs = Input(shape=(None, num_decoder_tokens), name='decoder_input')\n", "\n", "# 我們設定我們的解碼器回傳整個輸出的序列同時也回傳內部的states參數\n", "decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True, name='decoder_lstm')\n", "\n", "# 在訓練時我們不會使用這些回傳的states, 但是在預測時我們會用到這些states參數\n", "# **解碼器的初始狀態是使用編碼器的最後的狀態(states)**\n", "decoder_outputs, _, _ = decoder_lstm(decoder_inputs,\n", " initial_state=encoder_states) #我們使用`encoder_states`來做為初始值(initial state) <-- 重要\n", "\n", "# 接密集層(dense)來進行softmax運算每一個字符可能的機率\n", "decoder_dense = Dense(num_decoder_tokens, activation='softmax', name='decoder_output')\n", "decoder_outputs = decoder_dense(decoder_outputs)\n", "\n", "# 定義一個模型接收encoder_input_data` & `decoder_input_data`做為輸入而輸出`decoder_target_data`\n", "model = Model([encoder_inputs, decoder_inputs], decoder_outputs)\n", "\n", "# 打印出模型結構\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAccAAAFgCAYAAADO5bLkAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nO3df3Ac5X3H8c9ZsvkVYuOA8TiNUwLYaSlR2rRYNiQG10ANnDCNf0lCgAG7p8ykDR3TAj2F\ntGaGyXAeGOpWimQyASpLsR1sdIBpplaIiS2Z4M65TIbIA05PxG1OnU7ukrjU9Y+nfzi72XvuTrqT\n7m7vpPdr5sa+3Wd3v7fa3c/t7nN3AWOMEQAAcE3zuwAAACoN4QgAgIVwBADAQjgCAGCptQf87Gc/\n00MPPaQzZ874UQ8wpbS0tCgYDPpdBgBLxpljf3+/ent7/agFmFJ27tzJvgZUqIwzR8eOHTvKWQcw\n5TQ3N/tdAoAcuOcIAICFcAQAwEI4AgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICFcAQA\nwEI4AgBgIRwBALAQjgAAWAhHAAAsUyYcR0ZG1Nvbq4aGhrIts62tTW1tbWVbHgCgOHL+nuNk8/jj\nj6ujo8PvMsoqlUpp1qxZMsbkPU0gEMg6vJB5FItdfyXVBmBymzJnju3t7WVf5ubNm7V58+ayL9ex\nf//+gqcxxiiZTLrPk8mkb+Fj12+MUSKRcJ/7WRuAyW3KhONUk0ql1NXVNa5pZ86cmfX/5ZSr/jlz\n5rj/96s2AJNf0cJxZGREW7ZsUSAQUENDg/r7+93h3nt90WjUbTM8PJw2j1Qqpd7eXgUCAQUCgawH\nx2xtRkZGRm3X0NCgo0ePFlx3NBpVQ0ODUqmUWltbC7p/aL/ufNaDd5mS1NXVpUAgoNbW1rT6ndfu\nvcxoD4tEIopGo2njpPHfB62U+gvhBKwzfVtbW9rf23ls2bLFncY7zvu6SrGNAKhgxtLd3W2yDB5V\nIpEwwWDQ9PT0GGOM2bdvn5FkYrGYCQaDRpKRZAYGBowxxsTjcSPJhEKhtPkEg0ETDofd56FQKO25\n06azszNtucFg0CSTyYx2oVDIHd7T0+PWMZ66Y7FYRr2j8U5vP8+1Hpzx3jbJZNKEQiEjyQwNDbl1\n26/FmZd3mP3cGGPC4XDGOs3GnrZS6h9tuM1ZbiKRyKh1YGAg6zbovNZEIuHWWqptpKmpyTQ1NeXd\nHkD5FCUcneBJm7HkHoSzHczsYc48nIOSMecOYMFg0H3uHJjsNpLcg5cxxvT19aUdjI05d5DOtcyx\n6raDN1/5HOzzaROLxYwkE4lEJjyv8dZeSfXn+7rC4XBaWNnTRSIRI8nE4/G0Wr3bUim3EcIRqFxF\nCUfvO2j7YUx+B0JnHqNxzgS8nNDzhmi2dqMts5C6C1GscCz2vMZTeyXVX+jrisfjbhB6p3NC27kS\nYcy5wPSGZSm3EcIRqFwBY9K7+23fvl3Nzc3j6v6fa5ps4+1hY81jtDb5zqvQZeZT02jyWV4xax/P\nOs239kqqv5DX1dXVpWg0qkgkooULF2ZM19raqo6ODreH7iOPPJLWs7mU20hzc7Mkqbu7u+BpAZRW\nUXur5ur0ko9gMChJOnLkyJhtsnXACYVC4172ROoup4m8xkpQrvpbW1slSb29vdq4caO2bt2qBQsW\njFrT3r17tX//ft17771Z21XLNgKgOIoSjp2dnZKkF198UalUStJvevjlywm+jo4Odx7Dw8PugU6S\nmpqaJEnHjh1zhzltV69enVHPaEFbrLrLwTkw33bbbT5XMj7lrH9wcFBLly6VJDU2NkqS5s+fn7N9\nXV2dQqGQGhsb1dXVpfr6+rTx1bKNACgy+zrreHurKss9mXg8njbO6bTg7Rxj9wr0Th8KhTI61Ti9\nU53penp6MnoIOj0Tg8Gge//I6czjzLeQusfDO30ikch7PTjPnU4hyWTShMPhtHuqxpiMHqBOxyTv\n63PWZyKRcDvD5NNb1VuXU2ul1D/a38WZRywWS5s+Ho+boaGhjFrt6bz3Hh2l3Ea45whUrqKEozHn\nAikcDrsHNyeU7INKrmHGnDsQOfMIh8Npweht09nZmXYQztZTMB6PuwfgUCiU1iXfe3DMp277wJ6P\nbAfUfNaD83/vRwU6OzszXmM8HnfH9/X1GWNMxutzOpyEw2F32FjhOFbdftafb23Osuzpnd6r3g43\njmAwmHV7c2otxTZCOAKVqygdclA8E+0E5LdqrD+VSmV0xCkHOuQAlYuvj8OUt2PHjrR71gBAOFYQ\nby/cbD1yK1011d/W1pb2NXHLli3zuyQAFWTK/GRVseT7HZ/juax4+eWXp/2/mi5NStVVv9ODtbOz\nUxs2bPC5GgCVhnAsUCkP+JUcJvmopvo3bNhAKALIicuqAABYCEcAACyEIwAAFsIRAAAL4QgAgIVw\nBADAQjgCAGAhHAEAsBCOAABYCEcAACyEIwAAFsIRAAAL4QgAgCXnr3KsWbOmnHUAU87OnTvV1NTk\ndxkAssg4c1y2bJnWrVvnRy2YgP3791f8Dwwj3erVq9nXgAoVMNX0I3zIKRAIqLu7mzMRACgC7jkC\nAGAhHAEAsBCOAABYCEcAACyEIwAAFsIRAAAL4QgAgIVwBADAQjgCAGAhHAEAsBCOAABYCEcAACyE\nIwAAFsIRAAAL4QgAgIVwBADAQjgCAGAhHAEAsBCOAABYCEcAACyEIwAAFsIRAAAL4QgAgIVwBADA\nQjgCAGAhHAEAsBCOAABYCEcAACyEIwAAFsIRAAAL4QgAgIVwBADAQjgCAGAJGGOM30WgMN/5znf0\n6KOPat68ee6wAwcOaOHChbr00kslSclkUjfccIO2bt3qV5kAULUIxyrU1tamJ554Iq+2/HkBoHBc\nVq1CjY2NY7aZPn26vva1r5W+GACYhDhzrFK/93u/px/96Eejtvnxj3+shQsXlqkiAJg8OHOsUnff\nfbemT5+edVwgENBnPvMZghEAxolwrFKNjY06ffp01nE1NTW69957y1wRAEweXFatYvX19frhD3+o\ns2fPpg0PBAL64IMP9PGPf9ynygCgunHmWMXuvfdeBQKBtGHTpk3TkiVLCEYAmADCsYqtWrUqY1gg\nENA999zjQzUAMHkQjlXssssu00033aSamhp3WCAQyBqaAID8EY5V7p577nE/6F9TU6Obb75Zs2fP\n9rkqAKhuhGOVW7lypfuRDmOM7r77bp8rAoDqRzhWuYsvvli33367JGnGjBm68847fa4IAKpfbalm\nfPr0afX19enMmTOlWgR+7VOf+pT772uvveZzNVNDfX29PvGJT5Rk3h988IEGBwdLMm8AmbLuz6ZE\ndu/ebSTx4DEpH+vXry/VrmPWr1/v++vjwWMqPbLtzyU7c/yf//kfSeJXITDpNDc36+TJkyWb/8mT\nJ9XU1KTu7u6SLQPAObn2Z+45AgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICFcAQAwEI4\nAgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+FYoJGREfX29qqhoWFSLg+TRyVuO37U1NbWpra2\ntrItD5MD4Vigxx9/XI2NjYpGoxW7vFQqpUAgUMKqfiMQCGR9jGZwcFCtra0KBAJqbW1Vf39/Rs25\n5pvvY3BwcNTlF1JvtSr3tpqPSqyp1MazP45nvyqVfPfNyYZwLFB7e3vFL2///v0lqCQ7Y4wSiYT7\nPJlMjvoD14ODg1q8eLGWLl0qY4za29v1sY99TC0tLRlte3p6ZIxxH95lOo+enh53WDwed9s8//zz\nOWvwjkskEpP2B7nLva3mw4+aNm/erM2bN5d9uY7x7I/GGCWTSff5WPtVKdn1F7rPVyvCcZJJpVLq\n6uoq6zLnzJnj/n/mzJmjtnWCad26de6wurq6rAcvb5tcVqxY4f5//vz5kqRIJKKOjg4NDw9ntB8e\nHtZVV12VtXag2CayP3r3pbH2q1LJVX8h+3y1qrhwHBkZ0ZYtWxQIBNTQ0KD+/n53uPdeRTQaddvY\nB8FUKqXe3l73dD/bHzdbm5GRkVHbNTQ06OjRowXXHY1G1dDQoFQqpdbW1qLc/3CW5dTtXNaIRCLu\nJSvnteVad62tre66c16jd5hU/Ps1x48flyQdOXIkbXhdXV3ac+9Z4GhmzpyZ0Xb58uWSpIMHD2a0\nP3jwoDt+spnotpptPtW+/9jbfj7HEe8yJamrq8vdN7z1Z7ukaA/Ltj9K49+vKqX+QjgB60zf1taW\n9vd2Hlu2bHGn8Y7zvq5yHmNlSqS7u9sUOvtEImGCwaDp6ekxxhizb98+I8nEYjETDAaNJCPJDAwM\nGGOMicfjRpIJhUJp8wkGgyYcDrvPQ6FQ2nOnTWdnZ9pyg8GgSSaTGe1CoZA7vKenx61jPHXHYrGM\nesdiLy8SiZh4PG6MMSaZTJpwOJw23m7vrSEWixljjBkYGHDX3WjrMxwOZ6y7fGrMJRaLuW07Ozsz\n1vdEl+GMD4VCWds6ry3ferNpamoyTU1N45q2lPOf6Lbqnc9k2X+809vPc233znhvm2Qy6W5TQ0ND\nbt32a3HmNdr+aMz496tKqX+04TZnuYlEIqNW73HIFgwGTSKRcGst1TaSa3+rqHB0dhwvSe5GlO2P\nYQ9z5uGsVGPO/QGCwaD73FmxdhtJ7so3xpi+vr60jcmYcxtZrmWOVXchQTDaa7RrdzbyXO0nOmw8\nNY5maGjI3WGcdZ7PuikkHJ2/sXNwMOZcMO/bt6/gem2VGI7F2lanyv4z1nafrY3zxi4SiUx4XuOt\nvZLqz/d1hcPhtLCyp4tEIkaS+4bfqdW7LZVyG6mKcPS+A7AfxuT3h3TmMZpsZxXOTus9COQ6+8i1\nzELqLoQ9vVNXrlCp9HB0DAwMpIVkX1/fhJdh79T2WfBE6nVUYjgWa1udCvtPMQOhmsKx2PUX+rri\n8bgbhN7pnNB2rkQYk351zJjSbiNVEY5jvcBibYgT2XDGs8xi79xDQ0NpG4v3nWCu5VViODqcM5Ox\nArLQcHTebcbjcZNIJNLeiU62cCzntlrt+0+lhMt4aq+k+gt5XZ2dnSYYDJqhoaGs0zlvpJLJpHsJ\nuJBllWJ/rrgOOZJy3rTPRzAYlJTZ4SNbm2wdCEKh0LiXPZG6C7FgwQL19fUpFospFApp06ZNaTez\nK1Fra6ukczf1U6lU2rj6+npt3bpVkor64fAlS5ZIOtcJp7+/332O3NvqVNh/Jmoir7ESlKt+Z5/v\n7e3Vxo0btXXrVi1YsGDUmvbu3av9+/fr3nvvzdqunNtIRYVjZ2enJOnFF190D6BOD6V8OTtuR0eH\nO4/h4WH3DyVJTU1NkqRjx465w5y2q1evzqhntANFseouhBMwdXV1am9vVywW06ZNm0qyrGIYHBzU\n0qVL3eeHDx/OaON8DMP5+xXD/PnzFQ6H1djYqOPHj7vLmIyKta1Ohf1nvJwD82233eZzJeNTzvq9\n+3xjY6Mkjbr/1dXVKRQKqbGxUV1dXaqvr08b78s2Mq7z0DyMt7eqslxTdi6LOc+d+2zem/t2rybv\n9KFQKKNTgNO7zpmup6cn41Te6VkVDAbd699OZwRnvoXUPR7e6Z1apXM3op2anGv5Duf1JxIJE4lE\nsq67bPPNNiyfXnWjvUano4bTG9Jpt2/fvrS/o3MJ1Ntrcqz1kKuNd7xzP8M733zmNZpKvKxajG3V\nGT+Z9598jyPOc+dSvNMr3HtP1RiT0QPU2d69r8/eH43Jb7/y1pVtv/Wz/kL2eWf6eDyedlnV3vec\n6bz3Hh2l3Eaq4p6jMed2KOejCaFQyN2p7JWSa5gx51akM49wOJy2Y3vbdHZ2pm1E2Tq3xONxdwMK\nhUJpXYq9f9x86rY3zHzket3Ohipl3nN0AiEcDmfdqApZn2PtxNk22GwPZ9068x0aGkpb/7n+TqMt\nY6w2jmw95Uab11gqMRyNmfi26pjM+0+h+4L3owLZPnoUj8cz7pfbr8/eH40p3n7lR/2F7vP29E7v\nVXu7c5ad6zhQqm0k1/4W+PXMi2779u1qbm5WiWYP+Ka5uVmS1N3dXZXzx9icD7tX6/GrGutPpVJ6\n5JFHyv4Vg7n2t4q65wgAmJp27NiRds/ab4QjAHh4e+Fm65Fb6aqp/ra2trSviVu2bJnfJblq/S5g\nqsr3Owqr6bIIUC6l3H8uv/zytP9X2z5YTfU7PVg7Ozu1YcMGn6tJRzj6pJI3WKDSlXL/qfZ9s5rq\n37BhQ8WFooPLqgAAWAhHAAAshCMAABbCEQAAC+EIAICFcAQAwEI4AgBgIRwBALAQjgAAWAhHAAAs\nhCMAABbCEQAAC+EIAICl5L/KsXPnzlIvAiirnTt3lvxHWXfu3KmVK1eWdBkAcu/PJQvHq666SpK0\nZs2aUi0C8M0VV1xR0nmfOnWKfQcok2z7c8BU049/oSL84Ac/0F133aXf/u3f1ssvv6x58+b5XRLg\nOnPmjB5++GE988wz+upXv6qvfe1rfpeEKkQ4Ylzef/993XHHHfrVr36laDSqz372s36XBOgXv/iF\n1q1bpzfeeEPf/OY3tW7dOr9LQpWiQw7G5corr9TBgwe1cOFCff7zn1c0GvW7JExxx44d05IlS3Tk\nyBG98cYbBCMmhHDEuF1yySXau3evmpqadNddd+mpp57yuyRMUW+++aYWLVqk8847T4cOHdJ1113n\nd0mocoQjJmT69On6xje+oaeeekqPPvqoNmzYoFOnTvldFqaQb37zm1q+fLmWLl2q/fv367d+67f8\nLgmTAOGIonjooYe0e/du9fb2asWKFfr5z3/ud0mY5JyONw8++KAefvhh7dy5UxdddJHfZWGSoEMO\niioWi+nOO+/U+eefr1deeUVXX3213yVhEvrlL3+ppqYm/cu//Iuee+45NTU1+V0SJhnOHFFUn/3s\nZ3Xo0CHNmjVL9fX1+v73v+93SZhkfvKTn+j666/X22+/re9973sEI0qCcETRzZ07V2+88YaWLVum\nW265Rd/61rf8LgmTxA9+8APV19erpqZGhw4dUn19vd8lYZIiHFESF1xwgXbs2KGHH35Y999/v/7q\nr/5KZ8+e9bssVLHnn39ey5cv15IlS/Tmm29q/vz5fpeESYxwRMkEAgE98cQTeuGFF/Tss89q1apV\nOnHihN9locqcPXtWf/3Xf6377rtPDz30kL7zne/oIx/5iN9lYZKjQw7KwvnKuU9+8pPq6+vjK+eQ\nl1/96le6++679frrr6urq0stLS1+l4QpgnBE2bz//vu68847lUwm9fLLL+tzn/uc3yWhgg0PD6uh\noUH/+Z//qd27d2vJkiV+l4QphMuqKJsrr7xSBw4c0DXXXKOlS5dq9+7dfpeECjUwMKDrrrtOZ8+e\n1VtvvUUwouwIR5TVzJkz9eqrr6qlpUWrVq3iK+eQ4Z/+6Z+0bNkyXXfddTp48KA++clP+l0SpiDC\nEWVXW1ur9vZ2Pf3003r00Ud1//336//+7//8Lgs+M8boscce0z333KMvf/nL2rNnDx1v4BvuOcJX\ne/fu1dq1a/UHf/AHeumllzR79my/S4IPTpw4oZaWFr322mvq6OjQfffd53dJmOIIR/junXfe0R13\n3KHzzjtPr776Kl85N8X89Kc/VTAY1PHjx/XSSy/phhtu8LskgMuq8N+1116rQ4cO6dJLL9WiRYu0\nb98+v0tCmRw6dEh/9Ed/pNOnT+vQoUMEIyoG4YiKMHfuXPX39+vWW2/VihUrtG3bNr9LQon19PTo\nxhtv1Oc+9zkdPHhQV1xxhd8lAS7CERXj/PPP1/bt2/XII49o48aNfOXcJGWMUVtbm5qbm/WlL31J\nL7/8si6++GK/ywLScM8RFWn79u164IEHdPPNN2v79u30WpwkTpw4ofvuu099fX36h3/4Bz344IN+\nlwRkRTiiYg0MDGjlypWaN2+eotEov/Be5X76059q5cqVisfj2rVrl5YuXep3SUBOXFZFxVq8eLEG\nBwd16tQpLVq0SIcPH/a7JIzT22+/rUWLFunDDz/U4OAgwYiKRziiol1xxRU6cOCA6urq9PnPf167\ndu3yuyQUaMeOHfrCF76gz3zmMzp48KCuvPJKv0sCxkQ4ouLNnDlT0WhUDzzwgNasWaMnn3zS75KQ\nB2OM/vZv/1br1q3Txo0b9corr2jmzJl+lwXkpdbvAoB81NTU6O///u+1cOFCfeUrX9HRo0f1jW98\nQzNmzPC7NGTx4Ycfav369XrppZfU0dGhjRs3+l0SUBA65KDq/PM//7PWrl2ra6+9Vrt379all17q\nd0nw+I//+A+tXLlS77//vnbt2qWbbrrJ75KAgnFZFVXn1ltv1YEDB3T8+HHV19frxz/+sd8l4dcO\nHz6sRYsW6Ze//KUOHTpEMKJqEY6oStdcc40GBwc1Z84cLVmyhK+cqwDOxzN+93d/VwMDA7rqqqv8\nLgkYN8IRVWvOnDnq7+/X7bffrj/5kz9RR0eH3yVNScYYPfHEE1qzZo3Wr1+vV199VbNmzfK7LGBC\n6JCDqnb++efrhRde0IIFC/SlL31JR48e1VNPPaWamhq/S5sS/vd//1f333+/du7cqX/8x39UKBTy\nuySgKOiQg0mjt7dX69ev5yvnyuRnP/uZVq5cqaNHj2rHjh1avny53yUBRUM4YlIZHBzUXXfdpcsu\nu0yvvPKK5s+f73dJk1IsFlMwGNQFF1ygV155RQsWLPC7JKCouOeISaW+vl6HDh2SJC1atEhvvfWW\nzxVNPrt379YNN9yghQsX6tChQwQjJiXCEZPO/PnzdeDAAf3+7/++brzxRu3cuTNn2w8//LCMlVW+\nU6dOjTr+ySef1Be/+EW1tLTo9ddf1yWXXFKmyoDyIhwxKV188cWKRqP6sz/7M61du1abN2+WfQfh\ntdde04UXXqhnn33Wpyorz/XXX69AIKD/+q//Sht+8uRJtbS06Ktf/aqeffZZtbe3q7aW/nyYvLjn\niEmvo6NDX/7yl7Vu3Tpt27ZN5513nt555x0tXrxYJ06c0Ec+8hEdO3ZMl112md+l+mrv3r267bbb\nJJ27JL1//37NmDFDIyMjWrlypd599119+9vf1i233OJzpUDpEY6YEr773e9qzZo1uvbaa9XV1aU/\n/uM/1sjIiE6fPq3p06erpaVFzz33nN9l+ubkyZP69Kc/rQ8++EBnzpxRbW2t1q1bp4cffljBYFAz\nZsxQNBrVpz/9ab9LBcqCcMSU8e677+qOO+5QKpXSL37xi7T7a4FAQG+99Zb+8A//0McK/fPkk0+q\nra1NZ86ccYcFAgFddtlluuaaa7Rr1y7Nnj3bxwqB8iIcMWUYY7Rq1Sq9/PLLaSEgSbW1taqrq9MP\nf/hDBQIBnyr0x/DwsBYsWKCTJ09mjJs2bZp27dqlu+66y4fKAP/QIQdTxt/93d9pz549GcEoSadP\nn9a//uu/6oUXXvChMn899NBDOnv2bNZxxhg1NzfrnXfeKXNVgL84c8SU0N3drZaWloweq16BQECz\nZ8/WsWPH9NGPfrSM1fnnu9/9rm699dZR29TW1urSSy/Vv/3bv035TkuYOghHTHpnz57N+7tWa2tr\n9Rd/8ReKRCIlrsp/p06d0u/8zu/o3//937OeTdtuvPFGfe973ytDZYD/uKyKSW/atGnav3+/Ghsb\nNWPGDNXU1GjatOyb/unTp/XMM89Mid+I3LJly6jBWFNTo0AgoIsuukgPPPCAtm3bVuYKAf9w5ogp\nJZVK6dvf/ra6urr09ttva/r06RnfCjN9+nQtXrxY3//+932qsvSOHz+uq6++Ous3BNXW1urs2bO6\n+eabdf/997vfoQpMJYQjpqwf/ehHeu655/Stb31LyWRSNTU1On36tDt+x44dWr16tY8Vls6aNWu0\nZ88e942B8ybhyiuv1IYNG9TS0qJ58+b5XCXgH8IRU96pU6f02muvadu2bdq7d68CgYAbkidOnNCF\nF17oc4XF9frrr2vFihWSznVCuvDCC9XU1KT169dr8eLFPlcHVAbCEVm99dZbWrRokd9lACXxN3/z\nN3riiSf8LgMVjG8ORlbvvfeepHOXFqeqeDyuuXPn6rzzzvO7lKL67//+b506dUpz5871uxRfNDc3\n6yc/+YnfZaDCEY4Y1WS954apa8+ePX6XgCrARzkAALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICF\ncAQAwEI4AgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICFcERJjYyMqLe3Vw0NDX6X4ip3\nTZW4DgCMjnBEST3++ONqbGxUNBr1uxRXuWsaz/JSqZQCgUAJq/qNQCCQ9TGawcFBtba2KhAIqLW1\nVf39/Rk155pvvo/BwcFRl19IvUChCEeUVHt7u98lZCh3TeNZ3v79+0tQSXbGGCUSCfd5MpmUMSZn\n+8HBQS1evFhLly6VMUbt7e362Mc+ppaWloy2PT09Msa4D+8ynUdPT487LB6Pu22ef/75nDV4xyUS\niVHrBcaDcAQqTCqVUldXV1mXOWfOHPf/M2fOHLWtE0zr1q1zh9XV1Wnz5s0Zbb1tclmxYoX7//nz\n50uSIpGIOjo6NDw8nNF+eHhYV111VdbagWIhHFFUqVRKvb29CgQCamho0NGjR7O2GxkZ0ZYtW9x2\n/f39OecTCASyhkW2NiMjI0WvaWRkRNFoVA0NDUqlUmptbVVbW1uhqyaDsyynbufSYCQScS/BOq/N\nvm8ZjUbdS5pOgDiv0TtMktra2opSr+P48eOSpCNHjqQNr6urS3vuPQsczcyZMzPaLl++XJJ08ODB\njPYHDx50xwMlY4Asuru7zXg2j2AwaEKhkEkmk8YYY3p6eoyktHklEgkTDAZNT0+PMcaYffv2GUkm\nFoulzSccDrvPQ6FQ2nOnTWdnZ9o8g8Ggu+xi1RQMBt32AwMDJhaLmVAoVNB6sZcXiURMPB43xhiT\nTCZNOBxOG2+399bgrKeBgQEjyYRCITMwMGCMMSYej7vDHOFwOGPd5VNjLrFYzG3b2dmZsb4nugxn\nfCgUytrWeW351mtramoyTU1NBU+HqYVwRFbjCce+vj4jyQwNDbnDkslkxq4kGUEAAAu7SURBVEHM\nCScvSe4B3BmfSCTc8QMDAyYYDLrPnfCy20hyA66YNTntCwkCe152+HlrTyQSo4bjRIeNp8bRDA0N\nueHlrPN81k0h4ej8jZ3gN+ZcMO/bt6/ger0IR+SDcERW4wnHXO/0RzsLsh/e8YUuywk9b4gWq6bx\nHohzLc+pK1eoVHo4OgYGBtJCsq+vb8LLsN8k2GfBE6nXGMIR+SEckdV4wjHXwSrbWdNo8873ADqR\nZZWipkLqHRoaSgvkSCQy5vIqMRwdzpn9WAFZaDg6Z/TxeNwkEom0qwKEI0qJDjnwTa6OMcFgUFJm\nh49sbbJ1wAmFQkWvqdgWLFigvr4+xWIxhUIhbdq0SVu2bCnLssertbVV0rlOQqlUKm1cfX29tm7d\nKklF/bKDJUuWSDrXCae/v999DpQa4Yii6ezslDR6qHnbvfjii+5B1ukpKv0m+Do6Otzxw8PD7sFZ\nkpqamiRJx44dc4c5bVevXl30morNCZi6ujq1t7crFotp06ZNJVlWMQwODmrp0qXu88OHD2e0cT6G\n4fz9imH+/PkKh8NqbGzU8ePH3WUAJef3qSsq03guqzo9JYPBoNsT0+lUIc+9I6fzif1wpnF6jnrH\nhUKhjE41Tu9Up2NLT09PRi/SYtTkHTce3umdWqVznX2cmuLxeNqlVef1JxIJE4lE0ubh3KPMNt9s\nw/LprTraa3Q6Ojm9ZJ12+/btc2tJJpPuJVBvr+Ox1kOuNt7xTu9Y73zzmVcuXFZFPghHZDXej3LE\n43G3g0YoFEr7iIT3IBaPx92PL4RCITckHIlEwh0fDofTgtHbprOz0z1I5urcMtGavGHp7eyTLztw\nnWFO8CnLPUcnEMLhcNbgHm2+9rCxwjHbm4JsD2fdOvMdGhpKW/+5/k6jLWOsNg7vm5585jUawhH5\nCBjD9y4h0/bt29Xc3MzXcmHSaW5uliR1d3f7XAkqGfccAQCwEI4AAFhq/S4AqFb5/kwSl6aB6kM4\nAuNE6AGTF5dVAQCwEI4AAFgIRwAALIQjAAAWwhEAAAvhCACAhXAEAMBCOAIAYCEcAQCwEI4AAFgI\nRwAALIQjAAAWwhEAAAu/yoGsLrzwQkn5/ywTUE3Wr1/vdwmocAHD7+4gi9OnT6uvr09nzpzxu5Qp\nYc2aNfrzP/9z3XDDDX6XMiXU19frE5/4hN9loIIRjkAFCAQC6u7uVlNTk9+lABD3HAEAyEA4AgBg\nIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICFcAQAwEI4AgBgIRwBALAQjgAAWAhHAAAshCMA\nABbCEQAAC+EIAICFcAQAwEI4AgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICFcAQAwEI4\nAgBgIRwBALAQjgAAWAhHAAAshCMAABbCEQAAC+EIAICl1u8CgKno5z//ecawEydOpA2/6KKLNGPG\njHKWBeDXAsYY43cRwFTyyCOP6Otf//qY7WbMmKGTJ0+WoSIANi6rAmX2qU99Kq92V199dYkrAZAL\n4QiU2apVq1RbO/odjZqaGv3lX/5lmSoCYCMcgTKbPXu2br75ZtXU1ORsM23aNP3pn/5pGasC4EU4\nAj64++67let2f21trVasWKFZs2aVuSoADsIR8MGdd96ZsyfqmTNn1NLSUuaKAHgRjoAPLrroIq1c\nuVLTp0/PGHf++efr9ttv96EqAA7CEfBJc3OzTp06lTZs+vTp+uIXv6gLLrjAp6oASIQj4JtbbrlF\nH/3oR9OGnTp1Ss3NzT5VBMBBOAI+mTFjhtauXZt2afWSSy7R8uXLfawKgEQ4Ar7yXlqdPn261q1b\nN+ZnIAGUHl8fB/jo7NmzmjdvnhKJhCTpzTff1A033OBzVQA4cwR8NG3aNPce47x583T99df7XBEA\niV/lmPIee+wxvffee36XMaU5v8Rx9uxZrV271udqpraamho9/fTTmjt3rt+lwGdcVp3iAoGAJGn1\n6tU+VzK1vfvuu/r4xz+e0XsV5bVz5051d3erqanJ71LgM84cwcEA+DXnzSLAPUcAACyEIwAAFsIR\nAAAL4QgAgIVwBADAQjgCAGAhHAEAsBCOAABYCEcAACyEIwAAFsIRAAAL4QgAgIVwBADAQjgCAGAh\nHDFhIyMj6u3tVUNDg9+luCqxJgDVg99zxIQ9/vjj6ujo8LuMNJVYUyFSqZRmzZqlUv0W+XjmP9pv\nHUYiES1YsEBf+MIXNHPmzGKUCPiKM0dMWHt7u98lZKjEmgqxf//+ipu/MUaJRMJ9nkwmZYyRMUbL\nly9XV1eXWlpaNDIyUsxSAV8QjkCFSaVS6urqqsj5z5kzx/2/9wyxrq5O27ZtkyQ9+OCDSqVSEysS\n8BnhiIKlUin19vYqEAiooaFBR48ezdpuZGREW7Zscdv19/fnnE8gEMh6wM7WJtuZyURrGhkZUTQa\nVUNDg1KplFpbW9XW1lboqhmzXme49xKlPSwSiSgajaaN89YnSV1dXQoEAmptbU17reOdvyS1tbWN\n6zU75syZo6985SuKRqMZZ6ajrXfvveFoNOq2GR4eTpuHM72zTu3LvGNtb0BBDKY0Saa7u7ugaYLB\noAmFQiaZTBpjjOnp6TGSjHdzSiQSJhgMmp6eHmOMMfv27TOSTCwWS5tPOBx2n4dCobTnTpvOzs60\neQaDQXfZxaopGAy67QcGBkwsFjOhUKig9ZJPvYlEIqOueDyeMSzXc6c+Y4xJJpMmFAoZSWZoaGhC\n8zfGmHA4nLH+s8k2rSOZTBpJaeuukPXurdc7j0gkYuLxuLuMcDhc8PaWj/HsD5icCMcprtCDQV9f\nX9rB2JjfHBC9BysnnOxlOQdfZ3wikXDHDwwMmGAw6D53DnB2G0nuQbCYNTnt7eDNV771ZguXfMIr\n27BYLGYkmUgkMuH552usace73seq17tenTcB+S4jX4QjHITjFFfowcA5U8k2H+9w7xmB/fCOL3RZ\nTuh5Q7RYNU0kMAqpt5jhON5pyxmO41nv9jBn3fb09GR98zLWMgp5bYQjjCEcp7xCDwYTOUDnM59i\nLqsUNZWq3skSjs6bAe8Z23jWuz1saGgoLQC9Z8r5LCNfhCMcdMhBSeXqGBMMBiVJR44cyTmt0yZb\nB5xQKFT0miaqVPXmo9Tzz9fhw4clSTfddFPGuIms9wULFqivr0+xWEyhUEibNm3Sli1biroMwItw\nREE6OzsljR5q3nYvvvii263f6U0o/SZIOjo63PHDw8NqbW1159HU1CRJOnbsmDvMabt69eqi1zRR\n+dZbTE4Y3HbbbSWZfyFGRkb0zDPPKBgMatmyZe7wYqz3QCCgVCqluro6tbe3KxaLadOmTUVdBpDG\n71NX+EsFXkZyehIGg0G396DTEUX6TQ9Db69J78OZxuld6B0XCoUyOtU4vT2dzhg9PT0ZvUiLUVO2\nXp6Fyrdeu4ep02nHW6uzbhKJhHsJ0WnjdO5xem1672dOZP759Fb1dnTy3vtzep56X7sj3/XuzM+7\nDGde+vWlWufvG4/H0y6tjrW95avQ/QGTF+E4xY3nYBCPx90DcCgUSutG7z0wxuNxt8t9KBTKOFAl\nEgl3fDgcTgtGb5vOzs60YMjWIWOiNXkPqHbYFCKfeuPxuBtOfX19xhiTUavTCzUcDqcFhJT+EYjO\nzs6izX+scMwWPs4jEom4H8XIJp/17rwxyTXMCXJnefkuoxCEIxwBY0r05Y2oCoFAQN3d3e4lQVQu\n50Pv7LKlw/4AB/ccAQCwEI5AFfD2gOWLvYHS4yergFGM9jNNXqW+1Hn55Zen/Z9Lq0BpEY7AKCol\nhCqlDmCq4LIqAAAWwhEAAAvhCACAhXAEAMBCOAIAYCEcAQCwEI4AAFgIRwAALIQjAAAWwhEAAAvh\nCACAhXAEAMBCOAIAYOFXOaDm5mbt2bPH7zIAoGIEDL+FM6U99thjeu+99/wuA6gINTU1evrppzV3\n7ly/S4HPCEcAACzccwQAwEI4AgBgIRwBALAQjgAAWP4fUlerrIUOTMgAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from keras.utils import plot_model\n", "from IPython.display import Image\n", "\n", "# 產生網絡拓撲圖\n", "plot_model(model, to_file='seq2seq_graph.png')\n", "Image('seq2seq_graph.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 5.訓練模型" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 8000 samples, validate on 2000 samples\n", "Epoch 1/100\n", "8000/8000 [==============================] - 10s 1ms/step - loss: 1.9829 - val_loss: 2.4389\n", "Epoch 2/100\n", "8000/8000 [==============================] - 8s 975us/step - loss: 1.8537 - val_loss: 2.3435\n", "Epoch 3/100\n", "8000/8000 [==============================] - 8s 979us/step - loss: 1.7435 - val_loss: 2.2536\n", "Epoch 4/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 1.6472 - val_loss: 2.1720\n", "Epoch 5/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 1.5625 - val_loss: 2.0774\n", "Epoch 6/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 1.4880 - val_loss: 2.0484\n", "Epoch 7/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 1.4232 - val_loss: 1.9879\n", "Epoch 8/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 1.3615 - val_loss: 1.9317\n", "Epoch 9/100\n", "8000/8000 [==============================] - 8s 976us/step - loss: 1.3076 - val_loss: 1.8757\n", "Epoch 10/100\n", "8000/8000 [==============================] - 8s 973us/step - loss: 1.2638 - val_loss: 1.8554\n", "Epoch 11/100\n", "8000/8000 [==============================] - 8s 995us/step - loss: 1.2198 - val_loss: 1.8237\n", "Epoch 12/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 1.1824 - val_loss: 1.8005\n", "Epoch 13/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 1.1469 - val_loss: 1.7941\n", "Epoch 14/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 1.1114 - val_loss: 1.7713\n", "Epoch 15/100\n", "8000/8000 [==============================] - 8s 976us/step - loss: 1.0787 - val_loss: 1.7613\n", "Epoch 16/100\n", "8000/8000 [==============================] - 8s 984us/step - loss: 1.0488 - val_loss: 1.7438\n", "Epoch 17/100\n", "8000/8000 [==============================] - 8s 974us/step - loss: 1.0203 - val_loss: 1.7418\n", "Epoch 18/100\n", "8000/8000 [==============================] - 8s 987us/step - loss: 0.9918 - val_loss: 1.7386\n", "Epoch 19/100\n", "8000/8000 [==============================] - 8s 991us/step - loss: 0.9660 - val_loss: 1.7267\n", "Epoch 20/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.9398 - val_loss: 1.7283\n", "Epoch 21/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 0.9131 - val_loss: 1.7163\n", "Epoch 22/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.8890 - val_loss: 1.7248\n", "Epoch 23/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.8657 - val_loss: 1.7248\n", "Epoch 24/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.8416 - val_loss: 1.7170\n", "Epoch 25/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.8197 - val_loss: 1.7226\n", "Epoch 26/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.7981 - val_loss: 1.7321\n", "Epoch 27/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.7783 - val_loss: 1.7290\n", "Epoch 28/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.7576 - val_loss: 1.7283\n", "Epoch 29/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.7374 - val_loss: 1.7356\n", "Epoch 30/100\n", "8000/8000 [==============================] - 8s 978us/step - loss: 0.7193 - val_loss: 1.7331\n", "Epoch 31/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.7000 - val_loss: 1.7490\n", "Epoch 32/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.6823 - val_loss: 1.7447\n", "Epoch 33/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.6643 - val_loss: 1.7392\n", "Epoch 34/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 0.6465 - val_loss: 1.7528\n", "Epoch 35/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.6305 - val_loss: 1.7575\n", "Epoch 36/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.6143 - val_loss: 1.7546\n", "Epoch 37/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.5979 - val_loss: 1.7673\n", "Epoch 38/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.5822 - val_loss: 1.7846\n", "Epoch 39/100\n", "8000/8000 [==============================] - 8s 990us/step - loss: 0.5678 - val_loss: 1.7974\n", "Epoch 40/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.5528 - val_loss: 1.7874\n", "Epoch 41/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.5390 - val_loss: 1.7983\n", "Epoch 42/100\n", "8000/8000 [==============================] - 8s 980us/step - loss: 0.5263 - val_loss: 1.8070\n", "Epoch 43/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 0.5123 - val_loss: 1.8127\n", "Epoch 44/100\n", "8000/8000 [==============================] - 8s 994us/step - loss: 0.4993 - val_loss: 1.8130\n", "Epoch 45/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.4870 - val_loss: 1.8185\n", "Epoch 46/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.4748 - val_loss: 1.8358\n", "Epoch 47/100\n", "8000/8000 [==============================] - 8s 990us/step - loss: 0.4635 - val_loss: 1.8333\n", "Epoch 48/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.4512 - val_loss: 1.8440\n", "Epoch 49/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.4402 - val_loss: 1.8451\n", "Epoch 50/100\n", "8000/8000 [==============================] - 8s 999us/step - loss: 0.4290 - val_loss: 1.8525\n", "Epoch 51/100\n", "8000/8000 [==============================] - 8s 990us/step - loss: 0.4182 - val_loss: 1.8656\n", "Epoch 52/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.4083 - val_loss: 1.8839\n", "Epoch 53/100\n", "8000/8000 [==============================] - 8s 993us/step - loss: 0.3982 - val_loss: 1.8905\n", "Epoch 54/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.3881 - val_loss: 1.8920\n", "Epoch 55/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.3787 - val_loss: 1.8985\n", "Epoch 56/100\n", "8000/8000 [==============================] - 8s 991us/step - loss: 0.3697 - val_loss: 1.9039\n", "Epoch 57/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.3608 - val_loss: 1.9082\n", "Epoch 58/100\n", "8000/8000 [==============================] - 8s 987us/step - loss: 0.3525 - val_loss: 1.9121\n", "Epoch 59/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 0.3441 - val_loss: 1.9194\n", "Epoch 60/100\n", "8000/8000 [==============================] - 8s 983us/step - loss: 0.3347 - val_loss: 1.9338\n", "Epoch 61/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.3270 - val_loss: 1.9478\n", "Epoch 62/100\n", "8000/8000 [==============================] - 8s 983us/step - loss: 0.3192 - val_loss: 1.9383\n", "Epoch 63/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.3114 - val_loss: 1.9512\n", "Epoch 64/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.3046 - val_loss: 1.9562\n", "Epoch 65/100\n", "8000/8000 [==============================] - 8s 986us/step - loss: 0.2970 - val_loss: 1.9666\n", "Epoch 66/100\n", "8000/8000 [==============================] - 8s 989us/step - loss: 0.2905 - val_loss: 1.9733\n", "Epoch 67/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2846 - val_loss: 1.9765\n", "Epoch 68/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2743 - val_loss: 1.9953\n", "Epoch 69/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2688 - val_loss: 2.0060\n", "Epoch 70/100\n", "8000/8000 [==============================] - 8s 981us/step - loss: 0.2632 - val_loss: 2.0008\n", "Epoch 71/100\n", "8000/8000 [==============================] - 8s 984us/step - loss: 0.2570 - val_loss: 2.0049\n", "Epoch 72/100\n", "8000/8000 [==============================] - 8s 984us/step - loss: 0.2501 - val_loss: 2.0082\n", "Epoch 73/100\n", "8000/8000 [==============================] - 8s 992us/step - loss: 0.2447 - val_loss: 2.0196\n", "Epoch 74/100\n", "8000/8000 [==============================] - 8s 987us/step - loss: 0.2384 - val_loss: 2.0287\n", "Epoch 75/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.2325 - val_loss: 2.0356\n", "Epoch 76/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2270 - val_loss: 2.0414\n", "Epoch 77/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.2209 - val_loss: 2.0477\n", "Epoch 78/100\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "8000/8000 [==============================] - 8s 997us/step - loss: 0.2157 - val_loss: 2.0530\n", "Epoch 79/100\n", "8000/8000 [==============================] - 8s 985us/step - loss: 0.2108 - val_loss: 2.0583\n", "Epoch 80/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2053 - val_loss: 2.0610\n", "Epoch 81/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.2047 - val_loss: 2.0742\n", "Epoch 82/100\n", "8000/8000 [==============================] - 8s 995us/step - loss: 0.1951 - val_loss: 2.0844\n", "Epoch 83/100\n", "8000/8000 [==============================] - 8s 988us/step - loss: 0.1905 - val_loss: 2.0889\n", "Epoch 84/100\n", "8000/8000 [==============================] - 8s 984us/step - loss: 0.1858 - val_loss: 2.1018\n", "Epoch 85/100\n", "8000/8000 [==============================] - 8s 984us/step - loss: 0.1809 - val_loss: 2.1055\n", "Epoch 86/100\n", "8000/8000 [==============================] - 8s 979us/step - loss: 0.1763 - val_loss: 2.1158\n", "Epoch 87/100\n", "8000/8000 [==============================] - 8s 987us/step - loss: 0.1710 - val_loss: 2.1169\n", "Epoch 88/100\n", "8000/8000 [==============================] - 8s 998us/step - loss: 0.1687 - val_loss: 2.1181\n", "Epoch 89/100\n", "8000/8000 [==============================] - 8s 985us/step - loss: 0.1627 - val_loss: 2.1390\n", "Epoch 90/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.1588 - val_loss: 2.1428\n", "Epoch 91/100\n", "8000/8000 [==============================] - 8s 999us/step - loss: 0.1538 - val_loss: 2.1437\n", "Epoch 92/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.1509 - val_loss: 2.1550\n", "Epoch 93/100\n", "8000/8000 [==============================] - 8s 996us/step - loss: 0.1464 - val_loss: 2.1528\n", "Epoch 94/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.1424 - val_loss: 2.1584\n", "Epoch 95/100\n", "8000/8000 [==============================] - 8s 991us/step - loss: 0.1383 - val_loss: 2.1770\n", "Epoch 96/100\n", "8000/8000 [==============================] - 8s 1ms/step - loss: 0.1348 - val_loss: 2.1693\n", "Epoch 97/100\n", "8000/8000 [==============================] - 8s 989us/step - loss: 0.1310 - val_loss: 2.1808\n", "Epoch 98/100\n", "8000/8000 [==============================] - 8s 977us/step - loss: 0.1272 - val_loss: 2.1906\n", "Epoch 99/100\n", "8000/8000 [==============================] - 8s 990us/step - loss: 0.1233 - val_loss: 2.1884\n", "Epoch 100/100\n", "8000/8000 [==============================] - 8s 995us/step - loss: 0.1200 - val_loss: 2.1978\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\8703147\\AppData\\Local\\Continuum\\anaconda3\\envs\\ml\\lib\\site-packages\\keras\\engine\\topology.py:2344: UserWarning: Layer decoder_lstm was passed non-serializable keyword arguments: {'initial_state': [, ]}. They will not be included in the serialized model (and thus will be missing at deserialization time).\n", " str(node.arguments) + '. They will not be included '\n" ] } ], "source": [ "# 設定模型超參數\n", "model.compile(optimizer='rmsprop', loss='categorical_crossentropy')\n", "\n", "# 開始訓練\n", "model.fit([encoder_input_data, decoder_input_data], decoder_target_data,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " validation_split=0.2)\n", "\n", "# 儲存模型\n", "model.save('s2s.h5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### STEP 6.模型預測\n", "\n", "以下是預測階段的步驟:\n", "\n", "1. 對輸入進行編碼(encode)並取得解碼器所需要的初始狀態(initial decoder state)\n", "2. 以此初始狀態運行一步解碼器,並以“開始序列”標記作為目標。輸出將是下一個目標標記\n", "3. 重複當前目標標記和當前狀態\n", "\n", "\n", "![seq2seq_predict](https://4.bp.blogspot.com/-6DALk3-hPtA/WO04i5GgXLI/AAAAAAAABtc/2t9mYz4nQDg9jLoHdTkywDUfxIOFJfC_gCLcB/s640/Seq2SeqDiagram.gif)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-\n", "Input sentence: Hi.\n", "Decoded sentence: 你好。\n", "\n", "-\n", "Input sentence: Hi.\n", "Decoded sentence: 你好。\n", "\n", "-\n", "Input sentence: Run.\n", "Decoded sentence: 你用跑的。\n", "\n", "-\n", "Input sentence: Wait!\n", "Decoded sentence: 等等!\n", "\n", "-\n", "Input sentence: Hello!\n", "Decoded sentence: 你好。\n", "\n", "-\n", "Input sentence: I try.\n", "Decoded sentence: 讓我來。\n", "\n", "-\n", "Input sentence: I won!\n", "Decoded sentence: 我贏了。\n", "\n", "-\n", "Input sentence: Oh no!\n", "Decoded sentence: 不會吧。\n", "\n", "-\n", "Input sentence: Cheers!\n", "Decoded sentence: 乾杯!\n", "\n", "-\n", "Input sentence: He ran.\n", "Decoded sentence: 他跑了。\n", "\n", "-\n", "Input sentence: Hop in.\n", "Decoded sentence: 等一下!\n", "\n", "-\n", "Input sentence: I lost.\n", "Decoded sentence: 我吃了這個蘋果。\n", "\n", "-\n", "Input sentence: I quit.\n", "Decoded sentence: 我退出。\n", "\n", "-\n", "Input sentence: I'm OK.\n", "Decoded sentence: 我沒事。\n", "\n", "-\n", "Input sentence: Listen.\n", "Decoded sentence: 聽著。\n", "\n", "-\n", "Input sentence: No way!\n", "Decoded sentence: 沒門!\n", "\n", "-\n", "Input sentence: No way!\n", "Decoded sentence: 沒門!\n", "\n", "-\n", "Input sentence: Really?\n", "Decoded sentence: 你確定?\n", "\n", "-\n", "Input sentence: Try it.\n", "Decoded sentence: 一個您方的。\n", "\n", "-\n", "Input sentence: We try.\n", "Decoded sentence: 我們來試試。\n", "\n", "-\n", "Input sentence: Why me?\n", "Decoded sentence: 為什麼是我?\n", "\n", "-\n", "Input sentence: Ask Tom.\n", "Decoded sentence: 去問湯姆。\n", "\n", "-\n", "Input sentence: Be calm.\n", "Decoded sentence: 冷靜點。\n", "\n", "-\n", "Input sentence: Be fair.\n", "Decoded sentence: 公平點。\n", "\n", "-\n", "Input sentence: Be kind.\n", "Decoded sentence: 放鬆點吧。\n", "\n", "-\n", "Input sentence: Be nice.\n", "Decoded sentence: 和氣點。\n", "\n", "-\n", "Input sentence: Call me.\n", "Decoded sentence: 聯繫我。\n", "\n", "-\n", "Input sentence: Call us.\n", "Decoded sentence: 聯繫我們。\n", "\n", "-\n", "Input sentence: Come in.\n", "Decoded sentence: 快點。\n", "\n", "-\n", "Input sentence: Get Tom.\n", "Decoded sentence: 滾!\n", "\n", "-\n", "Input sentence: Get out!\n", "Decoded sentence: 滾出去!\n", "\n", "-\n", "Input sentence: Go away!\n", "Decoded sentence: 走開!\n", "\n", "-\n", "Input sentence: Go away!\n", "Decoded sentence: 走開!\n", "\n", "-\n", "Input sentence: Go away.\n", "Decoded sentence: 走開!\n", "\n", "-\n", "Input sentence: Goodbye!\n", "Decoded sentence: 你用跑的。\n", "\n", "-\n", "Input sentence: Goodbye!\n", "Decoded sentence: 你用跑的。\n", "\n", "-\n", "Input sentence: Hang on!\n", "Decoded sentence: 等一下!\n", "\n", "-\n", "Input sentence: He came.\n", "Decoded sentence: 他來了。\n", "\n", "-\n", "Input sentence: He runs.\n", "Decoded sentence: 他跑。\n", "\n", "-\n", "Input sentence: Help me.\n", "Decoded sentence: 幫我一下。\n", "\n", "-\n", "Input sentence: Hold on.\n", "Decoded sentence: 堅持。\n", "\n", "-\n", "Input sentence: Hug Tom.\n", "Decoded sentence: 抱抱湯姆!\n", "\n", "-\n", "Input sentence: I agree.\n", "Decoded sentence: 我同意。\n", "\n", "-\n", "Input sentence: I'm ill.\n", "Decoded sentence: 我生病了。\n", "\n", "-\n", "Input sentence: I'm old.\n", "Decoded sentence: 我生病了。\n", "\n", "-\n", "Input sentence: It's OK.\n", "Decoded sentence: 沒關係。\n", "\n", "-\n", "Input sentence: It's me.\n", "Decoded sentence: 是該上個的子。\n", "\n", "-\n", "Input sentence: Join us.\n", "Decoded sentence: 來加入我們吧。\n", "\n", "-\n", "Input sentence: Keep it.\n", "Decoded sentence: 留著吧。\n", "\n", "-\n", "Input sentence: Kiss me.\n", "Decoded sentence: 吻我。\n", "\n", "-\n", "Input sentence: Perfect!\n", "Decoded sentence: 完美!\n", "\n", "-\n", "Input sentence: See you.\n", "Decoded sentence: 再見!\n", "\n", "-\n", "Input sentence: Shut up!\n", "Decoded sentence: 閉嘴!\n", "\n", "-\n", "Input sentence: Skip it.\n", "Decoded sentence: 不管它。\n", "\n", "-\n", "Input sentence: Take it.\n", "Decoded sentence: 拿走吧。\n", "\n", "-\n", "Input sentence: Wake up!\n", "Decoded sentence: 醒醒!\n", "\n", "-\n", "Input sentence: Wash up.\n", "Decoded sentence: 去清洗一下。\n", "\n", "-\n", "Input sentence: We know.\n", "Decoded sentence: 我們什麼都沒?\n", "\n", "-\n", "Input sentence: Welcome.\n", "Decoded sentence: 歡迎。\n", "\n", "-\n", "Input sentence: Who won?\n", "Decoded sentence: 誰贏了?\n", "\n", "-\n", "Input sentence: Why not?\n", "Decoded sentence: 為什麼不?\n", "\n", "-\n", "Input sentence: You run.\n", "Decoded sentence: 你跑。\n", "\n", "-\n", "Input sentence: Back off.\n", "Decoded sentence: 你用跑的。\n", "\n", "-\n", "Input sentence: Be still.\n", "Decoded sentence: 靜靜的,別動。\n", "\n", "-\n", "Input sentence: Cuff him.\n", "Decoded sentence: 把他銬上。\n", "\n", "-\n", "Input sentence: Drive on.\n", "Decoded sentence: 往前開。\n", "\n", "-\n", "Input sentence: Get away!\n", "Decoded sentence: 滾!\n", "\n", "-\n", "Input sentence: Get away!\n", "Decoded sentence: 滾!\n", "\n", "-\n", "Input sentence: Get down!\n", "Decoded sentence: 趴下!\n", "\n", "-\n", "Input sentence: Get lost!\n", "Decoded sentence: 滾!\n", "\n", "-\n", "Input sentence: Get real.\n", "Decoded sentence: 醒醒吧。\n", "\n", "-\n", "Input sentence: Grab Tom.\n", "Decoded sentence: 抓住湯姆。\n", "\n", "-\n", "Input sentence: Grab him.\n", "Decoded sentence: 抓住他。\n", "\n", "-\n", "Input sentence: Have fun.\n", "Decoded sentence: 玩得開心。\n", "\n", "-\n", "Input sentence: He tries.\n", "Decoded sentence: 他很容易。\n", "\n", "-\n", "Input sentence: Humor me.\n", "Decoded sentence: 你就是湯姆的主意。\n", "\n", "-\n", "Input sentence: Hurry up.\n", "Decoded sentence: 趕快!\n", "\n", "-\n", "Input sentence: Hurry up.\n", "Decoded sentence: 趕快!\n", "\n", "-\n", "Input sentence: I forgot.\n", "Decoded sentence: 我忘了。\n", "\n", "-\n", "Input sentence: I resign.\n", "Decoded sentence: 我放棄。\n", "\n", "-\n", "Input sentence: I'll pay.\n", "Decoded sentence: 我來付錢。\n", "\n", "-\n", "Input sentence: I'm busy.\n", "Decoded sentence: 我很忙。\n", "\n", "-\n", "Input sentence: I'm cold.\n", "Decoded sentence: 我生病了。\n", "\n", "-\n", "Input sentence: I'm fine.\n", "Decoded sentence: 我很好。\n", "\n", "-\n", "Input sentence: I'm full.\n", "Decoded sentence: 我吃飽了。\n", "\n", "-\n", "Input sentence: I'm sick.\n", "Decoded sentence: 我生病了。\n", "\n", "-\n", "Input sentence: I'm sick.\n", "Decoded sentence: 我生病了。\n", "\n", "-\n", "Input sentence: Leave me.\n", "Decoded sentence: 讓我一個人呆會兒。\n", "\n", "-\n", "Input sentence: Let's go!\n", "Decoded sentence: 我們開始吧!\n", "\n", "-\n", "Input sentence: Let's go!\n", "Decoded sentence: 我們開始吧!\n", "\n", "-\n", "Input sentence: Let's go!\n", "Decoded sentence: 我們開始吧!\n", "\n", "-\n", "Input sentence: Look out!\n", "Decoded sentence: 當心!\n", "\n", "-\n", "Input sentence: She runs.\n", "Decoded sentence: 她試過了。\n", "\n", "-\n", "Input sentence: Stand up.\n", "Decoded sentence: 起立。\n", "\n", "-\n", "Input sentence: They won.\n", "Decoded sentence: 他們不錯。\n", "\n", "-\n", "Input sentence: Tom died.\n", "Decoded sentence: 湯姆去世了。\n", "\n", "-\n", "Input sentence: Tom quit.\n", "Decoded sentence: 湯姆不干了。\n", "\n", "-\n", "Input sentence: Tom swam.\n", "Decoded sentence: 湯姆游泳了。\n", "\n", "-\n", "Input sentence: Trust me.\n", "Decoded sentence: 相信我。\n", "\n", "-\n", "Input sentence: Try hard.\n", "Decoded sentence: 努力。\n", "\n" ] } ], "source": [ "# 定義要進行取樣的模型\n", "\n", "# 定義編碼器(encoder)的模型\n", "encoder_model = Model(encoder_inputs, encoder_states)\n", "\n", "# 定義解碼器LSTM cell的初始權重輸入\n", "decoder_state_input_h = Input(shape=(latent_dim,))\n", "decoder_state_input_c = Input(shape=(latent_dim,))\n", "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n", "\n", "# # 解碼器(decoder)定義初始狀態(initial decoder state)\n", "decoder_outputs, state_h, state_c = decoder_lstm(\n", " decoder_inputs, initial_state=decoder_states_inputs) #我們使用`decoder_states_inputs`來做為初始值(initial state)\n", "\n", "decoder_states = [state_h, state_c]\n", "decoder_outputs = decoder_dense(decoder_outputs)\n", "\n", "# 定義解碼器(decoder)的模型\n", "decoder_model = Model(\n", " [decoder_inputs] + decoder_states_inputs,\n", " [decoder_outputs] + decoder_states)\n", "\n", "\n", "# 反向查找字符索引來將序列解碼為可讀的內容。\n", "reverse_input_char_index = dict(\n", " (i, char) for char, i in input_token_index.items())\n", "\n", "reverse_target_char_index = dict(\n", " (i, char) for char, i in target_token_index.items())\n", "\n", "# 對序列進行解碼\n", "def decode_sequence(input_seq):\n", " # 將輸入編碼成為state向量\n", " states_value = encoder_model.predict(input_seq)\n", " \n", " # 產生長度為1的空白目標序列\n", " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", " \n", " # 發佈特定的目標序列起始字符\"[SOS]\",在這個範例中是使用 \"\\t\"字符\n", " target_seq[0, 0, target_token_index['\\t']] = 1.\n", "\n", " # 對批次的序列進行抽樣迴圈\n", " stop_condition = False\n", " decoded_sentence = ''\n", " while not stop_condition:\n", " output_tokens, h, c = decoder_model.predict(\n", " [target_seq] + states_value)\n", "\n", " # 對符標抽樣\n", " sampled_token_index = np.argmax(output_tokens[0, -1, :])\n", " sampled_char = reverse_target_char_index[sampled_token_index]\n", " decoded_sentence += sampled_char\n", "\n", " # 停止迴圈的條件: 到達最大的長度或是找到\"停止[EOS]\"字符,在這個範例中是使用 \"\\n\"字符\n", " if (sampled_char == '\\n' or\n", " len(decoded_sentence) > max_decoder_seq_length):\n", " stop_condition = True\n", "\n", " # 更新目標序列(of length 1).\n", " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", " target_seq[0, 0, sampled_token_index] = 1.\n", "\n", " # 更新 states\n", " states_value = [h, c]\n", "\n", " return decoded_sentence\n", "\n", "\n", "for seq_index in range(100):\n", " # 從訓練集中取出一個序列並試著解碼\n", " input_seq = encoder_input_data[seq_index: seq_index + 1]\n", " decoded_sentence = decode_sequence(input_seq)\n", " print('-')\n", " print('Input sentence:', input_texts[seq_index])\n", " print('Decoded sentence:', decoded_sentence)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "參考:\n", "* [A ten-minute introduction to sequence-to-sequence learning in Keras](https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "MIT License\n", "\n", "Copyright (c) 2018 Erhwen Kuo\n", "\n", "Copyright (c) 2017 François Chollet\n", "\n", "Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n", "\n", "The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n", "\n", "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }