{ "cells": [ { "cell_type": "code", "execution_count": 36, "id": "1382ab54-6306-4cd0-98b0-203ddf1a108b", "metadata": { "tags": [] }, "outputs": [], "source": [ "DATA_HUB = dict()\n", "DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'\n", "\n", "import numpy as np\n", "import torch\n", "import torchvision\n", "from PIL import Image\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils import data\n", "from torchvision import transforms\n", "\n", "nn_Module = nn.Module\n", "\n", "################# WARNING ################\n", "# The below part is generated automatically through:\n", "# d2lbook build lib\n", "# Don't edit it directly\n", "\n", "import collections\n", "import hashlib\n", "import math\n", "import os\n", "import random\n", "import re\n", "import shutil\n", "import sys\n", "import tarfile\n", "import time\n", "import zipfile\n", "from collections import defaultdict\n", "import pandas as pd\n", "import requests\n", "from IPython import display\n", "from matplotlib import pyplot as plt\n", "from matplotlib_inline import backend_inline\n", "\n", "d2l = sys.modules[__name__]\n", "\n", "import numpy as np\n", "import torch\n", "import torchvision\n", "from PIL import Image\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils import data\n", "from torchvision import transforms\n", "\n", "def use_svg_display():\n", " \"\"\"使用svg格式在Jupyter中显示绘图\n", "\n", " Defined in :numref:`sec_calculus`\"\"\"\n", " backend_inline.set_matplotlib_formats('svg')\n", "\n", "def set_figsize(figsize=(3.5, 2.5)):\n", " \"\"\"设置matplotlib的图表大小\n", "\n", " Defined in :numref:`sec_calculus`\"\"\"\n", " use_svg_display()\n", " d2l.plt.rcParams['figure.figsize'] = figsize\n", "\n", "def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n", " \"\"\"设置matplotlib的轴\n", "\n", " Defined in :numref:`sec_calculus`\"\"\"\n", " axes.set_xlabel(xlabel)\n", " axes.set_ylabel(ylabel)\n", " axes.set_xscale(xscale)\n", " axes.set_yscale(yscale)\n", " axes.set_xlim(xlim)\n", " axes.set_ylim(ylim)\n", " if legend:\n", " axes.legend(legend)\n", " axes.grid()\n", "\n", "def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n", " ylim=None, xscale='linear', yscale='linear',\n", " fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n", " \"\"\"绘制数据点\n", "\n", " Defined in :numref:`sec_calculus`\"\"\"\n", " if legend is None:\n", " legend = []\n", "\n", " set_figsize(figsize)\n", " axes = axes if axes else d2l.plt.gca()\n", "\n", " # 如果X有一个轴,输出True\n", " def has_one_axis(X):\n", " return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n", " and not hasattr(X[0], \"__len__\"))\n", "\n", " if has_one_axis(X):\n", " X = [X]\n", " if Y is None:\n", " X, Y = [[]] * len(X), X\n", " elif has_one_axis(Y):\n", " Y = [Y]\n", " if len(X) != len(Y):\n", " X = X * len(Y)\n", " axes.cla()\n", " for x, y, fmt in zip(X, Y, fmts):\n", " if len(x):\n", " axes.plot(x, y, fmt)\n", " else:\n", " axes.plot(y, fmt)\n", " set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n", "\n", "class Timer:\n", " \"\"\"记录多次运行时间\"\"\"\n", " def __init__(self):\n", " \"\"\"Defined in :numref:`subsec_linear_model`\"\"\"\n", " self.times = []\n", " self.start()\n", "\n", " def start(self):\n", " \"\"\"启动计时器\"\"\"\n", " self.tik = time.time()\n", "\n", " def stop(self):\n", " \"\"\"停止计时器并将时间记录在列表中\"\"\"\n", " self.times.append(time.time() - self.tik)\n", " return self.times[-1]\n", "\n", " def avg(self):\n", " \"\"\"返回平均时间\"\"\"\n", " return sum(self.times) / len(self.times)\n", "\n", " def sum(self):\n", " \"\"\"返回时间总和\"\"\"\n", " return sum(self.times)\n", "\n", " def cumsum(self):\n", " \"\"\"返回累计时间\"\"\"\n", " return np.array(self.times).cumsum().tolist()\n", "\n", "def synthetic_data(w, b, num_examples):\n", " \"\"\"生成y=Xw+b+噪声\n", "\n", " Defined in :numref:`sec_linear_scratch`\"\"\"\n", " X = d2l.normal(0, 1, (num_examples, len(w)))\n", " y = d2l.matmul(X, w) + b\n", " y += d2l.normal(0, 0.01, y.shape)\n", " return X, d2l.reshape(y, (-1, 1))\n", "\n", "def linreg(X, w, b):\n", " \"\"\"线性回归模型\n", "\n", " Defined in :numref:`sec_linear_scratch`\"\"\"\n", " return d2l.matmul(X, w) + b\n", "\n", "def squared_loss(y_hat, y):\n", " \"\"\"均方损失\n", "\n", " Defined in :numref:`sec_linear_scratch`\"\"\"\n", " return (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2\n", "\n", "def sgd(params, lr, batch_size):\n", " \"\"\"小批量随机梯度下降\n", "\n", " Defined in :numref:`sec_linear_scratch`\"\"\"\n", " with torch.no_grad():\n", " for param in params:\n", " param -= lr * param.grad / batch_size\n", " param.grad.zero_()\n", "\n", "def load_array(data_arrays, batch_size, is_train=True):\n", " \"\"\"构造一个PyTorch数据迭代器\n", "\n", " Defined in :numref:`sec_linear_concise`\"\"\"\n", " dataset = data.TensorDataset(*data_arrays)\n", " return data.DataLoader(dataset, batch_size, shuffle=is_train)\n", "\n", "def get_fashion_mnist_labels(labels):\n", " \"\"\"返回Fashion-MNIST数据集的文本标签\n", "\n", " Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n", " 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n", " return [text_labels[int(i)] for i in labels]\n", "\n", "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):\n", " \"\"\"绘制图像列表\n", "\n", " Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " figsize = (num_cols * scale, num_rows * scale)\n", " _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n", " axes = axes.flatten()\n", " for i, (ax, img) in enumerate(zip(axes, imgs)):\n", " if torch.is_tensor(img):\n", " # 图片张量\n", " ax.imshow(img.numpy())\n", " else:\n", " # PIL图片\n", " ax.imshow(img)\n", " ax.axes.get_xaxis().set_visible(False)\n", " ax.axes.get_yaxis().set_visible(False)\n", " if titles:\n", " ax.set_title(titles[i])\n", " return axes\n", "\n", "def get_dataloader_workers():\n", " \"\"\"使用4个进程来读取数据\n", "\n", " Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " return 4\n", "\n", "def load_data_fashion_mnist(batch_size, resize=None):\n", " \"\"\"下载Fashion-MNIST数据集,然后将其加载到内存中\n", "\n", " Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " trans = [transforms.ToTensor()]\n", " if resize:\n", " trans.insert(0, transforms.Resize(resize))\n", " trans = transforms.Compose(trans)\n", " mnist_train = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=True, transform=trans, download=True)\n", " mnist_test = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=False, transform=trans, download=True)\n", " return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n", " num_workers=get_dataloader_workers()),\n", " data.DataLoader(mnist_test, batch_size, shuffle=False,\n", " num_workers=get_dataloader_workers()))\n", "\n", "def accuracy(y_hat, y):\n", " \"\"\"计算预测正确的数量\n", "\n", " Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n", " y_hat = d2l.argmax(y_hat, axis=1)\n", " cmp = d2l.astype(y_hat, y.dtype) == y\n", " return float(d2l.reduce_sum(d2l.astype(cmp, y.dtype)))\n", "\n", "def evaluate_accuracy(net, data_iter):\n", " \"\"\"计算在指定数据集上模型的精度\n", "\n", " Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " if isinstance(net, torch.nn.Module):\n", " net.eval() # 将模型设置为评估模式\n", " metric = Accumulator(2) # 正确预测数、预测总数\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", " metric.add(accuracy(net(X), y), d2l.size(y))\n", " return metric[0] / metric[1]\n", "\n", "class Accumulator:\n", " \"\"\"在n个变量上累加\"\"\"\n", " def __init__(self, n):\n", " \"\"\"Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " self.data = [0.0] * n\n", "\n", " def add(self, *args):\n", " self.data = [a + float(b) for a, b in zip(self.data, args)]\n", "\n", " def reset(self):\n", " self.data = [0.0] * len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]\n", "\n", "def train_epoch_ch3(net, train_iter, loss, updater):\n", " \"\"\"训练模型一个迭代周期(定义见第3章)\n", "\n", " Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " # 将模型设置为训练模式\n", " if isinstance(net, torch.nn.Module):\n", " net.train()\n", " # 训练损失总和、训练准确度总和、样本数\n", " metric = Accumulator(3)\n", " for X, y in train_iter:\n", " # 计算梯度并更新参数\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " if isinstance(updater, torch.optim.Optimizer):\n", " # 使用PyTorch内置的优化器和损失函数\n", " updater.zero_grad()\n", " l.mean().backward()\n", " updater.step()\n", " else:\n", " # 使用定制的优化器和损失函数\n", " l.sum().backward()\n", " updater(X.shape[0])\n", " metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())\n", " # 返回训练损失和训练精度\n", " return metric[0] / metric[2], metric[1] / metric[2]\n", "\n", "class Animator:\n", " \"\"\"在动画中绘制数据\"\"\"\n", " def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n", " ylim=None, xscale='linear', yscale='linear',\n", " fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n", " figsize=(3.5, 2.5)):\n", " \"\"\"Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " # 增量地绘制多条线\n", " if legend is None:\n", " legend = []\n", " d2l.use_svg_display()\n", " self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)\n", " if nrows * ncols == 1:\n", " self.axes = [self.axes, ]\n", " # 使用lambda函数捕获参数\n", " self.config_axes = lambda: d2l.set_axes(\n", " self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n", " self.X, self.Y, self.fmts = None, None, fmts\n", "\n", " def add(self, x, y):\n", " # 向图表中添加多个数据点\n", " if not hasattr(y, \"__len__\"):\n", " y = [y]\n", " n = len(y)\n", " if not hasattr(x, \"__len__\"):\n", " x = [x] * n\n", " if not self.X:\n", " self.X = [[] for _ in range(n)]\n", " if not self.Y:\n", " self.Y = [[] for _ in range(n)]\n", " for i, (a, b) in enumerate(zip(x, y)):\n", " if a is not None and b is not None:\n", " self.X[i].append(a)\n", " self.Y[i].append(b)\n", " self.axes[0].cla()\n", " for x, y, fmt in zip(self.X, self.Y, self.fmts):\n", " self.axes[0].plot(x, y, fmt)\n", " self.config_axes()\n", " display.display(self.fig)\n", " display.clear_output(wait=True)\n", "\n", "def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):\n", " \"\"\"训练模型(定义见第3章)\n", "\n", " Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " for epoch in range(num_epochs):\n", " train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n", " test_acc = evaluate_accuracy(net, test_iter)\n", " animator.add(epoch + 1, train_metrics + (test_acc,))\n", " train_loss, train_acc = train_metrics\n", " assert train_loss < 0.5, train_loss\n", " assert train_acc <= 1 and train_acc > 0.7, train_acc\n", " assert test_acc <= 1 and test_acc > 0.7, test_acc\n", "\n", "def predict_ch3(net, test_iter, n=6):\n", " \"\"\"预测标签(定义见第3章)\n", "\n", " Defined in :numref:`sec_softmax_scratch`\"\"\"\n", " for X, y in test_iter:\n", " break\n", " trues = d2l.get_fashion_mnist_labels(y)\n", " preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))\n", " titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n", " d2l.show_images(\n", " d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])\n", "\n", "def evaluate_loss(net, data_iter, loss):\n", " \"\"\"评估给定数据集上模型的损失\n", "\n", " Defined in :numref:`sec_model_selection`\"\"\"\n", " metric = d2l.Accumulator(2) # 损失的总和,样本数量\n", " for X, y in data_iter:\n", " out = net(X)\n", " y = d2l.reshape(y, out.shape)\n", " l = loss(out, y)\n", " metric.add(d2l.reduce_sum(l), d2l.size(l))\n", " return metric[0] / metric[1]\n", "\n", "DATA_HUB = dict()\n", "DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'\n", "\n", "def download(name, cache_dir=os.path.join('..', 'data')):\n", " \"\"\"下载一个DATA_HUB中的文件,返回本地文件名\n", "\n", " Defined in :numref:`sec_kaggle_house`\"\"\"\n", " assert name in DATA_HUB, f\"{name} 不存在于 {DATA_HUB}\"\n", " url, sha1_hash = DATA_HUB[name]\n", " os.makedirs(cache_dir, exist_ok=True)\n", " fname = os.path.join(cache_dir, url.split('/')[-1])\n", " if os.path.exists(fname):\n", " sha1 = hashlib.sha1()\n", " with open(fname, 'rb') as f:\n", " while True:\n", " data = f.read(1048576)\n", " if not data:\n", " break\n", " sha1.update(data)\n", " if sha1.hexdigest() == sha1_hash:\n", " return fname # 命中缓存\n", " print(f'正在从{url}下载{fname}...')\n", " r = requests.get(url, stream=True, verify=True)\n", " with open(fname, 'wb') as f:\n", " f.write(r.content)\n", " return fname\n", "\n", "def download_extract(name, folder=None):\n", " \"\"\"下载并解压zip/tar文件\n", "\n", " Defined in :numref:`sec_kaggle_house`\"\"\"\n", " fname = download(name)\n", " base_dir = os.path.dirname(fname)\n", " data_dir, ext = os.path.splitext(fname)\n", " if ext == '.zip':\n", " fp = zipfile.ZipFile(fname, 'r')\n", " elif ext in ('.tar', '.gz'):\n", " fp = tarfile.open(fname, 'r')\n", " else:\n", " assert False, '只有zip/tar文件可以被解压缩'\n", " fp.extractall(base_dir)\n", " return os.path.join(base_dir, folder) if folder else data_dir\n", "\n", "def download_all():\n", " \"\"\"下载DATA_HUB中的所有文件\n", "\n", " Defined in :numref:`sec_kaggle_house`\"\"\"\n", " for name in DATA_HUB:\n", " download(name)\n", "\n", "DATA_HUB['kaggle_house_train'] = (\n", " DATA_URL + 'kaggle_house_pred_train.csv',\n", " '585e9cc93e70b39160e7921475f9bcd7d31219ce')\n", "\n", "DATA_HUB['kaggle_house_test'] = (\n", " DATA_URL + 'kaggle_house_pred_test.csv',\n", " 'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')\n", "\n", "def try_gpu(i=0):\n", " \"\"\"如果存在,则返回gpu(i),否则返回cpu()\n", "\n", " Defined in :numref:`sec_use_gpu`\"\"\"\n", " if torch.cuda.device_count() >= i + 1:\n", " return torch.device(f'cuda:{i}')\n", " return torch.device('cpu')\n", "\n", "def try_all_gpus():\n", " \"\"\"返回所有可用的GPU,如果没有GPU,则返回[cpu(),]\n", "\n", " Defined in :numref:`sec_use_gpu`\"\"\"\n", " devices = [torch.device(f'cuda:{i}')\n", " for i in range(torch.cuda.device_count())]\n", " return devices if devices else [torch.device('cpu')]\n", "\n", "def corr2d(X, K):\n", " \"\"\"计算二维互相关运算\n", "\n", " Defined in :numref:`sec_conv_layer`\"\"\"\n", " h, w = K.shape\n", " Y = d2l.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))\n", " for i in range(Y.shape[0]):\n", " for j in range(Y.shape[1]):\n", " Y[i, j] = d2l.reduce_sum((X[i: i + h, j: j + w] * K))\n", " return Y\n", "\n", "def evaluate_accuracy_gpu(net, data_iter, device=None):\n", " \"\"\"使用GPU计算模型在数据集上的精度\n", "\n", " Defined in :numref:`sec_lenet`\"\"\"\n", " if isinstance(net, nn.Module):\n", " net.eval() # 设置为评估模式\n", " if not device:\n", " device = next(iter(net.parameters())).device\n", " # 正确预测的数量,总预测的数量\n", " metric = d2l.Accumulator(2)\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", " if isinstance(X, list):\n", " # BERT微调所需的(之后将介绍)\n", " X = [x.to(device) for x in X]\n", " else:\n", " X = X.to(device)\n", " y = y.to(device)\n", " metric.add(d2l.accuracy(net(X), y), d2l.size(y))\n", " return metric[0] / metric[1]\n", "\n", "def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):\n", " \"\"\"用GPU训练模型(在第六章定义)\n", "\n", " Defined in :numref:`sec_lenet`\"\"\"\n", " def init_weights(m):\n", " if type(m) == nn.Linear or type(m) == nn.Conv2d:\n", " nn.init.xavier_uniform_(m.weight)\n", " net.apply(init_weights)\n", " print('training on', device)\n", " net.to(device)\n", " optimizer = torch.optim.SGD(net.parameters(), lr=lr)\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " timer, num_batches = d2l.Timer(), len(train_iter)\n", " for epoch in range(num_epochs):\n", " # 训练损失之和,训练准确率之和,样本数\n", " metric = d2l.Accumulator(3)\n", " net.train()\n", " for i, (X, y) in enumerate(train_iter):\n", " timer.start()\n", " optimizer.zero_grad()\n", " X, y = X.to(device), y.to(device)\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " l.backward()\n", " optimizer.step()\n", " with torch.no_grad():\n", " metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])\n", " timer.stop()\n", " train_l = metric[0] / metric[2]\n", " train_acc = metric[1] / metric[2]\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (train_l, train_acc, None))\n", " test_acc = evaluate_accuracy_gpu(net, test_iter)\n", " animator.add(epoch + 1, (None, None, test_acc))\n", " print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '\n", " f'test acc {test_acc:.3f}')\n", " print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '\n", " f'on {str(device)}')\n", "\n", "class Residual(nn.Module):\n", " def __init__(self, input_channels, num_channels,\n", " use_1x1conv=False, strides=1):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(input_channels, num_channels,\n", " kernel_size=3, padding=1, stride=strides)\n", " self.conv2 = nn.Conv2d(num_channels, num_channels,\n", " kernel_size=3, padding=1)\n", " if use_1x1conv:\n", " self.conv3 = nn.Conv2d(input_channels, num_channels,\n", " kernel_size=1, stride=strides)\n", " else:\n", " self.conv3 = None\n", " self.bn1 = nn.BatchNorm2d(num_channels)\n", " self.bn2 = nn.BatchNorm2d(num_channels)\n", "\n", " def forward(self, X):\n", " Y = F.relu(self.bn1(self.conv1(X)))\n", " Y = self.bn2(self.conv2(Y))\n", " if self.conv3:\n", " X = self.conv3(X)\n", " Y += X\n", " return F.relu(Y)\n", "\n", "d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt',\n", " '090b5e7e70c295757f55df93cb0a180b9691891a')\n", "\n", "def read_time_machine():\n", " \"\"\"将时间机器数据集加载到文本行的列表中\n", "\n", " Defined in :numref:`sec_text_preprocessing`\"\"\"\n", " with open(d2l.download('time_machine'), 'r') as f:\n", " lines = f.readlines()\n", " return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n", "\n", "def tokenize(lines, token='word'):\n", " \"\"\"将文本行拆分为单词或字符词元\n", "\n", " Defined in :numref:`sec_text_preprocessing`\"\"\"\n", " if token == 'word':\n", " return [line.split() for line in lines]\n", " elif token == 'char':\n", " return [list(line) for line in lines]\n", " else:\n", " print('错误:未知词元类型:' + token)\n", "\n", "class Vocab:\n", " \"\"\"文本词表\"\"\"\n", " def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n", " \"\"\"Defined in :numref:`sec_text_preprocessing`\"\"\"\n", " if tokens is None:\n", " tokens = []\n", " if reserved_tokens is None:\n", " reserved_tokens = []\n", " # 按出现频率排序\n", " counter = count_corpus(tokens)\n", " self._token_freqs = sorted(counter.items(), key=lambda x: x[1],\n", " reverse=True)\n", " # 未知词元的索引为0\n", " self.idx_to_token = [''] + reserved_tokens\n", " self.token_to_idx = {token: idx\n", " for idx, token in enumerate(self.idx_to_token)}\n", " for token, freq in self._token_freqs:\n", " if freq < min_freq:\n", " break\n", " if token not in self.token_to_idx:\n", " self.idx_to_token.append(token)\n", " self.token_to_idx[token] = len(self.idx_to_token) - 1\n", "\n", " def __len__(self):\n", " return len(self.idx_to_token)\n", "\n", " def __getitem__(self, tokens):\n", " if not isinstance(tokens, (list, tuple)):\n", " return self.token_to_idx.get(tokens, self.unk)\n", " return [self.__getitem__(token) for token in tokens]\n", "\n", " def to_tokens(self, indices):\n", " if not isinstance(indices, (list, tuple)):\n", " return self.idx_to_token[indices]\n", " return [self.idx_to_token[index] for index in indices]\n", "\n", " @property\n", " def unk(self): # 未知词元的索引为0\n", " return 0\n", "\n", " @property\n", " def token_freqs(self):\n", " return self._token_freqs\n", "\n", "def count_corpus(tokens):\n", " \"\"\"统计词元的频率\n", "\n", " Defined in :numref:`sec_text_preprocessing`\"\"\"\n", " # 这里的tokens是1D列表或2D列表\n", " if len(tokens) == 0 or isinstance(tokens[0], list):\n", " # 将词元列表展平成一个列表\n", " tokens = [token for line in tokens for token in line]\n", " return collections.Counter(tokens)\n", "\n", "def load_corpus_time_machine(max_tokens=-1):\n", " \"\"\"返回时光机器数据集的词元索引列表和词表\n", "\n", " Defined in :numref:`sec_text_preprocessing`\"\"\"\n", " lines = read_time_machine()\n", " tokens = tokenize(lines, 'char')\n", " vocab = Vocab(tokens)\n", " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", " # 所以将所有文本行展平到一个列表中\n", " corpus = [vocab[token] for line in tokens for token in line]\n", " if max_tokens > 0:\n", " corpus = corpus[:max_tokens]\n", " return corpus, vocab\n", "\n", "def seq_data_iter_random(corpus, batch_size, num_steps):\n", " \"\"\"使用随机抽样生成一个小批量子序列\n", "\n", " Defined in :numref:`sec_language_model`\"\"\"\n", " # 从随机偏移量开始对序列进行分区,随机范围包括num_steps-1\n", " corpus = corpus[random.randint(0, num_steps - 1):]\n", " # 减去1,是因为我们需要考虑标签\n", " num_subseqs = (len(corpus) - 1) // num_steps\n", " # 长度为num_steps的子序列的起始索引\n", " initial_indices = list(range(0, num_subseqs * num_steps, num_steps))\n", " # 在随机抽样的迭代过程中,\n", " # 来自两个相邻的、随机的、小批量中的子序列不一定在原始序列上相邻\n", " random.shuffle(initial_indices)\n", "\n", " def data(pos):\n", " # 返回从pos位置开始的长度为num_steps的序列\n", " return corpus[pos: pos + num_steps]\n", "\n", " num_batches = num_subseqs // batch_size\n", " for i in range(0, batch_size * num_batches, batch_size):\n", " # 在这里,initial_indices包含子序列的随机起始索引\n", " initial_indices_per_batch = initial_indices[i: i + batch_size]\n", " X = [data(j) for j in initial_indices_per_batch]\n", " Y = [data(j + 1) for j in initial_indices_per_batch]\n", " yield d2l.tensor(X), d2l.tensor(Y)\n", "\n", "def seq_data_iter_sequential(corpus, batch_size, num_steps):\n", " \"\"\"使用顺序分区生成一个小批量子序列\n", "\n", " Defined in :numref:`sec_language_model`\"\"\"\n", " # 从随机偏移量开始划分序列\n", " offset = random.randint(0, num_steps)\n", " num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size\n", " Xs = d2l.tensor(corpus[offset: offset + num_tokens])\n", " Ys = d2l.tensor(corpus[offset + 1: offset + 1 + num_tokens])\n", " Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)\n", " num_batches = Xs.shape[1] // num_steps\n", " for i in range(0, num_steps * num_batches, num_steps):\n", " X = Xs[:, i: i + num_steps]\n", " Y = Ys[:, i: i + num_steps]\n", " yield X, Y\n", "\n", "class SeqDataLoader:\n", " \"\"\"加载序列数据的迭代器\"\"\"\n", " def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):\n", " \"\"\"Defined in :numref:`sec_language_model`\"\"\"\n", " if use_random_iter:\n", " self.data_iter_fn = d2l.seq_data_iter_random\n", " else:\n", " self.data_iter_fn = d2l.seq_data_iter_sequential\n", " self.corpus, self.vocab = d2l.load_corpus_time_machine(max_tokens)\n", " self.batch_size, self.num_steps = batch_size, num_steps\n", "\n", " def __iter__(self):\n", " return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)\n", "\n", "def load_data_time_machine(batch_size, num_steps,\n", " use_random_iter=False, max_tokens=10000):\n", " \"\"\"返回时光机器数据集的迭代器和词表\n", "\n", " Defined in :numref:`sec_language_model`\"\"\"\n", " data_iter = SeqDataLoader(\n", " batch_size, num_steps, use_random_iter, max_tokens)\n", " return data_iter, data_iter.vocab\n", "\n", "class RNNModelScratch:\n", " \"\"\"从零开始实现的循环神经网络模型\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, device,\n", " get_params, init_state, forward_fn):\n", " \"\"\"Defined in :numref:`sec_rnn_scratch`\"\"\"\n", " self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n", " self.params = get_params(vocab_size, num_hiddens, device)\n", " self.init_state, self.forward_fn = init_state, forward_fn\n", "\n", " def __call__(self, X, state):\n", " X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n", " return self.forward_fn(X, state, self.params)\n", "\n", " def begin_state(self, batch_size, device):\n", " return self.init_state(batch_size, self.num_hiddens, device)\n", "\n", "def predict_ch8(prefix, num_preds, net, vocab, device):\n", " \"\"\"在prefix后面生成新字符\n", "\n", " Defined in :numref:`sec_rnn_scratch`\"\"\"\n", " state = net.begin_state(batch_size=1, device=device)\n", " outputs = [vocab[prefix[0]]]\n", " get_input = lambda: d2l.reshape(d2l.tensor(\n", " [outputs[-1]], device=device), (1, 1))\n", " for y in prefix[1:]: # 预热期\n", " _, state = net(get_input(), state)\n", " outputs.append(vocab[y])\n", " for _ in range(num_preds): # 预测num_preds步\n", " y, state = net(get_input(), state)\n", " outputs.append(int(y.argmax(dim=1).reshape(1)))\n", " return ''.join([vocab.idx_to_token[i] for i in outputs])\n", "\n", "def grad_clipping(net, theta):\n", " \"\"\"裁剪梯度\n", "\n", " Defined in :numref:`sec_rnn_scratch`\"\"\"\n", " if isinstance(net, nn.Module):\n", " params = [p for p in net.parameters() if p.requires_grad]\n", " else:\n", " params = net.params\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm\n", "\n", "def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n", " \"\"\"训练网络一个迭代周期(定义见第8章)\n", "\n", " Defined in :numref:`sec_rnn_scratch`\"\"\"\n", " state, timer = None, d2l.Timer()\n", " metric = d2l.Accumulator(2) # 训练损失之和,词元数量\n", " for X, Y in train_iter:\n", " if state is None or use_random_iter:\n", " # 在第一次迭代或使用随机抽样时初始化state\n", " state = net.begin_state(batch_size=X.shape[0], device=device)\n", " else:\n", " if isinstance(net, nn.Module) and not isinstance(state, tuple):\n", " # state对于nn.GRU是个张量\n", " state.detach_()\n", " else:\n", " # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n", " for s in state:\n", " s.detach_()\n", " y = Y.T.reshape(-1)\n", " X, y = X.to(device), y.to(device)\n", " y_hat, state = net(X, state)\n", " l = loss(y_hat, y.long()).mean()\n", " if isinstance(updater, torch.optim.Optimizer):\n", " updater.zero_grad()\n", " l.backward()\n", " grad_clipping(net, 1)\n", " updater.step()\n", " else:\n", " l.backward()\n", " grad_clipping(net, 1)\n", " # 因为已经调用了mean函数\n", " updater(batch_size=1)\n", " metric.add(l * d2l.size(y), d2l.size(y))\n", " return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()\n", "\n", "def train_ch8(net, train_iter, vocab, lr, num_epochs, device,\n", " use_random_iter=False):\n", " \"\"\"训练模型(定义见第8章)\n", "\n", " Defined in :numref:`sec_rnn_scratch`\"\"\"\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n", " legend=['train'], xlim=[10, num_epochs])\n", " # 初始化\n", " if isinstance(net, nn.Module):\n", " updater = torch.optim.SGD(net.parameters(), lr)\n", " else:\n", " updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n", " predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n", " # 训练和预测\n", " for epoch in range(num_epochs):\n", " ppl, speed = train_epoch_ch8(\n", " net, train_iter, loss, updater, device, use_random_iter)\n", " if (epoch + 1) % 10 == 0:\n", " print(predict('time traveller'))\n", " animator.add(epoch + 1, [ppl])\n", " print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n", " print(predict('time traveller'))\n", " print(predict('traveller'))\n", "\n", "class RNNModel(nn.Module):\n", " \"\"\"循环神经网络模型\n", "\n", " Defined in :numref:`sec_rnn-concise`\"\"\"\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.num_hiddens = self.rnn.hidden_size\n", " # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1\n", " if not self.rnn.bidirectional:\n", " self.num_directions = 1\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n", " else:\n", " self.num_directions = 2\n", " self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n", "\n", " def forward(self, inputs, state):\n", " X = F.one_hot(inputs.T.long(), self.vocab_size)\n", " X = X.to(torch.float32)\n", " Y, state = self.rnn(X, state)\n", " # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n", " # 它的输出形状是(时间步数*批量大小,词表大小)。\n", " output = self.linear(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", "\n", " def begin_state(self, device, batch_size=1):\n", " if not isinstance(self.rnn, nn.LSTM):\n", " # nn.GRU以张量作为隐状态\n", " return torch.zeros((self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens),\n", " device=device)\n", " else:\n", " # nn.LSTM以元组作为隐状态\n", " return (torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device),\n", " torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device))\n", "\n", "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5')\n", "\n", "def read_data_nmt():\n", " \"\"\"载入“英语-法语”数据集\n", "\n", " Defined in :numref:`sec_machine_translation`\"\"\"\n", " data_dir = d2l.download_extract('fra-eng')\n", " with open(os.path.join(data_dir, 'fra.txt'), 'r',\n", " encoding='utf-8') as f:\n", " return f.read()\n", "\n", "def preprocess_nmt(text):\n", " \"\"\"预处理“英语-法语”数据集\n", "\n", " Defined in :numref:`sec_machine_translation`\"\"\"\n", " def no_space(char, prev_char):\n", " return char in set(',.!?') and prev_char != ' '\n", "\n", " # 使用空格替换不间断空格\n", " # 使用小写字母替换大写字母\n", " text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n", " # 在单词和标点符号之间插入空格\n", " out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n", " for i, char in enumerate(text)]\n", " return ''.join(out)\n", "\n", "def tokenize_nmt(text, num_examples=None):\n", " \"\"\"词元化“英语-法语”数据数据集\n", "\n", " Defined in :numref:`sec_machine_translation`\"\"\"\n", " source, target = [], []\n", " for i, line in enumerate(text.split('\\n')):\n", " if num_examples and i > num_examples:\n", " break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " source.append(parts[0].split(' '))\n", " target.append(parts[1].split(' '))\n", " return source, target\n", "\n", "def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):\n", " \"\"\"绘制列表长度对的直方图\n", "\n", " Defined in :numref:`sec_machine_translation`\"\"\"\n", " d2l.set_figsize()\n", " _, _, patches = d2l.plt.hist(\n", " [[len(l) for l in xlist], [len(l) for l in ylist]])\n", " d2l.plt.xlabel(xlabel)\n", " d2l.plt.ylabel(ylabel)\n", " for patch in patches[1].patches:\n", " patch.set_hatch('/')\n", " d2l.plt.legend(legend)\n", "\n", "def truncate_pad(line, num_steps, padding_token):\n", " \"\"\"截断或填充文本序列\n", "\n", " Defined in :numref:`sec_machine_translation`\"\"\"\n", " if len(line) > num_steps:\n", " return line[:num_steps] # 截断\n", " return line + [padding_token] * (num_steps - len(line)) # 填充\n", "\n", "def build_array_nmt(lines, vocab, num_steps):\n", " \"\"\"将机器翻译的文本序列转换成小批量\n", "\n", " Defined in :numref:`subsec_mt_data_loading`\"\"\"\n", " lines = [vocab[l] for l in lines]\n", " lines = [l + [vocab['']] for l in lines]\n", " array = d2l.tensor([truncate_pad(\n", " l, num_steps, vocab['']) for l in lines])\n", " valid_len = d2l.reduce_sum(\n", " d2l.astype(array != vocab[''], d2l.int32), 1)\n", " return array, valid_len\n", "\n", "def load_data_nmt(batch_size, num_steps, num_examples=600):\n", " \"\"\"返回翻译数据集的迭代器和词表\n", "\n", " Defined in :numref:`subsec_mt_data_loading`\"\"\"\n", " text = preprocess_nmt(read_data_nmt())\n", " source, target = tokenize_nmt(text, num_examples)\n", " src_vocab = d2l.Vocab(source, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " tgt_vocab = d2l.Vocab(target, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n", " tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n", " data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n", " data_iter = d2l.load_array(data_arrays, batch_size)\n", " return data_iter, src_vocab, tgt_vocab\n", "\n", "class Encoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基本编码器接口\"\"\"\n", " def __init__(self, **kwargs):\n", " super(Encoder, self).__init__(**kwargs)\n", "\n", " def forward(self, X, *args):\n", " raise NotImplementedError\n", "\n", "class Decoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基本解码器接口\n", "\n", " Defined in :numref:`sec_encoder-decoder`\"\"\"\n", " def __init__(self, **kwargs):\n", " super(Decoder, self).__init__(**kwargs)\n", "\n", " def init_state(self, enc_outputs, *args):\n", " raise NotImplementedError\n", "\n", " def forward(self, X, state):\n", " raise NotImplementedError\n", "\n", "class EncoderDecoder(nn.Module):\n", " \"\"\"编码器-解码器架构的基类\n", "\n", " Defined in :numref:`sec_encoder-decoder`\"\"\"\n", " def __init__(self, encoder, decoder, **kwargs):\n", " super(EncoderDecoder, self).__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", "\n", " def forward(self, enc_X, dec_X, *args):\n", " enc_outputs = self.encoder(enc_X, *args)\n", " dec_state = self.decoder.init_state(enc_outputs, *args)\n", " return self.decoder(dec_X, dec_state)\n", "\n", "class Seq2SeqEncoder(d2l.Encoder):\n", " \"\"\"用于序列到序列学习的循环神经网络编码器\n", "\n", " Defined in :numref:`sec_seq2seq`\"\"\"\n", " def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n", " dropout=0, **kwargs):\n", " super(Seq2SeqEncoder, self).__init__(**kwargs)\n", " # 嵌入层\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,\n", " dropout=dropout)\n", "\n", " def forward(self, X, *args):\n", " # 输出'X'的形状:(batch_size,num_steps,embed_size)\n", " X = self.embedding(X)\n", " # 在循环神经网络模型中,第一个轴对应于时间步\n", " X = X.permute(1, 0, 2)\n", " # 如果未提及状态,则默认为0\n", " output, state = self.rnn(X)\n", " # output的形状:(num_steps,batch_size,num_hiddens)\n", " # state的形状:(num_layers,batch_size,num_hiddens)\n", " return output, state\n", "\n", "def sequence_mask(X, valid_len, value=0):\n", " \"\"\"在序列中屏蔽不相关的项\n", "\n", " Defined in :numref:`sec_seq2seq_decoder`\"\"\"\n", " maxlen = X.size(1)\n", " mask = torch.arange((maxlen), dtype=torch.float32,\n", " device=X.device)[None, :] < valid_len[:, None]\n", " X[~mask] = value\n", " return X\n", "\n", "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n", " \"\"\"带遮蔽的softmax交叉熵损失函数\n", "\n", " Defined in :numref:`sec_seq2seq_decoder`\"\"\"\n", " # pred的形状:(batch_size,num_steps,vocab_size)\n", " # label的形状:(batch_size,num_steps)\n", " # valid_len的形状:(batch_size,)\n", " def forward(self, pred, label, valid_len):\n", " weights = torch.ones_like(label)\n", " weights = sequence_mask(weights, valid_len)\n", " self.reduction='none'\n", " unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n", " pred.permute(0, 2, 1), label)\n", " weighted_loss = (unweighted_loss * weights).mean(dim=1)\n", " return weighted_loss\n", "\n", "def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n", " \"\"\"训练序列到序列模型\n", "\n", " Defined in :numref:`sec_seq2seq_decoder`\"\"\"\n", " def xavier_init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", " if type(m) == nn.GRU:\n", " for param in m._flat_weights_names:\n", " if \"weight\" in param:\n", " nn.init.xavier_uniform_(m._parameters[param])\n", "\n", " net.apply(xavier_init_weights)\n", " net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " loss = MaskedSoftmaxCELoss()\n", " net.train()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[10, num_epochs])\n", " for epoch in range(num_epochs):\n", " timer = d2l.Timer()\n", " metric = d2l.Accumulator(2) # 训练损失总和,词元数量\n", " for batch in data_iter:\n", " optimizer.zero_grad()\n", " X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n", " bos = torch.tensor([tgt_vocab['']] * Y.shape[0],\n", " device=device).reshape(-1, 1)\n", " dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学\n", " Y_hat, _ = net(X, dec_input, X_valid_len)\n", " l = loss(Y_hat, Y, Y_valid_len)\n", " l.sum().backward()\t# 损失函数的标量进行“反向传播”\n", " d2l.grad_clipping(net, 1)\n", " num_tokens = Y_valid_len.sum()\n", " optimizer.step()\n", " with torch.no_grad():\n", " metric.add(l.sum(), num_tokens)\n", " if (epoch + 1) % 10 == 0:\n", " animator.add(epoch + 1, (metric[0] / metric[1],))\n", " print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n", " f'tokens/sec on {str(device)}')\n", "\n", "def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,\n", " device, save_attention_weights=False):\n", " \"\"\"序列到序列模型的预测\n", "\n", " Defined in :numref:`sec_seq2seq_training`\"\"\"\n", " # 在预测时将net设置为评估模式\n", " net.eval()\n", " src_tokens = src_vocab[src_sentence.lower().split(' ')] + [\n", " src_vocab['']]\n", " enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n", " src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab[''])\n", " # 添加批量轴\n", " enc_X = torch.unsqueeze(\n", " torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n", " enc_outputs = net.encoder(enc_X, enc_valid_len)\n", " dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n", " # 添加批量轴\n", " dec_X = torch.unsqueeze(torch.tensor(\n", " [tgt_vocab['']], dtype=torch.long, device=device), dim=0)\n", " output_seq, attention_weight_seq = [], []\n", " for _ in range(num_steps):\n", " Y, dec_state = net.decoder(dec_X, dec_state)\n", " # 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入\n", " dec_X = Y.argmax(dim=2)\n", " pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n", " # 保存注意力权重(稍后讨论)\n", " if save_attention_weights:\n", " attention_weight_seq.append(net.decoder.attention_weights)\n", " # 一旦序列结束词元被预测,输出序列的生成就完成了\n", " if pred == tgt_vocab['']:\n", " break\n", " output_seq.append(pred)\n", " return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq\n", "\n", "def bleu(pred_seq, label_seq, k):\n", " \"\"\"计算BLEU\n", "\n", " Defined in :numref:`sec_seq2seq_training`\"\"\"\n", " pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')\n", " len_pred, len_label = len(pred_tokens), len(label_tokens)\n", " score = math.exp(min(0, 1 - len_label / len_pred))\n", " for n in range(1, k + 1):\n", " num_matches, label_subs = 0, collections.defaultdict(int)\n", " for i in range(len_label - n + 1):\n", " label_subs[' '.join(label_tokens[i: i + n])] += 1\n", " for i in range(len_pred - n + 1):\n", " if label_subs[' '.join(pred_tokens[i: i + n])] > 0:\n", " num_matches += 1\n", " label_subs[' '.join(pred_tokens[i: i + n])] -= 1\n", " score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n", " return score\n", "\n", "def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),\n", " cmap='Reds'):\n", " \"\"\"显示矩阵热图\n", "\n", " Defined in :numref:`sec_attention-cues`\"\"\"\n", " d2l.use_svg_display()\n", " num_rows, num_cols = matrices.shape[0], matrices.shape[1]\n", " fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,\n", " sharex=True, sharey=True, squeeze=False)\n", " for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):\n", " for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):\n", " pcm = ax.imshow(d2l.numpy(matrix), cmap=cmap)\n", " if i == num_rows - 1:\n", " ax.set_xlabel(xlabel)\n", " if j == 0:\n", " ax.set_ylabel(ylabel)\n", " if titles:\n", " ax.set_title(titles[j])\n", " fig.colorbar(pcm, ax=axes, shrink=0.6);\n", "\n", "def masked_softmax(X, valid_lens):\n", " \"\"\"通过在最后一个轴上掩蔽元素来执行softmax操作\n", "\n", " Defined in :numref:`sec_attention-scoring-functions`\"\"\"\n", " # X:3D张量,valid_lens:1D或2D张量\n", " if valid_lens is None:\n", " return nn.functional.softmax(X, dim=-1)\n", " else:\n", " shape = X.shape\n", " if valid_lens.dim() == 1:\n", " valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n", " else:\n", " valid_lens = valid_lens.reshape(-1)\n", " # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0\n", " X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,\n", " value=-1e6)\n", " return nn.functional.softmax(X.reshape(shape), dim=-1)\n", "\n", "class AdditiveAttention(nn.Module):\n", " \"\"\"加性注意力\n", "\n", " Defined in :numref:`sec_attention-scoring-functions`\"\"\"\n", " def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):\n", " super(AdditiveAttention, self).__init__(**kwargs)\n", " self.W_k = nn.Linear(key_size, num_hiddens, bias=False)\n", " self.W_q = nn.Linear(query_size, num_hiddens, bias=False)\n", " self.w_v = nn.Linear(num_hiddens, 1, bias=False)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, queries, keys, values, valid_lens):\n", " queries, keys = self.W_q(queries), self.W_k(keys)\n", " # 在维度扩展后,\n", " # queries的形状:(batch_size,查询的个数,1,num_hidden)\n", " # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)\n", " # 使用广播方式进行求和\n", " features = queries.unsqueeze(2) + keys.unsqueeze(1)\n", " features = torch.tanh(features)\n", " # self.w_v仅有一个输出,因此从形状中移除最后那个维度。\n", " # scores的形状:(batch_size,查询的个数,“键-值”对的个数)\n", " scores = self.w_v(features).squeeze(-1)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", " # values的形状:(batch_size,“键-值”对的个数,值的维度)\n", " return torch.bmm(self.dropout(self.attention_weights), values)\n", "\n", "class DotProductAttention(nn.Module):\n", " \"\"\"缩放点积注意力\n", "\n", " Defined in :numref:`subsec_additive-attention`\"\"\"\n", " def __init__(self, dropout, **kwargs):\n", " super(DotProductAttention, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " # queries的形状:(batch_size,查询的个数,d)\n", " # keys的形状:(batch_size,“键-值”对的个数,d)\n", " # values的形状:(batch_size,“键-值”对的个数,值的维度)\n", " # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)\n", " def forward(self, queries, keys, values, valid_lens=None):\n", " d = queries.shape[-1]\n", " # 设置transpose_b=True为了交换keys的最后两个维度\n", " scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", " return torch.bmm(self.dropout(self.attention_weights), values)\n", "\n", "class AttentionDecoder(d2l.Decoder):\n", " \"\"\"带有注意力机制解码器的基本接口\n", "\n", " Defined in :numref:`sec_seq2seq_attention`\"\"\"\n", " def __init__(self, **kwargs):\n", " super(AttentionDecoder, self).__init__(**kwargs)\n", "\n", " @property\n", " def attention_weights(self):\n", " raise NotImplementedError\n", "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\"多头注意力\n", "\n", " Defined in :numref:`sec_multihead-attention`\"\"\"\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " num_heads, dropout, bias=False, **kwargs):\n", " super(MultiHeadAttention, self).__init__(**kwargs)\n", " self.num_heads = num_heads\n", " self.attention = d2l.DotProductAttention(dropout)\n", " self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n", " self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n", " self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n", " self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n", "\n", " def forward(self, queries, keys, values, valid_lens):\n", " # queries,keys,values的形状:\n", " # (batch_size,查询或者“键-值”对的个数,num_hiddens)\n", " # valid_lens 的形状:\n", " # (batch_size,)或(batch_size,查询的个数)\n", " # 经过变换后,输出的queries,keys,values 的形状:\n", " # (batch_size*num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " queries = transpose_qkv(self.W_q(queries), self.num_heads)\n", " keys = transpose_qkv(self.W_k(keys), self.num_heads)\n", " values = transpose_qkv(self.W_v(values), self.num_heads)\n", "\n", " if valid_lens is not None:\n", " # 在轴0,将第一项(标量或者矢量)复制num_heads次,\n", " # 然后如此复制第二项,然后诸如此类。\n", " valid_lens = torch.repeat_interleave(\n", " valid_lens, repeats=self.num_heads, dim=0)\n", "\n", " # output的形状:(batch_size*num_heads,查询的个数,\n", " # num_hiddens/num_heads)\n", " output = self.attention(queries, keys, values, valid_lens)\n", "\n", " # output_concat的形状:(batch_size,查询的个数,num_hiddens)\n", " output_concat = transpose_output(output, self.num_heads)\n", " return self.W_o(output_concat)\n", "\n", "def transpose_qkv(X, num_heads):\n", " \"\"\"为了多注意力头的并行计算而变换形状\n", "\n", " Defined in :numref:`sec_multihead-attention`\"\"\"\n", " # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)\n", " # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,\n", " # num_hiddens/num_heads)\n", " X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n", "\n", " # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " X = X.permute(0, 2, 1, 3)\n", "\n", " # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n", " # num_hiddens/num_heads)\n", " return X.reshape(-1, X.shape[2], X.shape[3])\n", "\n", "\n", "def transpose_output(X, num_heads):\n", " \"\"\"逆转transpose_qkv函数的操作\n", "\n", " Defined in :numref:`sec_multihead-attention`\"\"\"\n", " X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n", " X = X.permute(0, 2, 1, 3)\n", " return X.reshape(X.shape[0], X.shape[1], -1)\n", "\n", "class PositionalEncoding(nn.Module):\n", " \"\"\"位置编码\n", "\n", " Defined in :numref:`sec_self-attention-and-positional-encoding`\"\"\"\n", " def __init__(self, num_hiddens, dropout, max_len=1000):\n", " super(PositionalEncoding, self).__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " # 创建一个足够长的P\n", " self.P = d2l.zeros((1, max_len, num_hiddens))\n", " X = d2l.arange(max_len, dtype=torch.float32).reshape(\n", " -1, 1) / torch.pow(10000, torch.arange(\n", " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", " self.P[:, :, 0::2] = torch.sin(X)\n", " self.P[:, :, 1::2] = torch.cos(X)\n", "\n", " def forward(self, X):\n", " X = X + self.P[:, :X.shape[1], :].to(X.device)\n", " return self.dropout(X)\n", "\n", "class PositionWiseFFN(nn.Module):\n", " \"\"\"基于位置的前馈网络\n", "\n", " Defined in :numref:`sec_transformer`\"\"\"\n", " def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,\n", " **kwargs):\n", " super(PositionWiseFFN, self).__init__(**kwargs)\n", " self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)\n", " self.relu = nn.ReLU()\n", " self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)\n", "\n", " def forward(self, X):\n", " return self.dense2(self.relu(self.dense1(X)))\n", "\n", "class AddNorm(nn.Module):\n", " \"\"\"残差连接后进行层规范化\n", "\n", " Defined in :numref:`sec_transformer`\"\"\"\n", " def __init__(self, normalized_shape, dropout, **kwargs):\n", " super(AddNorm, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", " self.ln = nn.LayerNorm(normalized_shape)\n", "\n", " def forward(self, X, Y):\n", " return self.ln(self.dropout(Y) + X)\n", "\n", "class EncoderBlock(nn.Module):\n", " \"\"\"Transformer编码器块\n", "\n", " Defined in :numref:`sec_transformer`\"\"\"\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " dropout, use_bias=False, **kwargs):\n", " super(EncoderBlock, self).__init__(**kwargs)\n", " self.attention = d2l.MultiHeadAttention(\n", " key_size, query_size, value_size, num_hiddens, num_heads, dropout,\n", " use_bias)\n", " self.addnorm1 = AddNorm(norm_shape, dropout)\n", " self.ffn = PositionWiseFFN(\n", " ffn_num_input, ffn_num_hiddens, num_hiddens)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", "\n", " def forward(self, X, valid_lens):\n", " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", " return self.addnorm2(Y, self.ffn(Y))\n", "\n", "class TransformerEncoder(d2l.Encoder):\n", " \"\"\"Transformer编码器\n", "\n", " Defined in :numref:`sec_transformer`\"\"\"\n", " def __init__(self, vocab_size, key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, num_layers, dropout, use_bias=False, **kwargs):\n", " super(TransformerEncoder, self).__init__(**kwargs)\n", " self.num_hiddens = num_hiddens\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(\"block\"+str(i),\n", " EncoderBlock(key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, dropout, use_bias))\n", "\n", " def forward(self, X, valid_lens, *args):\n", " # 因为位置编码值在-1和1之间,\n", " # 因此嵌入值乘以嵌入维度的平方根进行缩放,\n", " # 然后再与位置编码相加。\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self.attention_weights = [None] * len(self.blks)\n", " for i, blk in enumerate(self.blks):\n", " X = blk(X, valid_lens)\n", " self.attention_weights[\n", " i] = blk.attention.attention.attention_weights\n", " return X\n", "\n", "def annotate(text, xy, xytext):\n", " d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,\n", " arrowprops=dict(arrowstyle='->'))\n", "\n", "def train_2d(trainer, steps=20, f_grad=None):\n", " \"\"\"用定制的训练机优化2D目标函数\n", "\n", " Defined in :numref:`subsec_gd-learningrate`\"\"\"\n", " # s1和s2是稍后将使用的内部状态变量\n", " x1, x2, s1, s2 = -5, -2, 0, 0\n", " results = [(x1, x2)]\n", " for i in range(steps):\n", " if f_grad:\n", " x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)\n", " else:\n", " x1, x2, s1, s2 = trainer(x1, x2, s1, s2)\n", " results.append((x1, x2))\n", " print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')\n", " return results\n", "\n", "def show_trace_2d(f, results):\n", " \"\"\"显示优化过程中2D变量的轨迹\n", "\n", " Defined in :numref:`subsec_gd-learningrate`\"\"\"\n", " d2l.set_figsize()\n", " d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')\n", " x1, x2 = d2l.meshgrid(d2l.arange(-5.5, 1.0, 0.1),\n", " d2l.arange(-3.0, 1.0, 0.1), indexing='ij')\n", " d2l.plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')\n", " d2l.plt.xlabel('x1')\n", " d2l.plt.ylabel('x2')\n", "\n", "d2l.DATA_HUB['airfoil'] = (d2l.DATA_URL + 'airfoil_self_noise.dat',\n", " '76e5be1548fd8222e5074cf0faae75edff8cf93f')\n", "\n", "def get_data_ch11(batch_size=10, n=1500):\n", " \"\"\"Defined in :numref:`sec_minibatches`\"\"\"\n", " data = np.genfromtxt(d2l.download('airfoil'),\n", " dtype=np.float32, delimiter='\\t')\n", " data = torch.from_numpy((data - data.mean(axis=0)) / data.std(axis=0))\n", " data_iter = d2l.load_array((data[:n, :-1], data[:n, -1]),\n", " batch_size, is_train=True)\n", " return data_iter, data.shape[1]-1\n", "\n", "def train_ch11(trainer_fn, states, hyperparams, data_iter,\n", " feature_dim, num_epochs=2):\n", " \"\"\"Defined in :numref:`sec_minibatches`\"\"\"\n", " # 初始化模型\n", " w = torch.normal(mean=0.0, std=0.01, size=(feature_dim, 1),\n", " requires_grad=True)\n", " b = torch.zeros((1), requires_grad=True)\n", " net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss\n", " # 训练模型\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[0, num_epochs], ylim=[0.22, 0.35])\n", " n, timer = 0, d2l.Timer()\n", " for _ in range(num_epochs):\n", " for X, y in data_iter:\n", " l = loss(net(X), y).mean()\n", " l.backward()\n", " trainer_fn([w, b], states, hyperparams)\n", " n += X.shape[0]\n", " if n % 200 == 0:\n", " timer.stop()\n", " animator.add(n/X.shape[0]/len(data_iter),\n", " (d2l.evaluate_loss(net, data_iter, loss),))\n", " timer.start()\n", " print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')\n", " return timer.cumsum(), animator.Y[0]\n", "\n", "def train_concise_ch11(trainer_fn, hyperparams, data_iter, num_epochs=4):\n", " \"\"\"Defined in :numref:`sec_minibatches`\"\"\"\n", " # 初始化模型\n", " net = nn.Sequential(nn.Linear(5, 1))\n", " def init_weights(m):\n", " if type(m) == nn.Linear:\n", " torch.nn.init.normal_(m.weight, std=0.01)\n", " net.apply(init_weights)\n", "\n", " optimizer = trainer_fn(net.parameters(), **hyperparams)\n", " loss = nn.MSELoss(reduction='none')\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[0, num_epochs], ylim=[0.22, 0.35])\n", " n, timer = 0, d2l.Timer()\n", " for _ in range(num_epochs):\n", " for X, y in data_iter:\n", " optimizer.zero_grad()\n", " out = net(X)\n", " y = y.reshape(out.shape)\n", " l = loss(out, y)\n", " l.mean().backward()\n", " optimizer.step()\n", " n += X.shape[0]\n", " if n % 200 == 0:\n", " timer.stop()\n", " # MSELoss计算平方误差时不带系数1/2\n", " animator.add(n/X.shape[0]/len(data_iter),\n", " (d2l.evaluate_loss(net, data_iter, loss) / 2,))\n", " timer.start()\n", " print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')\n", "\n", "class Benchmark:\n", " \"\"\"用于测量运行时间\"\"\"\n", " def __init__(self, description='Done'):\n", " \"\"\"Defined in :numref:`sec_hybridize`\"\"\"\n", " self.description = description\n", "\n", " def __enter__(self):\n", " self.timer = d2l.Timer()\n", " return self\n", "\n", " def __exit__(self, *args):\n", " print(f'{self.description}: {self.timer.stop():.4f} sec')\n", "\n", "def split_batch(X, y, devices):\n", " \"\"\"将X和y拆分到多个设备上\n", "\n", " Defined in :numref:`sec_multi_gpu`\"\"\"\n", " assert X.shape[0] == y.shape[0]\n", " return (nn.parallel.scatter(X, devices),\n", " nn.parallel.scatter(y, devices))\n", "\n", "def resnet18(num_classes, in_channels=1):\n", " \"\"\"稍加修改的ResNet-18模型\n", "\n", " Defined in :numref:`sec_multi_gpu_concise`\"\"\"\n", " def resnet_block(in_channels, out_channels, num_residuals,\n", " first_block=False):\n", " blk = []\n", " for i in range(num_residuals):\n", " if i == 0 and not first_block:\n", " blk.append(d2l.Residual(in_channels, out_channels,\n", " use_1x1conv=True, strides=2))\n", " else:\n", " blk.append(d2l.Residual(out_channels, out_channels))\n", " return nn.Sequential(*blk)\n", "\n", " # 该模型使用了更小的卷积核、步长和填充,而且删除了最大汇聚层\n", " net = nn.Sequential(\n", " nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),\n", " nn.BatchNorm2d(64),\n", " nn.ReLU())\n", " net.add_module(\"resnet_block1\", resnet_block(\n", " 64, 64, 2, first_block=True))\n", " net.add_module(\"resnet_block2\", resnet_block(64, 128, 2))\n", " net.add_module(\"resnet_block3\", resnet_block(128, 256, 2))\n", " net.add_module(\"resnet_block4\", resnet_block(256, 512, 2))\n", " net.add_module(\"global_avg_pool\", nn.AdaptiveAvgPool2d((1,1)))\n", " net.add_module(\"fc\", nn.Sequential(nn.Flatten(),\n", " nn.Linear(512, num_classes)))\n", " return net\n", "\n", "def train_batch_ch13(net, X, y, loss, trainer, devices):\n", " \"\"\"用多GPU进行小批量训练\n", "\n", " Defined in :numref:`sec_image_augmentation`\"\"\"\n", " if isinstance(X, list):\n", " # 微调BERT中所需\n", " X = [x.to(devices[0]) for x in X]\n", " else:\n", " X = X.to(devices[0])\n", " y = y.to(devices[0])\n", " net.train()\n", " trainer.zero_grad()\n", " pred = net(X)\n", " l = loss(pred, y)\n", " l.sum().backward()\n", " trainer.step()\n", " train_loss_sum = l.sum()\n", " train_acc_sum = d2l.accuracy(pred, y)\n", " return train_loss_sum, train_acc_sum\n", "\n", "def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,\n", " devices=d2l.try_all_gpus()):\n", " \"\"\"用多GPU进行模型训练\n", "\n", " Defined in :numref:`sec_image_augmentation`\"\"\"\n", " timer, num_batches = d2l.Timer(), len(train_iter)\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " for epoch in range(num_epochs):\n", " # 4个维度:储存训练损失,训练准确度,实例数,特点数\n", " metric = d2l.Accumulator(4)\n", " for i, (features, labels) in enumerate(train_iter):\n", " timer.start()\n", " l, acc = train_batch_ch13(\n", " net, features, labels, loss, trainer, devices)\n", " metric.add(l, acc, labels.shape[0], labels.numel())\n", " timer.stop()\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[2], metric[1] / metric[3],\n", " None))\n", " test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)\n", " animator.add(epoch + 1, (None, None, test_acc))\n", " print(f'loss {metric[0] / metric[2]:.3f}, train acc '\n", " f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')\n", " print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '\n", " f'{str(devices)}')\n", "\n", "d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',\n", " 'fba480ffa8aa7e0febbb511d181409f899b9baa5')\n", "\n", "def box_corner_to_center(boxes):\n", " \"\"\"从(左上,右下)转换到(中间,宽度,高度)\n", "\n", " Defined in :numref:`sec_bbox`\"\"\"\n", " x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]\n", " cx = (x1 + x2) / 2\n", " cy = (y1 + y2) / 2\n", " w = x2 - x1\n", " h = y2 - y1\n", " boxes = d2l.stack((cx, cy, w, h), axis=-1)\n", " return boxes\n", "\n", "def box_center_to_corner(boxes):\n", " \"\"\"从(中间,宽度,高度)转换到(左上,右下)\n", "\n", " Defined in :numref:`sec_bbox`\"\"\"\n", " cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]\n", " x1 = cx - 0.5 * w\n", " y1 = cy - 0.5 * h\n", " x2 = cx + 0.5 * w\n", " y2 = cy + 0.5 * h\n", " boxes = d2l.stack((x1, y1, x2, y2), axis=-1)\n", " return boxes\n", "\n", "def bbox_to_rect(bbox, color):\n", " \"\"\"Defined in :numref:`sec_bbox`\"\"\"\n", " # 将边界框(左上x,左上y,右下x,右下y)格式转换成matplotlib格式:\n", " # ((左上x,左上y),宽,高)\n", " return d2l.plt.Rectangle(\n", " xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],\n", " fill=False, edgecolor=color, linewidth=2)\n", "\n", "def multibox_prior(data, sizes, ratios):\n", " \"\"\"生成以每个像素为中心具有不同形状的锚框\n", "\n", " Defined in :numref:`sec_anchor`\"\"\"\n", " in_height, in_width = data.shape[-2:]\n", " device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n", " boxes_per_pixel = (num_sizes + num_ratios - 1)\n", " size_tensor = d2l.tensor(sizes, device=device)\n", " ratio_tensor = d2l.tensor(ratios, device=device)\n", "\n", " # 为了将锚点移动到像素的中心,需要设置偏移量。\n", " # 因为一个像素的高为1且宽为1,我们选择偏移我们的中心0.5\n", " offset_h, offset_w = 0.5, 0.5\n", " steps_h = 1.0 / in_height # 在y轴上缩放步长\n", " steps_w = 1.0 / in_width # 在x轴上缩放步长\n", "\n", " # 生成锚框的所有中心点\n", " center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n", " center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n", " shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n", " shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n", "\n", " # 生成“boxes_per_pixel”个高和宽,\n", " # 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)\n", " w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n", " sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n", " * in_height / in_width # 处理矩形输入\n", " h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n", " sizes[0] / torch.sqrt(ratio_tensor[1:])))\n", " # 除以2来获得半高和半宽\n", " anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n", " in_height * in_width, 1) / 2\n", "\n", " # 每个中心点都将有“boxes_per_pixel”个锚框,\n", " # 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次\n", " out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n", " dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n", " output = out_grid + anchor_manipulations\n", " return output.unsqueeze(0)\n", "\n", "def show_bboxes(axes, bboxes, labels=None, colors=None):\n", " \"\"\"显示所有边界框\n", "\n", " Defined in :numref:`sec_anchor`\"\"\"\n", " def _make_list(obj, default_values=None):\n", " if obj is None:\n", " obj = default_values\n", " elif not isinstance(obj, (list, tuple)):\n", " obj = [obj]\n", " return obj\n", "\n", " labels = _make_list(labels)\n", " colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n", " for i, bbox in enumerate(bboxes):\n", " color = colors[i % len(colors)]\n", " rect = d2l.bbox_to_rect(d2l.numpy(bbox), color)\n", " axes.add_patch(rect)\n", " if labels and len(labels) > i:\n", " text_color = 'k' if color == 'w' else 'w'\n", " axes.text(rect.xy[0], rect.xy[1], labels[i],\n", " va='center', ha='center', fontsize=9, color=text_color,\n", " bbox=dict(facecolor=color, lw=0))\n", "\n", "def box_iou(boxes1, boxes2):\n", " \"\"\"计算两个锚框或边界框列表中成对的交并比\n", "\n", " Defined in :numref:`sec_anchor`\"\"\"\n", " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", " (boxes[:, 3] - boxes[:, 1]))\n", " # boxes1,boxes2,areas1,areas2的形状:\n", " # boxes1:(boxes1的数量,4),\n", " # boxes2:(boxes2的数量,4),\n", " # areas1:(boxes1的数量,),\n", " # areas2:(boxes2的数量,)\n", " areas1 = box_area(boxes1)\n", " areas2 = box_area(boxes2)\n", " # inter_upperlefts,inter_lowerrights,inters的形状:\n", " # (boxes1的数量,boxes2的数量,2)\n", " inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n", " inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n", " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", " # inter_areasandunion_areas的形状:(boxes1的数量,boxes2的数量)\n", " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", " union_areas = areas1[:, None] + areas2 - inter_areas\n", " return inter_areas / union_areas\n", "\n", "def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n", " \"\"\"将最接近的真实边界框分配给锚框\n", "\n", " Defined in :numref:`sec_anchor`\"\"\"\n", " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", " # 位于第i行和第j列的元素x_ij是锚框i和真实边界框j的IoU\n", " jaccard = box_iou(anchors, ground_truth)\n", " # 对于每个锚框,分配的真实边界框的张量\n", " anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n", " device=device)\n", " # 根据阈值,决定是否分配真实边界框\n", " max_ious, indices = torch.max(jaccard, dim=1)\n", " anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n", " box_j = indices[max_ious >= iou_threshold]\n", " anchors_bbox_map[anc_i] = box_j\n", " col_discard = torch.full((num_anchors,), -1)\n", " row_discard = torch.full((num_gt_boxes,), -1)\n", " for _ in range(num_gt_boxes):\n", " max_idx = torch.argmax(jaccard)\n", " box_idx = (max_idx % num_gt_boxes).long()\n", " anc_idx = (max_idx / num_gt_boxes).long()\n", " anchors_bbox_map[anc_idx] = box_idx\n", " jaccard[:, box_idx] = col_discard\n", " jaccard[anc_idx, :] = row_discard\n", " return anchors_bbox_map\n", "\n", "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", " \"\"\"对锚框偏移量的转换\n", "\n", " Defined in :numref:`subsec_labeling-anchor-boxes`\"\"\"\n", " c_anc = d2l.box_corner_to_center(anchors)\n", " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", " offset_wh = 5 * d2l.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", " offset = d2l.concat([offset_xy, offset_wh], axis=1)\n", " return offset\n", "\n", "def multibox_target(anchors, labels):\n", " \"\"\"使用真实边界框标记锚框\n", "\n", " Defined in :numref:`subsec_labeling-anchor-boxes`\"\"\"\n", " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", " batch_offset, batch_mask, batch_class_labels = [], [], []\n", " device, num_anchors = anchors.device, anchors.shape[0]\n", " for i in range(batch_size):\n", " label = labels[i, :, :]\n", " anchors_bbox_map = assign_anchor_to_bbox(\n", " label[:, 1:], anchors, device)\n", " bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n", " 1, 4)\n", " # 将类标签和分配的边界框坐标初始化为零\n", " class_labels = torch.zeros(num_anchors, dtype=torch.long,\n", " device=device)\n", " assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n", " device=device)\n", " # 使用真实边界框来标记锚框的类别。\n", " # 如果一个锚框没有被分配,标记其为背景(值为零)\n", " indices_true = torch.nonzero(anchors_bbox_map >= 0)\n", " bb_idx = anchors_bbox_map[indices_true]\n", " class_labels[indices_true] = label[bb_idx, 0].long() + 1\n", " assigned_bb[indices_true] = label[bb_idx, 1:]\n", " # 偏移量转换\n", " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", " batch_offset.append(offset.reshape(-1))\n", " batch_mask.append(bbox_mask.reshape(-1))\n", " batch_class_labels.append(class_labels)\n", " bbox_offset = torch.stack(batch_offset)\n", " bbox_mask = torch.stack(batch_mask)\n", " class_labels = torch.stack(batch_class_labels)\n", " return (bbox_offset, bbox_mask, class_labels)\n", "\n", "def offset_inverse(anchors, offset_preds):\n", " \"\"\"根据带有预测偏移量的锚框来预测边界框\n", "\n", " Defined in :numref:`subsec_labeling-anchor-boxes`\"\"\"\n", " anc = d2l.box_corner_to_center(anchors)\n", " pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n", " pred_bbox_wh = d2l.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n", " pred_bbox = d2l.concat((pred_bbox_xy, pred_bbox_wh), axis=1)\n", " predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n", " return predicted_bbox\n", "\n", "def nms(boxes, scores, iou_threshold):\n", " \"\"\"对预测边界框的置信度进行排序\n", "\n", " Defined in :numref:`subsec_predicting-bounding-boxes-nms`\"\"\"\n", " B = torch.argsort(scores, dim=-1, descending=True)\n", " keep = [] # 保留预测边界框的指标\n", " while B.numel() > 0:\n", " i = B[0]\n", " keep.append(i)\n", " if B.numel() == 1: break\n", " iou = box_iou(boxes[i, :].reshape(-1, 4),\n", " boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n", " inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n", " B = B[inds + 1]\n", " return d2l.tensor(keep, device=boxes.device)\n", "\n", "def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n", " pos_threshold=0.009999999):\n", " \"\"\"使用非极大值抑制来预测边界框\n", "\n", " Defined in :numref:`subsec_predicting-bounding-boxes-nms`\"\"\"\n", " device, batch_size = cls_probs.device, cls_probs.shape[0]\n", " anchors = anchors.squeeze(0)\n", " num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n", " out = []\n", " for i in range(batch_size):\n", " cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n", " conf, class_id = torch.max(cls_prob[1:], 0)\n", " predicted_bb = offset_inverse(anchors, offset_pred)\n", " keep = nms(predicted_bb, conf, nms_threshold)\n", "\n", " # 找到所有的non_keep索引,并将类设置为背景\n", " all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n", " combined = torch.cat((keep, all_idx))\n", " uniques, counts = combined.unique(return_counts=True)\n", " non_keep = uniques[counts == 1]\n", " all_id_sorted = torch.cat((keep, non_keep))\n", " class_id[non_keep] = -1\n", " class_id = class_id[all_id_sorted]\n", " conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n", " # pos_threshold是一个用于非背景预测的阈值\n", " below_min_idx = (conf < pos_threshold)\n", " class_id[below_min_idx] = -1\n", " conf[below_min_idx] = 1 - conf[below_min_idx]\n", " pred_info = torch.cat((class_id.unsqueeze(1),\n", " conf.unsqueeze(1),\n", " predicted_bb), dim=1)\n", " out.append(pred_info)\n", " return d2l.stack(out)\n", "\n", "d2l.DATA_HUB['banana-detection'] = (\n", " d2l.DATA_URL + 'banana-detection.zip',\n", " '5de26c8fce5ccdea9f91267273464dc968d20d72')\n", "\n", "def read_data_bananas(is_train=True):\n", " \"\"\"读取香蕉检测数据集中的图像和标签\n", "\n", " Defined in :numref:`sec_object-detection-dataset`\"\"\"\n", " data_dir = d2l.download_extract('banana-detection')\n", " csv_fname = os.path.join(data_dir, 'bananas_train' if is_train\n", " else 'bananas_val', 'label.csv')\n", " csv_data = pd.read_csv(csv_fname)\n", " csv_data = csv_data.set_index('img_name')\n", " images, targets = [], []\n", " for img_name, target in csv_data.iterrows():\n", " images.append(torchvision.io.read_image(\n", " os.path.join(data_dir, 'bananas_train' if is_train else\n", " 'bananas_val', 'images', f'{img_name}')))\n", " # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),\n", " # 其中所有图像都具有相同的香蕉类(索引为0)\n", " targets.append(list(target))\n", " return images, torch.tensor(targets).unsqueeze(1) / 256\n", "\n", "class BananasDataset(torch.utils.data.Dataset):\n", " \"\"\"一个用于加载香蕉检测数据集的自定义数据集\n", "\n", " Defined in :numref:`sec_object-detection-dataset`\"\"\"\n", " def __init__(self, is_train):\n", " self.features, self.labels = read_data_bananas(is_train)\n", " print('read ' + str(len(self.features)) + (f' training examples' if\n", " is_train else f' validation examples'))\n", "\n", " def __getitem__(self, idx):\n", " return (self.features[idx].float(), self.labels[idx])\n", "\n", " def __len__(self):\n", " return len(self.features)\n", "\n", "def load_data_bananas(batch_size):\n", " \"\"\"加载香蕉检测数据集\n", "\n", " Defined in :numref:`sec_object-detection-dataset`\"\"\"\n", " train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),\n", " batch_size, shuffle=True)\n", " val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),\n", " batch_size)\n", " return train_iter, val_iter\n", "\n", "d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar',\n", " '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')\n", "\n", "def read_voc_images(voc_dir, is_train=True):\n", " \"\"\"读取所有VOC图像并标注\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", " txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',\n", " 'train.txt' if is_train else 'val.txt')\n", " mode = torchvision.io.image.ImageReadMode.RGB\n", " with open(txt_fname, 'r') as f:\n", " images = f.read().split()\n", " features, labels = [], []\n", " for i, fname in enumerate(images):\n", " features.append(torchvision.io.read_image(os.path.join(\n", " voc_dir, 'JPEGImages', f'{fname}.jpg')))\n", " labels.append(torchvision.io.read_image(os.path.join(\n", " voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))\n", " return features, labels\n", "\n", "VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],\n", " [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],\n", " [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],\n", " [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],\n", " [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],\n", " [0, 64, 128]]\n", "\n", "VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',\n", " 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\n", " 'diningtable', 'dog', 'horse', 'motorbike', 'person',\n", " 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']\n", "\n", "def voc_colormap2label():\n", " \"\"\"构建从RGB到VOC类别索引的映射\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", " colormap2label = torch.zeros(256 ** 3, dtype=torch.long)\n", " for i, colormap in enumerate(VOC_COLORMAP):\n", " colormap2label[\n", " (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i\n", " return colormap2label\n", "\n", "def voc_label_indices(colormap, colormap2label):\n", " \"\"\"将VOC标签中的RGB值映射到它们的类别索引\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", " colormap = colormap.permute(1, 2, 0).numpy().astype('int32')\n", " idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256\n", " + colormap[:, :, 2])\n", " return colormap2label[idx]\n", "\n", "def voc_rand_crop(feature, label, height, width):\n", " \"\"\"随机裁剪特征和标签图像\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", " rect = torchvision.transforms.RandomCrop.get_params(\n", " feature, (height, width))\n", " feature = torchvision.transforms.functional.crop(feature, *rect)\n", " label = torchvision.transforms.functional.crop(label, *rect)\n", " return feature, label\n", "\n", "class VOCSegDataset(torch.utils.data.Dataset):\n", " \"\"\"一个用于加载VOC数据集的自定义数据集\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", "\n", " def __init__(self, is_train, crop_size, voc_dir):\n", " self.transform = torchvision.transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " self.crop_size = crop_size\n", " features, labels = read_voc_images(voc_dir, is_train=is_train)\n", " self.features = [self.normalize_image(feature)\n", " for feature in self.filter(features)]\n", " self.labels = self.filter(labels)\n", " self.colormap2label = voc_colormap2label()\n", " print('read ' + str(len(self.features)) + ' examples')\n", "\n", " def normalize_image(self, img):\n", " return self.transform(img.float() / 255)\n", "\n", " def filter(self, imgs):\n", " return [img for img in imgs if (\n", " img.shape[1] >= self.crop_size[0] and\n", " img.shape[2] >= self.crop_size[1])]\n", "\n", " def __getitem__(self, idx):\n", " feature, label = voc_rand_crop(self.features[idx], self.labels[idx],\n", " *self.crop_size)\n", " return (feature, voc_label_indices(label, self.colormap2label))\n", "\n", " def __len__(self):\n", " return len(self.features)\n", "\n", "def load_data_voc(batch_size, crop_size):\n", " \"\"\"加载VOC语义分割数据集\n", "\n", " Defined in :numref:`sec_semantic_segmentation`\"\"\"\n", " voc_dir = d2l.download_extract('voc2012', os.path.join(\n", " 'VOCdevkit', 'VOC2012'))\n", " num_workers = d2l.get_dataloader_workers()\n", " train_iter = torch.utils.data.DataLoader(\n", " VOCSegDataset(True, crop_size, voc_dir), batch_size,\n", " shuffle=True, drop_last=True, num_workers=num_workers)\n", " test_iter = torch.utils.data.DataLoader(\n", " VOCSegDataset(False, crop_size, voc_dir), batch_size,\n", " drop_last=True, num_workers=num_workers)\n", " return train_iter, test_iter\n", "\n", "d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n", " '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n", "\n", "def read_csv_labels(fname):\n", " \"\"\"读取fname来给标签字典返回一个文件名\n", "\n", " Defined in :numref:`sec_kaggle_cifar10`\"\"\"\n", " with open(fname, 'r') as f:\n", " # 跳过文件头行(列名)\n", " lines = f.readlines()[1:]\n", " tokens = [l.rstrip().split(',') for l in lines]\n", " return dict(((name, label) for name, label in tokens))\n", "\n", "def copyfile(filename, target_dir):\n", " \"\"\"将文件复制到目标目录\n", "\n", " Defined in :numref:`sec_kaggle_cifar10`\"\"\"\n", " os.makedirs(target_dir, exist_ok=True)\n", " shutil.copy(filename, target_dir)\n", "\n", "def reorg_train_valid(data_dir, labels, valid_ratio):\n", " \"\"\"将验证集从原始的训练集中拆分出来\n", "\n", " Defined in :numref:`sec_kaggle_cifar10`\"\"\"\n", " # 训练数据集中样本最少的类别中的样本数\n", " n = collections.Counter(labels.values()).most_common()[-1][1]\n", " # 验证集中每个类别的样本数\n", " n_valid_per_label = max(1, math.floor(n * valid_ratio))\n", " label_count = {}\n", " for train_file in os.listdir(os.path.join(data_dir, 'train')):\n", " label = labels[train_file.split('.')[0]]\n", " fname = os.path.join(data_dir, 'train', train_file)\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train_valid', label))\n", " if label not in label_count or label_count[label] < n_valid_per_label:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'valid', label))\n", " label_count[label] = label_count.get(label, 0) + 1\n", " else:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train', label))\n", " return n_valid_per_label\n", "\n", "def reorg_test(data_dir):\n", " \"\"\"在预测期间整理测试集,以方便读取\n", "\n", " Defined in :numref:`sec_kaggle_cifar10`\"\"\"\n", " for test_file in os.listdir(os.path.join(data_dir, 'test')):\n", " copyfile(os.path.join(data_dir, 'test', test_file),\n", " os.path.join(data_dir, 'train_valid_test', 'test',\n", " 'unknown'))\n", "\n", "d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',\n", " '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')\n", "\n", "d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',\n", " '319d85e578af0cdc590547f26231e4e31cdf1e42')\n", "\n", "def read_ptb():\n", " \"\"\"将PTB数据集加载到文本行的列表中\n", "\n", " Defined in :numref:`sec_word2vec_data`\"\"\"\n", " data_dir = d2l.download_extract('ptb')\n", " # Readthetrainingset.\n", " with open(os.path.join(data_dir, 'ptb.train.txt')) as f:\n", " raw_text = f.read()\n", " return [line.split() for line in raw_text.split('\\n')]\n", "\n", "def subsample(sentences, vocab):\n", " \"\"\"下采样高频词\n", "\n", " Defined in :numref:`sec_word2vec_data`\"\"\"\n", " # 排除未知词元''\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " counter = d2l.count_corpus(sentences)\n", " num_tokens = sum(counter.values())\n", "\n", " # 如果在下采样期间保留词元,则返回True\n", " def keep(token):\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", "\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", "\n", "def get_centers_and_contexts(corpus, max_window_size):\n", " \"\"\"返回跳元模型中的中心词和上下文词\n", "\n", " Defined in :numref:`sec_word2vec_data`\"\"\"\n", " centers, contexts = [], []\n", " for line in corpus:\n", " # 要形成“中心词-上下文词”对,每个句子至少需要有2个词\n", " if len(line) < 2:\n", " continue\n", " centers += line\n", " for i in range(len(line)): # 上下文窗口中间i\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, i - window_size),\n", " min(len(line), i + 1 + window_size)))\n", " # 从上下文词中排除中心词\n", " indices.remove(i)\n", " contexts.append([line[idx] for idx in indices])\n", " return centers, contexts\n", "\n", "class RandomGenerator:\n", " \"\"\"根据n个采样权重在{1,...,n}中随机抽取\"\"\"\n", " def __init__(self, sampling_weights):\n", " \"\"\"Defined in :numref:`sec_word2vec_data`\"\"\"\n", " # Exclude\n", " self.population = list(range(1, len(sampling_weights) + 1))\n", " self.sampling_weights = sampling_weights\n", " self.candidates = []\n", " self.i = 0\n", "\n", " def draw(self):\n", " if self.i == len(self.candidates):\n", " # 缓存k个随机采样结果\n", " self.candidates = random.choices(\n", " self.population, self.sampling_weights, k=10000)\n", " self.i = 0\n", " self.i += 1\n", " return self.candidates[self.i - 1]\n", "\n", "generator = RandomGenerator([2, 3, 4])\n", "[generator.draw() for _ in range(10)]\n", "\n", "def get_negatives(all_contexts, vocab, counter, K):\n", " \"\"\"返回负采样中的噪声词\n", "\n", " Defined in :numref:`sec_word2vec_data`\"\"\"\n", " # 索引为1、2、...(索引0是词表中排除的未知标记)\n", " sampling_weights = [counter[vocab.to_tokens(i)]**0.75\n", " for i in range(1, len(vocab))]\n", " all_negatives, generator = [], RandomGenerator(sampling_weights)\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " neg = generator.draw()\n", " # 噪声词不能是上下文词\n", " if neg not in contexts:\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "def batchify(data):\n", " \"\"\"返回带有负采样的跳元模型的小批量样本\n", "\n", " Defined in :numref:`sec_word2vec_data`\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += \\\n", " [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (d2l.reshape(d2l.tensor(centers), (-1, 1)), d2l.tensor(\n", " contexts_negatives), d2l.tensor(masks), d2l.tensor(labels))\n", "\n", "def load_data_ptb(batch_size, max_window_size, num_noise_words):\n", " \"\"\"下载PTB数据集,然后将其加载到内存中\n", "\n", " Defined in :numref:`subsec_word2vec-minibatch-loading`\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " sentences = read_ptb()\n", " vocab = d2l.Vocab(sentences, min_freq=10)\n", " subsampled, counter = subsample(sentences, vocab)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(\n", " all_contexts, vocab, counter, num_noise_words)\n", "\n", " class PTBDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", "\n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index],\n", " self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", " dataset = PTBDataset(all_centers, all_contexts, all_negatives)\n", "\n", " data_iter = torch.utils.data.DataLoader(\n", " dataset, batch_size, shuffle=True,\n", " collate_fn=batchify, num_workers=num_workers)\n", " return data_iter, vocab\n", "\n", "d2l.DATA_HUB['glove.6b.50d'] = (d2l.DATA_URL + 'glove.6B.50d.zip',\n", " '0b8703943ccdb6eb788e6f091b8946e82231bc4d')\n", "\n", "d2l.DATA_HUB['glove.6b.100d'] = (d2l.DATA_URL + 'glove.6B.100d.zip',\n", " 'cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a')\n", "\n", "d2l.DATA_HUB['glove.42b.300d'] = (d2l.DATA_URL + 'glove.42B.300d.zip',\n", " 'b5116e234e9eb9076672cfeabf5469f3eec904fa')\n", "\n", "d2l.DATA_HUB['wiki.en'] = (d2l.DATA_URL + 'wiki.en.zip',\n", " 'c1816da3821ae9f43899be655002f6c723e91b88')\n", "\n", "class TokenEmbedding:\n", " \"\"\"GloVe嵌入\"\"\"\n", " def __init__(self, embedding_name):\n", " \"\"\"Defined in :numref:`sec_synonyms`\"\"\"\n", " self.idx_to_token, self.idx_to_vec = self._load_embedding(\n", " embedding_name)\n", " self.unknown_idx = 0\n", " self.token_to_idx = {token: idx for idx, token in\n", " enumerate(self.idx_to_token)}\n", "\n", " def _load_embedding(self, embedding_name):\n", " idx_to_token, idx_to_vec = [''], []\n", " data_dir = d2l.download_extract(embedding_name)\n", " # GloVe网站:https://nlp.stanford.edu/projects/glove/\n", " # fastText网站:https://fasttext.cc/\n", " with open(os.path.join(data_dir, 'vec.txt'), 'r') as f:\n", " for line in f:\n", " elems = line.rstrip().split(' ')\n", " token, elems = elems[0], [float(elem) for elem in elems[1:]]\n", " # 跳过标题信息,例如fastText中的首行\n", " if len(elems) > 1:\n", " idx_to_token.append(token)\n", " idx_to_vec.append(elems)\n", " idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec\n", " return idx_to_token, d2l.tensor(idx_to_vec)\n", "\n", " def __getitem__(self, tokens):\n", " indices = [self.token_to_idx.get(token, self.unknown_idx)\n", " for token in tokens]\n", " vecs = self.idx_to_vec[d2l.tensor(indices)]\n", " return vecs\n", "\n", " def __len__(self):\n", " return len(self.idx_to_token)\n", "\n", "def get_tokens_and_segments(tokens_a, tokens_b=None):\n", " \"\"\"获取输入序列的词元及其片段索引\n", "\n", " Defined in :numref:`sec_bert`\"\"\"\n", " tokens = [''] + tokens_a + ['']\n", " # 0和1分别标记片段A和B\n", " segments = [0] * (len(tokens_a) + 2)\n", " if tokens_b is not None:\n", " tokens += tokens_b + ['']\n", " segments += [1] * (len(tokens_b) + 1)\n", " return tokens, segments\n", "\n", "class BERTEncoder(nn.Module):\n", " \"\"\"BERT编码器\n", "\n", " Defined in :numref:`subsec_bert_input_rep`\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,\n", " ffn_num_hiddens, num_heads, num_layers, dropout,\n", " max_len=1000, key_size=768, query_size=768, value_size=768,\n", " **kwargs):\n", " super(BERTEncoder, self).__init__(**kwargs)\n", " self.token_embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.segment_embedding = nn.Embedding(2, num_hiddens)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(f\"{i}\", d2l.EncoderBlock(\n", " key_size, query_size, value_size, num_hiddens, norm_shape,\n", " ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))\n", " # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数\n", " self.pos_embedding = nn.Parameter(torch.randn(1, max_len,\n", " num_hiddens))\n", "\n", " def forward(self, tokens, segments, valid_lens):\n", " # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)\n", " X = self.token_embedding(tokens) + self.segment_embedding(segments)\n", " X = X + self.pos_embedding.data[:, :X.shape[1], :]\n", " for blk in self.blks:\n", " X = blk(X, valid_lens)\n", " return X\n", "\n", "class MaskLM(nn.Module):\n", " \"\"\"BERT的掩蔽语言模型任务\n", "\n", " Defined in :numref:`subsec_bert_input_rep`\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):\n", " super(MaskLM, self).__init__(**kwargs)\n", " self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),\n", " nn.ReLU(),\n", " nn.LayerNorm(num_hiddens),\n", " nn.Linear(num_hiddens, vocab_size))\n", "\n", " def forward(self, X, pred_positions):\n", " num_pred_positions = pred_positions.shape[1]\n", " pred_positions = pred_positions.reshape(-1)\n", " batch_size = X.shape[0]\n", " batch_idx = torch.arange(0, batch_size)\n", " # 假设batch_size=2,num_pred_positions=3\n", " # 那么batch_idx是np.array([0,0,0,1,1,1])\n", " batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)\n", " masked_X = X[batch_idx, pred_positions]\n", " masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))\n", " mlm_Y_hat = self.mlp(masked_X)\n", " return mlm_Y_hat\n", "\n", "class NextSentencePred(nn.Module):\n", " \"\"\"BERT的下一句预测任务\n", "\n", " Defined in :numref:`subsec_mlm`\"\"\"\n", " def __init__(self, num_inputs, **kwargs):\n", " super(NextSentencePred, self).__init__(**kwargs)\n", " self.output = nn.Linear(num_inputs, 2)\n", "\n", " def forward(self, X):\n", " # X的形状:(batchsize,num_hiddens)\n", " return self.output(X)\n", "\n", "class BERTModel(nn.Module):\n", " \"\"\"BERT模型\n", "\n", " Defined in :numref:`subsec_nsp`\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,\n", " ffn_num_hiddens, num_heads, num_layers, dropout,\n", " max_len=1000, key_size=768, query_size=768, value_size=768,\n", " hid_in_features=768, mlm_in_features=768,\n", " nsp_in_features=768):\n", " super(BERTModel, self).__init__()\n", " self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,\n", " ffn_num_input, ffn_num_hiddens, num_heads, num_layers,\n", " dropout, max_len=max_len, key_size=key_size,\n", " query_size=query_size, value_size=value_size)\n", " self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),\n", " nn.Tanh())\n", " self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)\n", " self.nsp = NextSentencePred(nsp_in_features)\n", "\n", " def forward(self, tokens, segments, valid_lens=None,\n", " pred_positions=None):\n", " encoded_X = self.encoder(tokens, segments, valid_lens)\n", " if pred_positions is not None:\n", " mlm_Y_hat = self.mlm(encoded_X, pred_positions)\n", " else:\n", " mlm_Y_hat = None\n", " # 用于下一句预测的多层感知机分类器的隐藏层,0是“”标记的索引\n", " nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))\n", " return encoded_X, mlm_Y_hat, nsp_Y_hat\n", "\n", "d2l.DATA_HUB['wikitext-2'] = (\n", " 'https://s3.amazonaws.com/research.metamind.io/wikitext/'\n", " 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')\n", "\n", "def _read_wiki(data_dir):\n", " \"\"\"Defined in :numref:`sec_bert-dataset`\"\"\"\n", " file_name = os.path.join(data_dir, 'wiki.train.tokens')\n", " with open(file_name, 'r') as f:\n", " lines = f.readlines()\n", " # 大写字母转换为小写字母\n", " paragraphs = [line.strip().lower().split(' . ')\n", " for line in lines if len(line.split(' . ')) >= 2]\n", " random.shuffle(paragraphs)\n", " return paragraphs\n", "\n", "def _get_next_sentence(sentence, next_sentence, paragraphs):\n", " \"\"\"Defined in :numref:`sec_bert-dataset`\"\"\"\n", " if random.random() < 0.5:\n", " is_next = True\n", " else:\n", " # paragraphs是三重列表的嵌套\n", " next_sentence = random.choice(random.choice(paragraphs))\n", " is_next = False\n", " return sentence, next_sentence, is_next\n", "\n", "def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):\n", " \"\"\"Defined in :numref:`sec_bert-dataset`\"\"\"\n", " nsp_data_from_paragraph = []\n", " for i in range(len(paragraph) - 1):\n", " tokens_a, tokens_b, is_next = _get_next_sentence(\n", " paragraph[i], paragraph[i + 1], paragraphs)\n", " # 考虑1个''词元和2个''词元\n", " if len(tokens_a) + len(tokens_b) + 3 > max_len:\n", " continue\n", " tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n", " nsp_data_from_paragraph.append((tokens, segments, is_next))\n", " return nsp_data_from_paragraph\n", "\n", "def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,\n", " vocab):\n", " \"\"\"Defined in :numref:`sec_bert-dataset`\"\"\"\n", " # 为遮蔽语言模型的输入创建新的词元副本,其中输入可能包含替换的“”或随机词元\n", " mlm_input_tokens = [token for token in tokens]\n", " pred_positions_and_labels = []\n", " # 打乱后用于在遮蔽语言模型任务中获取15%的随机词元进行预测\n", " random.shuffle(candidate_pred_positions)\n", " for mlm_pred_position in candidate_pred_positions:\n", " if len(pred_positions_and_labels) >= num_mlm_preds:\n", " break\n", " masked_token = None\n", " # 80%的时间:将词替换为“”词元\n", " if random.random() < 0.8:\n", " masked_token = ''\n", " else:\n", " # 10%的时间:保持词不变\n", " if random.random() < 0.5:\n", " masked_token = tokens[mlm_pred_position]\n", " # 10%的时间:用随机词替换该词\n", " else:\n", " masked_token = random.choice(vocab.idx_to_token)\n", " mlm_input_tokens[mlm_pred_position] = masked_token\n", " pred_positions_and_labels.append(\n", " (mlm_pred_position, tokens[mlm_pred_position]))\n", " return mlm_input_tokens, pred_positions_and_labels\n", "\n", "def _get_mlm_data_from_tokens(tokens, vocab):\n", " \"\"\"Defined in :numref:`subsec_prepare_mlm_data`\"\"\"\n", " candidate_pred_positions = []\n", " # tokens是一个字符串列表\n", " for i, token in enumerate(tokens):\n", " # 在遮蔽语言模型任务中不会预测特殊词元\n", " if token in ['', '']:\n", " continue\n", " candidate_pred_positions.append(i)\n", " # 遮蔽语言模型任务中预测15%的随机词元\n", " num_mlm_preds = max(1, round(len(tokens) * 0.15))\n", " mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(\n", " tokens, candidate_pred_positions, num_mlm_preds, vocab)\n", " pred_positions_and_labels = sorted(pred_positions_and_labels,\n", " key=lambda x: x[0])\n", " pred_positions = [v[0] for v in pred_positions_and_labels]\n", " mlm_pred_labels = [v[1] for v in pred_positions_and_labels]\n", " return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]\n", "\n", "def _pad_bert_inputs(examples, max_len, vocab):\n", " \"\"\"Defined in :numref:`subsec_prepare_mlm_data`\"\"\"\n", " max_num_mlm_preds = round(max_len * 0.15)\n", " all_token_ids, all_segments, valid_lens, = [], [], []\n", " all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []\n", " nsp_labels = []\n", " for (token_ids, pred_positions, mlm_pred_label_ids, segments,\n", " is_next) in examples:\n", " all_token_ids.append(torch.tensor(token_ids + [vocab['']] * (\n", " max_len - len(token_ids)), dtype=torch.long))\n", " all_segments.append(torch.tensor(segments + [0] * (\n", " max_len - len(segments)), dtype=torch.long))\n", " # valid_lens不包括''的计数\n", " valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))\n", " all_pred_positions.append(torch.tensor(pred_positions + [0] * (\n", " max_num_mlm_preds - len(pred_positions)), dtype=torch.long))\n", " # 填充词元的预测将通过乘以0权重在损失中过滤掉\n", " all_mlm_weights.append(\n", " torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (\n", " max_num_mlm_preds - len(pred_positions)),\n", " dtype=torch.float32))\n", " all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (\n", " max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))\n", " nsp_labels.append(torch.tensor(is_next, dtype=torch.long))\n", " return (all_token_ids, all_segments, valid_lens, all_pred_positions,\n", " all_mlm_weights, all_mlm_labels, nsp_labels)\n", "\n", "class _WikiTextDataset(torch.utils.data.Dataset):\n", " \"\"\"Defined in :numref:`subsec_prepare_mlm_data`\"\"\"\n", " def __init__(self, paragraphs, max_len):\n", " # 输入paragraphs[i]是代表段落的句子字符串列表;\n", " # 而输出paragraphs[i]是代表段落的句子列表,其中每个句子都是词元列表\n", " paragraphs = [d2l.tokenize(\n", " paragraph, token='word') for paragraph in paragraphs]\n", " sentences = [sentence for paragraph in paragraphs\n", " for sentence in paragraph]\n", " self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[\n", " '', '', '', ''])\n", " # 获取下一句子预测任务的数据\n", " examples = []\n", " for paragraph in paragraphs:\n", " examples.extend(_get_nsp_data_from_paragraph(\n", " paragraph, paragraphs, self.vocab, max_len))\n", " # 获取遮蔽语言模型任务的数据\n", " examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)\n", " + (segments, is_next))\n", " for tokens, segments, is_next in examples]\n", " # 填充输入\n", " (self.all_token_ids, self.all_segments, self.valid_lens,\n", " self.all_pred_positions, self.all_mlm_weights,\n", " self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(\n", " examples, max_len, self.vocab)\n", "\n", " def __getitem__(self, idx):\n", " return (self.all_token_ids[idx], self.all_segments[idx],\n", " self.valid_lens[idx], self.all_pred_positions[idx],\n", " self.all_mlm_weights[idx], self.all_mlm_labels[idx],\n", " self.nsp_labels[idx])\n", "\n", " def __len__(self):\n", " return len(self.all_token_ids)\n", "\n", "def load_data_wiki(batch_size, max_len):\n", " \"\"\"加载WikiText-2数据集\n", "\n", " Defined in :numref:`subsec_prepare_mlm_data`\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')\n", " paragraphs = _read_wiki(data_dir)\n", " train_set = _WikiTextDataset(paragraphs, max_len)\n", " train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n", " shuffle=True, num_workers=num_workers)\n", " return train_iter, train_set.vocab\n", "\n", "def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n", " segments_X, valid_lens_x,\n", " pred_positions_X, mlm_weights_X,\n", " mlm_Y, nsp_y):\n", " \"\"\"Defined in :numref:`sec_bert-pretraining`\"\"\"\n", " # 前向传播\n", " _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n", " valid_lens_x.reshape(-1),\n", " pred_positions_X)\n", " # 计算遮蔽语言模型损失\n", " mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n", " mlm_weights_X.reshape(-1, 1)\n", " mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n", " # 计算下一句子预测任务的损失\n", " nsp_l = loss(nsp_Y_hat, nsp_y)\n", " l = mlm_l + nsp_l\n", " return mlm_l, nsp_l, l\n", "\n", "d2l.DATA_HUB['aclImdb'] = (\n", " 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',\n", " '01ada507287d82875905620988597833ad4e0903')\n", "\n", "def read_imdb(data_dir, is_train):\n", " \"\"\"读取IMDb评论数据集文本序列和标签\n", "\n", " Defined in :numref:`sec_sentiment`\"\"\"\n", " data, labels = [], []\n", " for label in ('pos', 'neg'):\n", " folder_name = os.path.join(data_dir, 'train' if is_train else 'test',\n", " label)\n", " for file in os.listdir(folder_name):\n", " with open(os.path.join(folder_name, file), 'rb') as f:\n", " review = f.read().decode('utf-8').replace('\\n', '')\n", " data.append(review)\n", " labels.append(1 if label == 'pos' else 0)\n", " return data, labels\n", "\n", "def load_data_imdb(batch_size, num_steps=500):\n", " \"\"\"返回数据迭代器和IMDb评论数据集的词表\n", "\n", " Defined in :numref:`sec_sentiment`\"\"\"\n", " data_dir = d2l.download_extract('aclImdb', 'aclImdb')\n", " train_data = read_imdb(data_dir, True)\n", " test_data = read_imdb(data_dir, False)\n", " train_tokens = d2l.tokenize(train_data[0], token='word')\n", " test_tokens = d2l.tokenize(test_data[0], token='word')\n", " vocab = d2l.Vocab(train_tokens, min_freq=5)\n", " train_features = torch.tensor([d2l.truncate_pad(\n", " vocab[line], num_steps, vocab['']) for line in train_tokens])\n", " test_features = torch.tensor([d2l.truncate_pad(\n", " vocab[line], num_steps, vocab['']) for line in test_tokens])\n", " train_iter = d2l.load_array((train_features, torch.tensor(train_data[1])),\n", " batch_size)\n", " test_iter = d2l.load_array((test_features, torch.tensor(test_data[1])),\n", " batch_size,\n", " is_train=False)\n", " return train_iter, test_iter, vocab\n", "\n", "def predict_sentiment(net, vocab, sequence):\n", " \"\"\"预测文本序列的情感\n", "\n", " Defined in :numref:`sec_sentiment_rnn`\"\"\"\n", " sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())\n", " label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)\n", " return 'positive' if label == 1 else 'negative'\n", "\n", "d2l.DATA_HUB['SNLI'] = (\n", " 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n", " '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n", "\n", "def read_snli(data_dir, is_train):\n", " \"\"\"将SNLI数据集解析为前提、假设和标签\n", "\n", " Defined in :numref:`sec_natural-language-inference-and-dataset`\"\"\"\n", " def extract_text(s):\n", " # 删除我们不会使用的信息\n", " s = re.sub('\\\\(', '', s)\n", " s = re.sub('\\\\)', '', s)\n", " # 用一个空格替换两个或多个连续的空格\n", " s = re.sub('\\\\s{2,}', ' ', s)\n", " return s.strip()\n", " label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n", " file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n", " if is_train else 'snli_1.0_test.txt')\n", " with open(file_name, 'r') as f:\n", " rows = [row.split('\\t') for row in f.readlines()[1:]]\n", " premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n", " hypotheses = [extract_text(row[2]) for row in rows if row[0] \\\n", " in label_set]\n", " labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n", " return premises, hypotheses, labels\n", "\n", "class SNLIDataset(torch.utils.data.Dataset):\n", " \"\"\"用于加载SNLI数据集的自定义数据集\n", "\n", " Defined in :numref:`sec_natural-language-inference-and-dataset`\"\"\"\n", " def __init__(self, dataset, num_steps, vocab=None):\n", " self.num_steps = num_steps\n", " all_premise_tokens = d2l.tokenize(dataset[0])\n", " all_hypothesis_tokens = d2l.tokenize(dataset[1])\n", " if vocab is None:\n", " self.vocab = d2l.Vocab(all_premise_tokens + \\\n", " all_hypothesis_tokens, min_freq=5, reserved_tokens=[''])\n", " else:\n", " self.vocab = vocab\n", " self.premises = self._pad(all_premise_tokens)\n", " self.hypotheses = self._pad(all_hypothesis_tokens)\n", " self.labels = torch.tensor(dataset[2])\n", " print('read ' + str(len(self.premises)) + ' examples')\n", "\n", " def _pad(self, lines):\n", " return torch.tensor([d2l.truncate_pad(\n", " self.vocab[line], self.num_steps, self.vocab[''])\n", " for line in lines])\n", "\n", " def __getitem__(self, idx):\n", " return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n", "\n", " def __len__(self):\n", " return len(self.premises)\n", "\n", "def load_data_snli(batch_size, num_steps=50):\n", " \"\"\"下载SNLI数据集并返回数据迭代器和词表\n", "\n", " Defined in :numref:`sec_natural-language-inference-and-dataset`\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " data_dir = d2l.download_extract('SNLI')\n", " train_data = read_snli(data_dir, True)\n", " test_data = read_snli(data_dir, False)\n", " train_set = SNLIDataset(train_data, num_steps)\n", " test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n", " train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n", " shuffle=True,\n", " num_workers=num_workers)\n", " test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n", " shuffle=False,\n", " num_workers=num_workers)\n", " return train_iter, test_iter, train_set.vocab\n", "\n", "def predict_snli(net, vocab, premise, hypothesis):\n", " \"\"\"预测前提和假设之间的逻辑关系\n", "\n", " Defined in :numref:`sec_natural-language-inference-attention`\"\"\"\n", " net.eval()\n", " premise = torch.tensor(vocab[premise], device=d2l.try_gpu())\n", " hypothesis = torch.tensor(vocab[hypothesis], device=d2l.try_gpu())\n", " label = torch.argmax(net([premise.reshape((1, -1)),\n", " hypothesis.reshape((1, -1))]), dim=1)\n", " return 'entailment' if label == 0 else 'contradiction' if label == 1 \\\n", " else 'neutral'\n", "\n", "\n", "# Alias defined in config.ini\n", "nn_Module = nn.Module\n", "\n", "ones = torch.ones\n", "zeros = torch.zeros\n", "tensor = torch.tensor\n", "arange = torch.arange\n", "meshgrid = torch.meshgrid\n", "sin = torch.sin\n", "sinh = torch.sinh\n", "cos = torch.cos\n", "cosh = torch.cosh\n", "tanh = torch.tanh\n", "linspace = torch.linspace\n", "exp = torch.exp\n", "log = torch.log\n", "normal = torch.normal\n", "rand = torch.rand\n", "randn = torch.randn\n", "matmul = torch.matmul\n", "int32 = torch.int32\n", "float32 = torch.float32\n", "concat = torch.cat\n", "stack = torch.stack\n", "abs = torch.abs\n", "eye = torch.eye\n", "numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)\n", "size = lambda x, *args, **kwargs: x.numel(*args, **kwargs)\n", "reshape = lambda x, *args, **kwargs: x.reshape(*args, **kwargs)\n", "to = lambda x, *args, **kwargs: x.to(*args, **kwargs)\n", "reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)\n", "argmax = lambda x, *args, **kwargs: x.argmax(*args, **kwargs)\n", "astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)\n", "transpose = lambda x, *args, **kwargs: x.t(*args, **kwargs)\n", "reduce_mean = lambda x, *args, **kwargs: x.mean(*args, **kwargs)\n" ] }, { "cell_type": "code", "execution_count": 37, "id": "0acca715-0672-4a59-b615-613d3be468c2", "metadata": { "tags": [] }, "outputs": [], "source": [ "batch_size, num_steps = 32, 35\n", "train_iter, vocab = load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "code", "execution_count": 39, "id": "c289c562-2a91-4ba8-90a9-16cb266d4514", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n", "torch.Size([32, 35]) torch.Size([32, 35])\n" ] } ], "source": [ "for X, Y in train_iter:\n", " print(X.shape,Y.shape)" ] }, { "cell_type": "code", "execution_count": 23, "id": "a415ad0f-ebe8-4650-9f2f-026118b2c1db", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import sys\n", "import torch.nn as nn\n", "import torch\n", "import warnings\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "from torchsummary import summary\n", "from torch.nn import functional as F\n", "from sklearn.model_selection import ParameterGrid\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "class Data(d2l.DataModule):\n", " def __init__(self, batch_size=320, T=1000, num_train=600, tau=4, randn=0.2):\n", " self.save_hyperparameters()\n", " self.time = torch.range(1, T, dtype=torch.float32)\n", " self.x = torch.sin(0.01*self.time) + torch.randn(T)*randn\n", " \n", " def get_dataloader(self, train):\n", " features = [self.x[i:self.T-self.tau+i] for i in range(self.tau)]\n", " labels = [self.x[i:self.T-self.tau+i] for i in range(1,self.tau+1)]\n", " self.features = torch.stack(features, 1).unsqueeze(dim=-1)#.swapaxes(0, 1)\n", " self.labels = torch.stack(labels, 1).unsqueeze(dim=-1)#.swapaxes(0, 1)\n", " i = slice(0, self.num_train) if train else slice(self.num_train, None)\n", " return self.get_tensorloader([self.features, self.labels], train, i)\n", " \n", "class RNN(d2l.Module): #@save\n", " \"\"\"The RNN model implemented with high-level APIs.\"\"\"\n", " def __init__(self, num_inputs, num_hiddens):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.rnn = nn.RNN(num_inputs, num_hiddens)\n", "\n", " def forward(self, inputs, H=None):\n", " return self.rnn(inputs, H)\n", " \n", "class RNNAutoRegression(d2l.LinearRegression): #@save\n", " \"\"\"The RNN-based language model implemented with high-level APIs.\"\"\"\n", " def init_params(self):\n", " self.linear = nn.LazyLinear(1)\n", "\n", " # def output_layer(self, hiddens):\n", " # return self.linear(hiddens).swapaxes(0, 1)\n", " \n", " def __init__(self, rnn,lr=0.01, tau=4, plot_flag=True, emb_len=8):\n", " super().__init__(lr=lr)\n", " self.save_hyperparameters()\n", " self.init_params() \n", "\n", " def forward(self, X, state=None):\n", " rnn_outputs, _ = self.rnn(X.swapaxes(0, 1), state)\n", " outputs = [self.linear(H) for H in rnn_outputs]\n", " return torch.stack(outputs, 1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "c0bd1619-6165-486d-889b-5ccff0bbfc1e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(0.15221456438302994, 0.1405056193470955)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-15T08:31:02.995493\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tau=4\n", "data = Data(tau=tau)\n", "rnn = RNN(num_inputs=1, num_hiddens=8)\n", "model = RNNAutoRegression(rnn=rnn, lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=20)\n", "trainer.fit(model, data)" ] }, { "cell_type": "code", "execution_count": 27, "id": "65b9d105-2b6e-49db-9da2-47e7c16381c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([996, 4, 1])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.features.shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "fb5122c0-af18-4c44-b87e-08c7d88d48c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-15T08:20:26.160931\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "onestep_preds = model(data.features).detach().numpy()\n", "d2l.plot(data.time[data.tau:], [data.x[data.tau:], onestep_preds[:,-1].reshape(-1)], 'time', 'x',\n", " legend=['labels', '1-step preds'], figsize=(6, 3))" ] }, { "cell_type": "code", "execution_count": 25, "id": "592345c5-2cf5-4a59-9905-e4d2d875ccaf", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 1, 1])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp.shape" ] }, { "cell_type": "code", "execution_count": 35, "id": "3061430a-2353-4a90-83a1-f37a23f8e238", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-15T08:38:18.728511\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "multistep_preds = torch.zeros(data.T)\n", "multistep_preds[:] = data.x\n", "for i in range(data.num_train + data.tau, data.T):\n", " temp = model(\n", " multistep_preds[i - data.tau:i].reshape((1,data.tau,1)))\n", " # print(temp.shape,temp[:,-1].item())\n", " multistep_preds[i] = temp[:,-1].item()\n", "multistep_preds = multistep_preds.detach().numpy()\n", "d2l.plot([data.time[data.tau:], data.time[data.num_train+data.tau:]],\n", " [onestep_preds[:,-1].reshape(-1), multistep_preds[data.num_train+data.tau:]], 'time',\n", " 'x', legend=['1-step preds', 'multistep preds'], figsize=(6, 3))" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }