{
"cells": [
{
"cell_type": "markdown",
"id": "aff7c9d8",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Gated Recurrent Units (GRU)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6851ec0b",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:50:04.809302Z",
"iopub.status.busy": "2023-08-18T19:50:04.808778Z",
"iopub.status.idle": "2023-08-18T19:50:07.808414Z",
"shell.execute_reply": "2023-08-18T19:50:07.807417Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9e6d53ee",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Initializing Model Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1f2fcd5e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.813979Z",
"iopub.status.busy": "2023-08-18T19:50:07.813174Z",
"iopub.status.idle": "2023-08-18T19:50:07.819841Z",
"shell.execute_reply": "2023-08-18T19:50:07.818739Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class GRUScratch(d2l.Module):\n",
" def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n",
" triple = lambda: (init_weight(num_inputs, num_hiddens),\n",
" init_weight(num_hiddens, num_hiddens),\n",
" nn.Parameter(torch.zeros(num_hiddens)))\n",
" self.W_xz, self.W_hz, self.b_z = triple()\n",
" self.W_xr, self.W_hr, self.b_r = triple()\n",
" self.W_xh, self.W_hh, self.b_h = triple()"
]
},
{
"cell_type": "markdown",
"id": "6bcdedca",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Define the GRU forward computation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "78b86a43",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.824621Z",
"iopub.status.busy": "2023-08-18T19:50:07.823909Z",
"iopub.status.idle": "2023-08-18T19:50:07.830603Z",
"shell.execute_reply": "2023-08-18T19:50:07.829486Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"@d2l.add_to_class(GRUScratch)\n",
"def forward(self, inputs, H=None):\n",
" if H is None:\n",
" H = torch.zeros((inputs.shape[1], self.num_hiddens),\n",
" device=inputs.device)\n",
" outputs = []\n",
" for X in inputs:\n",
" Z = torch.sigmoid(torch.matmul(X, self.W_xz) +\n",
" torch.matmul(H, self.W_hz) + self.b_z)\n",
" R = torch.sigmoid(torch.matmul(X, self.W_xr) +\n",
" torch.matmul(H, self.W_hr) + self.b_r)\n",
" H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +\n",
" torch.matmul(R * H, self.W_hh) + self.b_h)\n",
" H = Z * H + (1 - Z) * H_tilde\n",
" outputs.append(H)\n",
" return outputs, H"
]
},
{
"cell_type": "markdown",
"id": "f90925b9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ecd79fad",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.835201Z",
"iopub.status.busy": "2023-08-18T19:50:07.834646Z",
"iopub.status.idle": "2023-08-18T19:51:44.215275Z",
"shell.execute_reply": "2023-08-18T19:51:44.214117Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data = d2l.TimeMachine(batch_size=1024, num_steps=32)\n",
"gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)\n",
"model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)\n",
"trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "97ccf919",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Concise Implementation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "66c56966",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:51:44.237345Z",
"iopub.status.busy": "2023-08-18T19:51:44.237065Z",
"iopub.status.idle": "2023-08-18T19:52:51.996558Z",
"shell.execute_reply": "2023-08-18T19:52:51.995714Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class GRU(d2l.RNN):\n",
" def __init__(self, num_inputs, num_hiddens):\n",
" d2l.Module.__init__(self)\n",
" self.save_hyperparameters()\n",
" self.rnn = nn.GRU(num_inputs, num_hiddens)\n",
"\n",
"gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)\n",
"model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "33f8aee3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:52:52.004246Z",
"iopub.status.busy": "2023-08-18T19:52:52.003659Z",
"iopub.status.idle": "2023-08-18T19:52:52.029661Z",
"shell.execute_reply": "2023-08-18T19:52:52.028855Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'it has so it and the time '"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict('it has', 20, data.vocab, d2l.try_gpu())"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}