{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "from keras.models import Sequential\n", "from keras.layers import Dense, Embedding, LSTM\n", "from keras.utils import np_utils\n", "from keras.utils.data_utils import get_file\n", "from keras.preprocessing import sequence\n", "from keras.preprocessing.text import Tokenizer\n", "\n", "import numpy as np\n", "np.random.seed(13)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "path = get_file('alice.txt', origin='http://www.gutenberg.org/cache/epub/11/pg11.txt')\n", "doc = open(path).readlines()[0:50]\n", "tokenizer = Tokenizer()\n", "tokenizer.fit_on_texts(doc)\n", "doc = tokenizer.texts_to_sequences(doc)\n", "doc = [l for l in doc if len(l) > 1]\n", "words_size = sum([len(words) - 1 for words in doc])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "maxlen = max([len(x)-1 for x in doc])\n", "vocab_size = len(tokenizer.word_index) + 1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def generate_data(X, maxlen, V):\n", " for sentence in X: \n", " inputs = []\n", " targets = []\n", " for i in range(1, len(sentence)):\n", " inputs.append(sentence[0:i])\n", " targets.append(sentence[i])\n", " y = np_utils.to_categorical(targets, V)\n", " inputs_sequence = sequence.pad_sequences(inputs, maxlen=maxlen)\n", " yield (inputs_sequence, y)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def sample(p):\n", " p /= sum(p)\n", " return np.where(np.random.multinomial(1, p, 1)==1)[1][0]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model = Sequential()\n", "model.add(Embedding(vocab_size, 128, input_length=maxlen))\n", "model.add(LSTM(128, return_sequences=False))\n", "model.add(Dense(vocab_size, activation='softmax'))\n", "model.compile(loss='categorical_crossentropy', optimizer='adadelta')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 alice's for 25 3 her hot bank author what 'without pleasure feel posting at thought\n", "1 alice's mind ' into no conversations and pictures sitting as away the re in copy\n", "2 alice's anyone get and hole once hot it 25 language was worth rabbit chapter\n", "3 alice's mind carroll a ' anywhere terms date thought org up restrictions was license carroll\n", "4 alice's millennium well carroll gutenberg peeped 25 project feel 11 edition up online copy i\n", "5 alice's updated project use restrictions the www carroll she edition project you stupid stupid feel\n", "6 alice's 3 carroll in very english ' ' it with her fulcrum well updated is\n", "7 alice's hole 0 making edition and well it and very adventures had sister in own\n", "8 alice's in pleasure this was her the anywhere daisy updated restrictions whatsoever 11 0 release\n", "9 alice's gutenberg no carroll ' or alice's ' millennium carroll she had ebook carroll june\n", "10 alice's ebook december was use is june had is wonderland 2008 beginning 20 wonderland do\n", "11 alice's 0 ebook the ' in down her carroll trouble tired away december by fulcrum\n", "12 alice's is 3 nothing wonderland of online sitting posting in it so for her 3\n", "13 alice's in project restrictions 20 of terms 2011 this adventures gutenberg away trouble do is\n", "14 alice's carroll march by 11 2011 would adventures tired english by of 1994 'and of\n", "15 alice's alice's gutenberg's 20 date terms www online carroll wonderland worth in away give\n", "16 alice's carroll english no she restrictions hole the it her ' wonderland 0 date millennium\n", "17 alice's carroll no 0 the would alice's alice i but whether included 'and carroll gutenberg's\n", "18 alice's adventures no the it lewis i lewis 2008 was day org a you made\n", "19 alice's ebook ' june mind daisy ' this to anywhere is by the carroll down\n", "20 alice's you ' in the june gutenberg's english stupid english to 11 beginning re a\n", "21 alice's 1994 in be feel june carroll in adventures 2011 once is updated the be\n", "22 alice's ebook no lewis or stupid adventures alice's june to december or adventures sitting own\n", "23 alice's i march you rabbit and millennium anyone march or is the but away pleasure\n", "24 alice's millennium use conversations the 11 day it as adventures hole chapter ' get away\n", "25 alice's a ' adventures hole carroll updated english included and it as but or rabbit\n", "26 alice's under carroll carroll by gutenberg her december terms in for license december considering but\n", "27 alice's fulcrum carroll nothing made do you march english or worth www away www stupid\n", "28 alice's june carroll of getting alice's of alice's anyone fulcrum pleasure tired with 'and terms\n", "29 alice's edition june ' english worth in no under adventures carroll 11 is sitting into\n" ] } ], "source": [ "for i in range(30):\n", " for x, y in generate_data(doc, maxlen, vocab_size):\n", " model.train_on_batch(x, y)\n", "\n", " in_words = \"alice's\"\n", " for _ in range(maxlen):\n", " in_sequence = sequence.pad_sequences(tokenizer.texts_to_sequences([in_words]), maxlen=maxlen)\n", " wordid = sample(model.predict(in_sequence)[0])\n", " for k, v in tokenizer.word_index.items():\n", " if v == wordid:\n", " in_words += ' ' + k\n", " break\n", "\n", " print(i, in_words)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "alice's carroll carroll carroll carroll carroll in in the the the the the the or\n" ] } ], "source": [ "in_words = \"alice's\"\n", "for _ in range(maxlen):\n", " in_sequence = sequence.pad_sequences(tokenizer.texts_to_sequences([in_words]), maxlen=maxlen)\n", " wordid = model.predict_classes(in_sequence, verbose=0)[0]\n", " for k, v in tokenizer.word_index.items():\n", " if v == wordid:\n", " in_words += ' ' + k\n", " break\n", "\n", "print(in_words)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" }, "toc": { "toc_cell": false, "toc_number_sections": true, "toc_threshold": 6, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }