{
"cells": [
{
"cell_type": "markdown",
"id": "aff7c9d8",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Gated Recurrent Units (GRU)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6851ec0b",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:50:04.809302Z",
"iopub.status.busy": "2023-08-18T19:50:04.808778Z",
"iopub.status.idle": "2023-08-18T19:50:07.808414Z",
"shell.execute_reply": "2023-08-18T19:50:07.807417Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9e6d53ee",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Initializing Model Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1f2fcd5e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.813979Z",
"iopub.status.busy": "2023-08-18T19:50:07.813174Z",
"iopub.status.idle": "2023-08-18T19:50:07.819841Z",
"shell.execute_reply": "2023-08-18T19:50:07.818739Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class GRUScratch(d2l.Module):\n",
" def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n",
" triple = lambda: (init_weight(num_inputs, num_hiddens),\n",
" init_weight(num_hiddens, num_hiddens),\n",
" nn.Parameter(torch.zeros(num_hiddens)))\n",
" self.W_xz, self.W_hz, self.b_z = triple()\n",
" self.W_xr, self.W_hr, self.b_r = triple()\n",
" self.W_xh, self.W_hh, self.b_h = triple()"
]
},
{
"cell_type": "markdown",
"id": "6bcdedca",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Define the GRU forward computation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "78b86a43",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.824621Z",
"iopub.status.busy": "2023-08-18T19:50:07.823909Z",
"iopub.status.idle": "2023-08-18T19:50:07.830603Z",
"shell.execute_reply": "2023-08-18T19:50:07.829486Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"@d2l.add_to_class(GRUScratch)\n",
"def forward(self, inputs, H=None):\n",
" if H is None:\n",
" H = torch.zeros((inputs.shape[1], self.num_hiddens),\n",
" device=inputs.device)\n",
" outputs = []\n",
" for X in inputs:\n",
" Z = torch.sigmoid(torch.matmul(X, self.W_xz) +\n",
" torch.matmul(H, self.W_hz) + self.b_z)\n",
" R = torch.sigmoid(torch.matmul(X, self.W_xr) +\n",
" torch.matmul(H, self.W_hr) + self.b_r)\n",
" H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +\n",
" torch.matmul(R * H, self.W_hh) + self.b_h)\n",
" H = Z * H + (1 - Z) * H_tilde\n",
" outputs.append(H)\n",
" return outputs, H"
]
},
{
"cell_type": "markdown",
"id": "f90925b9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ecd79fad",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.835201Z",
"iopub.status.busy": "2023-08-18T19:50:07.834646Z",
"iopub.status.idle": "2023-08-18T19:51:44.215275Z",
"shell.execute_reply": "2023-08-18T19:51:44.214117Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"