{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Gated Recurrent Units (GRU)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.615802Z",
"start_time": "2019-07-03T23:13:40.417672Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [],
"source": [
"import d2l\n",
"from mxnet import np, npx\n",
"from mxnet.gluon import rnn\n",
"npx.set_np()\n",
"\n",
"batch_size, num_steps = 32, 35\n",
"train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initialize model parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.624368Z",
"start_time": "2019-07-03T23:13:42.618052Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"def get_params(vocab_size, num_hiddens, ctx):\n",
" num_inputs = num_outputs = vocab_size\n",
" normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n",
" three = lambda : (normal((num_inputs, num_hiddens)),\n",
" normal((num_hiddens, num_hiddens)),\n",
" np.zeros(num_hiddens, ctx=ctx))\n",
" W_xz, W_hz, b_z = three() # Update gate parameter\n",
" W_xr, W_hr, b_r = three() # Reset gate parameter\n",
" W_xh, W_hh, b_h = three() # Candidate hidden state parameter\n",
" # Output layer parameters\n",
" W_hq = normal((num_hiddens, num_outputs))\n",
" b_q = np.zeros(num_outputs, ctx=ctx)\n",
" # Create gradient\n",
" params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
" for param in params:\n",
" param.attach_grad()\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initialize the hidden state"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.629629Z",
"start_time": "2019-07-03T23:13:42.626280Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"def init_gru_state(batch_size, num_hiddens, ctx):\n",
" return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:13:42.637496Z",
"start_time": "2019-07-03T23:13:42.631549Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"def gru(inputs, state, params):\n",
" W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
" H, = state\n",
" outputs = []\n",
" for X in inputs:\n",
" Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)\n",
" R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)\n",
" H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)\n",
" H = Z * H + (1 - Z) * H_tilda\n",
" Y = np.dot(H, W_hq) + b_q\n",
" outputs.append(Y)\n",
" return np.concatenate(outputs, axis=0), (H,)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:14:22.156668Z",
"start_time": "2019-07-03T23:13:42.638985Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
},
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Perplexity 10.6, 13809 tokens/sec on gpu(0)\n",
"time travellere the the the the the the the the the the the the \n",
"travellere the the the the the the the the the the the the \n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()\n",
"num_epochs, lr = 50, 1\n",
"model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,\n",
" init_gru_state, gru)\n",
"d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Concise implementation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-03T23:15:13.493113Z",
"start_time": "2019-07-03T23:14:22.158378Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "9"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Perplexity 1.1, 99504 tokens/sec on gpu(0)\n",
"time traveller sit as he ghe wis our some said filby can exputt\n",
"traveller sut esh uighay an all disconestwink we caur the s\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"gru_layer = rnn.GRU(num_hiddens)\n",
"model = d2l.RNNModel(gru_layer, len(vocab))\n",
"d2l.train_ch8(model, train_iter, vocab, lr, num_epochs*10, ctx)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}