{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Данный ноутбук содержит код эксперимента по исследованию возможности модели нейродифференциального уравнения явным образом контролировать trade-off между численной точностью и вычислительными затратами." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "import torchvision.datasets as datasets\n", "import torchvision.transforms as transforms\n", "\n", "from torchdiffeq import odeint_adjoint as odeint" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from utils import norm, Flatten, get_mnist_loaders, one_hot, ConcatConv2d" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "downsampling_layers = [\n", " nn.Conv2d(1, 64, 3, 1),\n", " norm(64),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(64, 64, 4, 2, 1),\n", " norm(64),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(64, 64, 4, 2, 1),\n", "]\n", "fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class ODEfunc(nn.Module):\n", "\n", " def __init__(self, dim):\n", " super(ODEfunc, self).__init__()\n", " self.norm1 = norm(dim)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm2 = norm(dim)\n", " self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm3 = norm(dim)\n", " self.nfe = 0\n", "\n", " def forward(self, t, x):\n", " self.nfe += 1\n", " out = self.norm1(x)\n", " out = self.relu(out)\n", " out = self.conv1(t, out)\n", " out = self.norm2(out)\n", " out = self.relu(out)\n", " out = self.conv2(t, out)\n", " out = self.norm3(out)\n", " return out\n", " \n", "class ODEBlock(nn.Module):\n", "\n", " def __init__(self, odefunc, tol=1e-3, method=None):\n", " super(ODEBlock, self).__init__()\n", " self.odefunc = odefunc\n", " self.integration_time = torch.tensor([0, 1]).float()\n", " self.tol = tol\n", " self.method = method\n", "\n", " def forward(self, x):\n", " self.integration_time = self.integration_time.type_as(x)\n", " out = odeint(self.odefunc, x, self.integration_time, rtol=self.tol, \n", " atol=self.tol, method=self.method)\n", " return out[1]\n", "\n", " @property\n", " def nfe(self):\n", " return self.odefunc.nfe\n", "\n", " @nfe.setter\n", " def nfe(self, value):\n", " self.odefunc.nfe = value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Загрузим веса заранее обученной модели ODE-Net (во время обучения использовалась максимально допустимая абсолютная ошибка численного метода $tol=10^{-3}$):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "checkpoint = torch.load('ODEnet_mnist.pth')\n", "batch_size = test_batch_size = 1000\n", "data_aug = False" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "all_acc = []\n", "all_times = []\n", "nfes = []\n", "tols = [1e-4, 1e-3, 1e-2, 1e-1, 1]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tol=0.0001, accuracy=0.9961, mean_time=4.60, nfe=320\n", "Tol=0.001, accuracy=0.9961, mean_time=3.76, nfe=260\n", "Tol=0.01, accuracy=0.996, mean_time=2.93, nfe=200\n", "Tol=0.1, accuracy=0.9961, mean_time=2.13, nfe=140\n", "Tol=1, accuracy=0.9956, mean_time=2.13, nfe=140\n" ] } ], "source": [ "for tol in tols:\n", " feature_layers = [ODEBlock(ODEfunc(64), tol=tol)]\n", " model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)\n", " model.load_state_dict(checkpoint['state_dict'])\n", " model.eval()\n", " train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_batch_size)\n", "\n", " with torch.no_grad():\n", " times = []\n", " total_correct = 0\n", " for x, y in test_loader:\n", " x = x.to(device)\n", " y = one_hot(np.array(y.numpy()), 10)\n", "\n", " target_class = np.argmax(y, axis=1)\n", " start = time.time()\n", " preds = model(x)\n", " times.append(time.time() - start)\n", " predicted_class = np.argmax(preds.cpu().detach().numpy(), axis=1)\n", " total_correct += np.sum(predicted_class == target_class)\n", " accuracy = total_correct / len(test_loader.dataset)\n", " nfe = feature_layers[0].nfe\n", " all_acc.append(accuracy)\n", " nfes.append(nfe)\n", " all_times.append(times)\n", " print('Tol={0}, accuracy={1}, mean_time={2:0.2f}, nfe={3}'.format(tol, accuracy, np.mean(times), nfe))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }