{
"cells": [
{
"cell_type": "markdown",
"id": "bfb9b0de",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Long Short-Term Memory (LSTM)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "af24541f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:39.252438Z",
"iopub.status.busy": "2023-08-18T19:53:39.251775Z",
"iopub.status.idle": "2023-08-18T19:53:42.244563Z",
"shell.execute_reply": "2023-08-18T19:53:42.243546Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9ed5d075",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initializing Model Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "344044a5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:42.249641Z",
"iopub.status.busy": "2023-08-18T19:53:42.248681Z",
"iopub.status.idle": "2023-08-18T19:53:42.259080Z",
"shell.execute_reply": "2023-08-18T19:53:42.257966Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class LSTMScratch(d2l.Module):\n",
" def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n",
" triple = lambda: (init_weight(num_inputs, num_hiddens),\n",
" init_weight(num_hiddens, num_hiddens),\n",
" nn.Parameter(torch.zeros(num_hiddens)))\n",
" self.W_xi, self.W_hi, self.b_i = triple()\n",
" self.W_xf, self.W_hf, self.b_f = triple()\n",
" self.W_xo, self.W_ho, self.b_o = triple()\n",
" self.W_xc, self.W_hc, self.b_c = triple()"
]
},
{
"cell_type": "markdown",
"id": "82c645a9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The actual model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3284d4fa",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:42.263023Z",
"iopub.status.busy": "2023-08-18T19:53:42.262354Z",
"iopub.status.idle": "2023-08-18T19:53:42.269844Z",
"shell.execute_reply": "2023-08-18T19:53:42.269034Z"
},
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"@d2l.add_to_class(LSTMScratch)\n",
"def forward(self, inputs, H_C=None):\n",
" if H_C is None:\n",
" H = torch.zeros((inputs.shape[1], self.num_hiddens),\n",
" device=inputs.device)\n",
" C = torch.zeros((inputs.shape[1], self.num_hiddens),\n",
" device=inputs.device)\n",
" else:\n",
" H, C = H_C\n",
" outputs = []\n",
" for X in inputs:\n",
" I = torch.sigmoid(torch.matmul(X, self.W_xi) +\n",
" torch.matmul(H, self.W_hi) + self.b_i)\n",
" F = torch.sigmoid(torch.matmul(X, self.W_xf) +\n",
" torch.matmul(H, self.W_hf) + self.b_f)\n",
" O = torch.sigmoid(torch.matmul(X, self.W_xo) +\n",
" torch.matmul(H, self.W_ho) + self.b_o)\n",
" C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +\n",
" torch.matmul(H, self.W_hc) + self.b_c)\n",
" C = F * C + I * C_tilde\n",
" H = O * torch.tanh(C)\n",
" outputs.append(H)\n",
" return outputs, (H, C)"
]
},
{
"cell_type": "markdown",
"id": "87662674",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3c605094",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:42.273652Z",
"iopub.status.busy": "2023-08-18T19:53:42.273097Z",
"iopub.status.idle": "2023-08-18T19:55:28.400186Z",
"shell.execute_reply": "2023-08-18T19:55:28.399180Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"