{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 读取和存储\n", "\n", "到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。\n", "\n", "\n", "## 读写`Tensor`\n", "\n", "我们可以直接使用`save`函数和`load`函数分别存储和读取`Tensor`。下面的例子创建了`Tensor`变量`x`,并将其存在文件名同为`x`的文件里。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "x = torch.ones(3)\n", "torch.save(x, 'x')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后我们将数据从存储的文件读回内存。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1., 1., 1.])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2 = torch.load('x')\n", "x2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们还可以存储一列`Tensor`并读回内存。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([1., 1., 1.]), tensor([0., 0., 0., 0.]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.zeros(4)\n", "torch.save([x, y], 'xy')\n", "x2, y2 = torch.load('xy')\n", "(x2, y2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们甚至可以存储并读取一个从字符串映射到`Tensor`的字典。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mydict = {'x': x, 'y': y}\n", "torch.save(mydict, 'mydict')\n", "mydict2 = torch.load('mydict')\n", "mydict2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 读写模型的参数\n", "\n", "除`Tensor`以外,我们还可以读写模型的参数。我们可以使用`save`方法来保存模型的`state_dict`,`Module`类提供了`load_state_dict`函数来读取模型参数。为了演示方便,我们先创建一个多层感知机,并将其初始化。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, **kwargs):\n", " super(MLP, self).__init__(**kwargs)\n", " self.hidden = nn.Linear(20, 256)\n", " self.activation = nn.ReLU()\n", " self.output = nn.Linear(256, 10)\n", " \n", " def forward(self, x):\n", " return self.output(self.activation(self.hidden(x)))\n", " \n", "net = MLP()\n", "X = torch.rand(2, 20)\n", "Y = net(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面把该模型的参数存成文件,文件名为mlp.params。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "filename = 'mlp.params'\n", "torch.save(net.state_dict(), filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,我们再实例化一次定义好的多层感知机。与随机初始化模型参数不同,我们在这里直接读取保存在文件里的参数。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net2 = MLP()\n", "net2.load_state_dict(torch.load(filename))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为这两个实例都有同样的模型参数,那么对同一个输入`X`的计算结果将会是一样的。我们来验证一下。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True, True, True, True, True, True, True],\n", " [True, True, True, True, True, True, True, True, True, True]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y2 = net2(X)\n", "Y2 == Y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 小结\n", "\n", "* 通过`save`函数和`load`函数可以很方便地读写`Tensor`。\n", "* 通过`load_state_dict`函数可以很方便地读取模型的参数。\n", "\n", "## 练习\n", "\n", "* 即使无须把训练好的模型部署到不同的设备,存储模型参数在实际中还有哪些好处?\n", "\n", "\n", "\n", "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/1255)\n", "\n", "![](../img/qr_read-write.svg)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:pytorch]", "language": "python", "name": "conda-env-pytorch-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }