{ "cells": [ { "cell_type": "markdown", "id": "74f2b62b-44f3-4a73-987c-f98bd3fbd016", "metadata": {}, "source": [ "# 1. Assume that we only want to use the input at time step $t'$ to predict the output at time step $t\\gt t'$. What are the best values for the reset and update gates for each time step?" ] }, { "cell_type": "markdown", "id": "1bf88dea-fcd7-4d2b-a9a0-f89d501a6ded", "metadata": {}, "source": [ "If we only want to use the input at time step $t'$ to predict the output at time step $t\\gt t'$, then we want to preserve the hidden state $h_{t'}$ as much as possible and ignore the inputs at other time steps. Therefore, the best values for the reset and update gates are:\n", "\n", "- For $k \\lt t'$, we set $z_k=0$ and $r_k=0$. This means that we do not update the hidden state at all and we do not use the previous hidden state to compute the candidate hidden state. The hidden state remains as the initial value $h_0$ until time step $t'$.\n", "- For $k = t'$, we set $z_{t'}=0$ and $r_{t'}=1$. This means that we do not update the hidden state with the candidate hidden state, but we use the previous hidden state $h_0$ to compute the candidate hidden state. The candidate hidden state becomes a function of the input at time step $t'$ and the initial value $h_0$. The hidden state remains as $h_{t'}=\\frac{M^{t'}x_0}{\\|M^{t'}x_0\\|}$.\n", "- For $k \\gt t'$, we set $z_k=1$ and $r_k=0$. This means that we update the hidden state with the candidate hidden state, but we do not use the previous hidden state to compute the candidate hidden state. The candidate hidden state becomes a function of only the input at each time step. The hidden state becomes a weighted average of the candidate hidden state and the hidden state at time step $t'$. The hidden state remains close to $h_{t'}$ as long as the inputs are not too large.\n", "\n", "This way, we can ensure that the hidden state at time step $t$ is mostly influenced by the input at time step $t'$ and not by other inputs.\n", "\n", "- (1) 10.2. Gated Recurrent Units (GRU) — Dive into Deep .... https://d2l.ai/chapter_recurrent-modern/gru.html.\n", "- (2) GRU unit: difference between Update and Forget gates. https://stats.stackexchange.com/questions/586668/gru-unit-difference-between-update-and-forget-gates.\n", "- (3) Gated Recurrent Unit Networks - GeeksforGeeks. https://www.geeksforgeeks.org/gated-recurrent-unit-networks/.\n", "- (4) Gated Recurrent Unit – What Is It And How To Learn .... https://analyticsindiamag.com/gated-recurrent-unit-what-is-it-and-how-to-learn/.\n", "(5) GRU(Gated recurrent unit)结构介绍 - CSDN博客. https://blog.csdn.net/fangfanglovezhou/article/details/122560797." ] }, { "cell_type": "markdown", "id": "bc598722-0257-41e5-b40b-41075c3d3e1d", "metadata": {}, "source": [ "# 2. Adjust the hyperparameters and analyze their influence on running time, perplexity, and the output sequence." ] }, { "cell_type": "code", "execution_count": 1, "id": "69537900-b85e-48ff-aeb2-e51c2b7fe5fe", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sys\n", "import torch.nn as nn\n", "import torch\n", "import warnings\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "from torchsummary import summary\n", "warnings.filterwarnings(\"ignore\")\n", "from sklearn.model_selection import ParameterGrid\n", "\n", "class GRU(d2l.RNN):\n", " def __init__(self, num_inputs, num_hiddens):\n", " d2l.Module.__init__(self)\n", " self.save_hyperparameters()\n", " self.rnn = nn.GRU(num_inputs, num_hiddens)\n", " \n", "def stat_val(model, data):\n", " ppls = []\n", " for batch in iter(data.get_dataloader(False)):\n", " ppls.append(model.validation_step(batch, plot_flag=False).detach().numpy())\n", " return np.exp(np.mean(ppls))\n", "\n", "def experient(data_class=d2l.TimeMachine, num_steps=32, num_hiddens=32, lr=1):\n", " data = data_class(batch_size=1024, num_steps=num_steps)\n", " gru = GRU(num_inputs=len(data.vocab), num_hiddens=num_hiddens)\n", " model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=lr)\n", " trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1) #, num_gpus=1\n", " trainer.fit(model, data)\n", " return stat_val(model, data)" ] }, { "cell_type": "code", "execution_count": null, "id": "a34a1342-bf71-4a8b-94fb-c0088e8b1563", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-05T06:14:14.177673\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "param_grid = {'num_steps':[8, 16, 32, 64, 128],\n", " 'num_hiddens':[8, 16, 32, 64, 128],\n", " 'lr':[0.01,0.1,1,10]}\n", "param_grid_obj = ParameterGrid(param_grid)\n", "ppls = []\n", "for params in param_grid_obj:\n", " ppl = experient(**params)\n", " ppls.append(ppl)\n", " print(params, ppl)" ] }, { "cell_type": "markdown", "id": "aba93657-7f18-4a89-b247-9aee8ce115dc", "metadata": {}, "source": [ "# 3. Compare runtime, perplexity, and the output strings for rnn.RNN and rnn.GRU implementations with each other." ] }, { "cell_type": "markdown", "id": "f4cd6b6a-e9ff-4e38-b177-96ea921047e6", "metadata": {}, "source": [ "# 4. What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate?" ] }, { "cell_type": "code", "execution_count": null, "id": "f178dc3a-9d70-4516-9199-d98d5cc42caa", "metadata": {}, "outputs": [], "source": [ "class ResetGRUScratch(d2l.Module):\n", " def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n", " triple = lambda: (init_weight(num_inputs, num_hiddens),\n", " init_weight(num_hiddens, num_hiddens),\n", " nn.Parameter(torch.zeros(num_hiddens)))\n", " # self.W_xz, self.W_hz, self.b_z = triple() # Update gate\n", " self.W_xr, self.W_hr, self.b_r = triple() # Reset gate\n", " self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state\n", " \n", " def forward(self, inputs, H=None):\n", " if H is None:\n", " # Initial state with shape: (batch_size, num_hiddens)\n", " H = torch.zeros((inputs.shape[1], self.num_hiddens),\n", " device=inputs.device)\n", " outputs = []\n", " for X in inputs:\n", " # Z = torch.sigmoid(torch.matmul(X, self.W_xz) +\n", " # torch.matmul(H, self.W_hz) + self.b_z)\n", " R = torch.sigmoid(torch.matmul(X, self.W_xr) +\n", " torch.matmul(H, self.W_hr) + self.b_r)\n", " H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +\n", " torch.matmul(R * H, self.W_hh) + self.b_h)\n", " H = Z * H + (1 - Z) * H_tilde\n", " H = H_tilde\n", " outputs.append(H)\n", " return outputs, H\n", " \n", "class UpdateGRUScratch(d2l.Module):\n", " def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n", " triple = lambda: (init_weight(num_inputs, num_hiddens),\n", " init_weight(num_hiddens, num_hiddens),\n", " nn.Parameter(torch.zeros(num_hiddens)))\n", " self.W_xz, self.W_hz, self.b_z = triple() # Update gate\n", " # self.W_xr, self.W_hr, self.b_r = triple() # Reset gate\n", " self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state\n", " \n", " def forward(self, inputs, H=None):\n", " if H is None:\n", " # Initial state with shape: (batch_size, num_hiddens)\n", " H = torch.zeros((inputs.shape[1], self.num_hiddens),\n", " device=inputs.device)\n", " outputs = []\n", " for X in inputs:\n", " Z = torch.sigmoid(torch.matmul(X, self.W_xz) +\n", " torch.matmul(H, self.W_hz) + self.b_z)\n", " # R = torch.sigmoid(torch.matmul(X, self.W_xr) +\n", " # torch.matmul(H, self.W_hr) + self.b_r)\n", " H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +\n", " torch.matmul(H, self.W_hh) + self.b_h)\n", " H = Z * H + (1 - Z) * H_tilde\n", " H = H_tilde\n", " outputs.append(H)\n", " return outputs, H" ] }, { "cell_type": "code", "execution_count": null, "id": "81661187-9fcb-43c2-8fbc-4ce720d18989", "metadata": {}, "outputs": [], "source": [ "data = d2l.TimeMachine(batch_size=1024, num_steps=32)\n", "gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)\n", "model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)\n", "trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }