{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "f:\\pythonprojects\\lenv\\test\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n", "Using TensorFlow backend.\n" ] } ], "source": [ "import random\n", "import os\n", "\n", "import keras\n", "import numpy as np\n", "from keras.callbacks import LambdaCallback\n", "from keras.models import Input, Model, load_model\n", "from keras.layers import LSTM, Dropout, Dense\n", "from keras.optimizers import Adam\n", "\n", "from data_utils import *\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class PoetryModel(object):\n", " def __init__(self, config):\n", " self.model = None\n", " self.do_train = True\n", " self.loaded_model = True\n", " self.config = config\n", "\n", " # 文件预处理\n", " self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)\n", " \n", " # 诗的list\n", " self.poems = self.files_content.split(']')\n", " # 诗的总数量\n", " self.poems_num = len(self.poems)\n", " \n", " # 如果模型文件存在则直接加载模型,否则开始训练\n", " if os.path.exists(self.config.weight_file) and self.loaded_model:\n", " self.model = load_model(self.config.weight_file)\n", " else:\n", " self.train()\n", "\n", " def build_model(self):\n", " '''建立模型'''\n", " print('building model')\n", "\n", " # 输入的dimension\n", " input_tensor = Input(shape=(self.config.max_len, len(self.words)))\n", " lstm = LSTM(512, return_sequences=True)(input_tensor)\n", " dropout = Dropout(0.6)(lstm)\n", " lstm = LSTM(256)(dropout)\n", " dropout = Dropout(0.6)(lstm)\n", " dense = Dense(len(self.words), activation='softmax')(dropout)\n", " self.model = Model(inputs=input_tensor, outputs=dense)\n", " optimizer = Adam(lr=self.config.learning_rate)\n", " self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n", "\n", " def sample(self, preds, temperature=1.0):\n", " '''\n", " 当temperature=1.0时,模型输出正常\n", " 当temperature=0.5时,模型输出比较open\n", " 当temperature=1.5时,模型输出比较保守\n", " 在训练的过程中可以看到temperature不同,结果也不同\n", " 就是一个概率分布变换的问题,保守的时候概率大的值变得更大,选择的可能性也更大\n", " '''\n", " preds = np.asarray(preds).astype('float64')\n", " exp_preds = np.power(preds,1./temperature)\n", " preds = exp_preds / np.sum(exp_preds)\n", " pro = np.random.choice(range(len(preds)),1,p=preds)\n", " return int(pro.squeeze())\n", " \n", " def generate_sample_result(self, epoch, logs):\n", " '''训练过程中,每4个epoch打印出当前的学习情况'''\n", " if epoch % 4 != 0:\n", " return\n", " \n", " with open('out/out.txt', 'a',encoding='utf-8') as f:\n", " f.write('==================Epoch {}=====================\\n'.format(epoch))\n", " \n", " print(\"\\n==================Epoch {}=====================\".format(epoch))\n", " for diversity in [0.7, 1.0, 1.3]:\n", " print(\"------------Diversity {}--------------\".format(diversity))\n", " generate = self.predict_random(temperature=diversity)\n", " print(generate)\n", " \n", " # 训练时的预测结果写入txt\n", " with open('out/out.txt', 'a',encoding='utf-8') as f:\n", " f.write(generate+'\\n')\n", " \n", " def predict_random(self,temperature = 1):\n", " '''随机从库中选取一句开头的诗句,生成五言绝句'''\n", " if not self.model:\n", " print('model not loaded')\n", " return\n", " \n", " index = random.randint(0, self.poems_num)\n", " sentence = self.poems[index][: self.config.max_len]\n", " generate = self.predict_sen(sentence,temperature=temperature)\n", " return generate\n", " \n", " def predict_first(self, char,temperature =1):\n", " '''根据给出的首个文字,生成五言绝句'''\n", " if not self.model:\n", " print('model not loaded')\n", " return\n", " \n", " index = random.randint(0, self.poems_num)\n", " #选取随机一首诗的最后max_len字符+给出的首个文字作为初始输入\n", " sentence = self.poems[index][1-self.config.max_len:] + char\n", " generate = str(char)\n", "# print('first line = ',sentence)\n", " # 直接预测后面23个字符\n", " generate += self._preds(sentence,length=23,temperature=temperature)\n", " return generate\n", " \n", " def predict_sen(self, text,temperature =1):\n", " '''根据给出的前max_len个字,生成诗句'''\n", " '''此例中,即根据给出的第一句诗句(含逗号),来生成古诗'''\n", " if not self.model:\n", " return\n", " max_len = self.config.max_len\n", " if len(text)