{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# 多 GPU 训练" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:50:25.762210Z", "start_time": "2019-07-03T22:50:22.940185Z" }, "attributes": { "classes": [], "id": "", "n": "1" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import autograd, gluon, init, np, npx\n", "from mxnet.gluon import nn\n", "npx.set_np()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在多 GPU 上初始模型参数。\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:50:32.537906Z", "start_time": "2019-07-03T22:50:25.764263Z" }, "attributes": { "classes": [], "id": "", "n": "3" } }, "outputs": [], "source": [ "net = d2l.resnet18(10)\n", "ctx = d2l.try_all_gpus()\n", "net.initialize(init=init.Normal(sigma=0.01), ctx=ctx)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "验证。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:50:32.772457Z", "start_time": "2019-07-03T22:50:32.539859Z" }, "attributes": { "classes": [], "id": "", "n": "4" } }, "outputs": [ { "data": { "text/plain": [ "(array([[ 9.4988764e-07, 4.0808845e-06, -5.1063816e-06, -4.9375967e-06,\n", " 1.1718329e-06, -5.6178824e-06, -4.8232919e-06, 1.9737163e-06,\n", " -7.3709026e-07, 2.2256274e-06],\n", " [ 7.7096996e-07, 4.2829342e-06, -6.1890505e-06, -5.4664861e-06,\n", " 1.2786281e-06, -5.2085825e-06, -4.6386904e-06, 2.0427817e-06,\n", " -1.0129007e-06, 2.0370280e-06]], ctx=gpu(0)),\n", " array([[ 2.4921033e-07, 3.8222056e-06, -5.5915402e-06, -5.4971724e-06,\n", " 1.4587372e-06, -4.5317338e-06, -4.8936981e-06, 2.3227499e-06,\n", " -3.8662023e-07, 1.8324375e-06],\n", " [-8.7117598e-08, 3.6717909e-06, -5.0221552e-06, -5.0705357e-06,\n", " 2.1382066e-06, -4.9615883e-06, -4.7389462e-06, 2.3168993e-06,\n", " -4.3993109e-07, 2.1564051e-06]], ctx=gpu(1)))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.random.uniform(size=(4, 1, 28, 28))\n", "gpu_x = gluon.utils.split_and_load(x, ctx)\n", "net(gpu_x[0]), net(gpu_x[1])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在多 GPU 上计算精度。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:50:32.779294Z", "start_time": "2019-07-03T22:50:32.774319Z" }, "attributes": { "classes": [], "id": "", "n": "6" } }, "outputs": [], "source": [ "def evaluate_accuracy_gpus(net, data_iter):\n", " # 查看所在的所有设备\n", " ctx_list = list(net.collect_params().values())[0].list_ctx()\n", " metric = d2l.Accumulator(2) # 分类正确的样本数,总样本数。\n", " for features, labels in data_iter:\n", " Xs, ys = d2l.split_batch(features, labels, ctx_list)\n", " pys = [net(X) for X in Xs] # 并行执行。\n", " metric.add(sum(float(d2l.accuracy(py, y)) for py, y in zip(pys, ys)), \n", " labels.size)\n", " return metric[0]/metric[1]" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "训练函数。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:50:32.788932Z", "start_time": "2019-07-03T22:50:32.780660Z" }, "attributes": { "classes": [], "id": "", "n": "7" } }, "outputs": [], "source": [ "def train(num_gpus, batch_size, lr):\n", " train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n", " ctx_list = [d2l.try_gpu(i) for i in range(num_gpus)]\n", " net.initialize(init=init.Normal(sigma=0.01),\n", " ctx=ctx_list, force_reinit=True)\n", " trainer = gluon.Trainer(\n", " net.collect_params(), 'sgd', {'learning_rate': lr})\n", " loss = gluon.loss.SoftmaxCrossEntropyLoss()\n", " timer, num_epochs = d2l.Timer(), 10\n", " animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])\n", " for epoch in range(num_epochs):\n", " timer.start()\n", " for features, labels in train_iter:\n", " Xs, ys = d2l.split_batch(features, labels, ctx_list)\n", " with autograd.record():\n", " ls = [loss(net(X), y) for X, y in zip(Xs, ys)]\n", " for l in ls:\n", " l.backward()\n", " trainer.step(batch_size)\n", " npx.waitall()\n", " timer.stop()\n", " animator.add(epoch+1, (evaluate_accuracy_gpus(net, test_iter),))\n", " print('test acc: %.2f, %.1f sec/epoch on %s' % (\n", " animator.Y[0][-1], timer.avg(), ctx_list))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "使用一个 GPU。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:52:56.382631Z", "start_time": "2019-07-03T22:50:32.790237Z" }, "attributes": { "classes": [], "id": "", "n": "8" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test acc: 0.93, 13.2 sec/epoch on [gpu(0)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train(num_gpus=1, batch_size=256, lr=0.1)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "使用两个 GPU。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:54:18.734545Z", "start_time": "2019-07-03T22:52:56.384544Z" }, "attributes": { "classes": [], "id": "", "n": "9" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test acc: 0.92, 6.8 sec/epoch on [gpu(0), gpu(1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train(num_gpus=2, batch_size=512, lr=0.2)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }