{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Long Short Term Memory (LSTM)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:11:54.519772Z",
"start_time": "2019-07-03T23:11:51.720287Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"import d2l\n",
"from mxnet import np, npx\n",
"from mxnet.gluon import rnn\n",
"npx.set_np()\n",
"\n",
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initializing model parameters."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:11:54.528595Z",
"start_time": "2019-07-03T23:11:54.522014Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"def get_lstm_params(vocab_size, num_hiddens, ctx):\n",
" num_inputs = num_outputs = vocab_size\n",
" normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n",
" three = lambda : (normal((num_inputs, num_hiddens)),\n",
" normal((num_hiddens, num_hiddens)),\n",
" np.zeros(num_hiddens, ctx=ctx))\n",
" W_xi, W_hi, b_i = three() # Input gate parameters\n",
" W_xf, W_hf, b_f = three() # Forget gate parameters\n",
" W_xo, W_ho, b_o = three() # Output gate parameters\n",
" W_xc, W_hc, b_c = three() # Candidate cell parameters\n",
" # Output layer parameters\n",
" W_hq = normal((num_hiddens, num_outputs))\n",
" b_q = np.zeros(num_outputs, ctx=ctx)\n",
" # Create gradient\n",
" params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n",
" b_c, W_hq, b_q]\n",
" for param in params:\n",
" param.attach_grad()\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initialize the hidden state"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:11:54.535348Z",
"start_time": "2019-07-03T23:11:54.531181Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"def init_lstm_state(batch_size, num_hiddens, ctx):\n",
" return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx),\n",
" np.zeros(shape=(batch_size, num_hiddens), ctx=ctx))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:11:54.545037Z",
"start_time": "2019-07-03T23:11:54.537332Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"def lstm(inputs, state, params):\n",
" [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,\n",
" W_hq, b_q] = params\n",
" (H, C) = state\n",
" outputs = []\n",
" for X in inputs:\n",
" I = npx.sigmoid(np.dot(X, W_xi) + np.dot(H, W_hi) + b_i)\n",
" F = npx.sigmoid(np.dot(X, W_xf) + np.dot(H, W_hf) + b_f)\n",
" O = npx.sigmoid(np.dot(X, W_xo) + np.dot(H, W_ho) + b_o)\n",
" C_tilda = np.tanh(np.dot(X, W_xc) + np.dot(H, W_hc) + b_c)\n",
" C = F * C + I * C_tilda\n",
" H = O * np.tanh(C)\n",
" Y = np.dot(H, W_hq) + b_q\n",
" outputs.append(Y)\n",
" return np.concatenate(outputs, axis=0), (H, C)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:12:41.455722Z",
"start_time": "2019-07-03T23:11:54.546597Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "9"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Perplexity 14.4, 10748 tokens/sec on gpu(0)\n",
"time traveller te at at at at at at at at at at at at at at at a\n",
"traveller te at at at at at at at at at at at at at at at a\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"