{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Gated Recurrent Units (GRU)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.615802Z",
"start_time": "2019-07-03T23:13:40.417672Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [],
"source": [
"import d2l\n",
"from mxnet import np, npx\n",
"from mxnet.gluon import rnn\n",
"npx.set_np()\n",
"\n",
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initialize model parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.624368Z",
"start_time": "2019-07-03T23:13:42.618052Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"def get_params(vocab_size, num_hiddens, ctx):\n",
" num_inputs = num_outputs = vocab_size\n",
" normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n",
" three = lambda : (normal((num_inputs, num_hiddens)),\n",
" normal((num_hiddens, num_hiddens)),\n",
" np.zeros(num_hiddens, ctx=ctx))\n",
" W_xz, W_hz, b_z = three() # Update gate parameter\n",
" W_xr, W_hr, b_r = three() # Reset gate parameter\n",
" W_xh, W_hh, b_h = three() # Candidate hidden state parameter\n",
" # Output layer parameters\n",
" W_hq = normal((num_hiddens, num_outputs))\n",
" b_q = np.zeros(num_outputs, ctx=ctx)\n",
" # Create gradient\n",
" params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
" for param in params:\n",
" param.attach_grad()\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initialize the hidden state"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.629629Z",
"start_time": "2019-07-03T23:13:42.626280Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"def init_gru_state(batch_size, num_hiddens, ctx):\n",
" return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.637496Z",
"start_time": "2019-07-03T23:13:42.631549Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"def gru(inputs, state, params):\n",
" W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
" H, = state\n",
" outputs = []\n",
" for X in inputs:\n",
" Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)\n",
" R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)\n",
" H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)\n",
" H = Z * H + (1 - Z) * H_tilda\n",
" Y = np.dot(H, W_hq) + b_q\n",
" outputs.append(Y)\n",
" return np.concatenate(outputs, axis=0), (H,)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:14:22.156668Z",
"start_time": "2019-07-03T23:13:42.638985Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
},
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Perplexity 10.6, 13809 tokens/sec on gpu(0)\n",
"time travellere the the the the the the the the the the the the \n",
"travellere the the the the the the the the the the the the \n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"