{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Пример запуска Нейросетевых Дифференциальных уравнений на синтетических данных" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "import os\n", "import argparse\n", "import logging\n", "import tqdm\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Устанавливаем библиотеку с сайта авторов: https://github.com/rtqichen/torchdiffeq" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "from torchdiffeq import odeint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Создаём классы ODEfunc, реализующий обучаемый модуль $f(\\cdot)$, из уравнения $\\frac{dz}{dt} = f(z(t), t, \\theta)$.\n", "В нашем случае, это будет простая трёхслойная полносвязанная нейросеть. Первый слой увеличивает размерность пространства,\n", "второй -- содержит основное число параметров, третий -- проецирует скрытое представление обратно в пространство малой размерности.\n", "\n", "При этом, в слои добавлена зависимость от $t$, как это требуется в исходной функции." ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "class ODEfunc(nn.Module):\n", "\n", " def __init__(self, dim, hidden_dim):\n", " super(ODEfunc, self).__init__()\n", " self.first = nn.Linear(dim, hidden_dim)\n", " self.second = nn.Linear(hidden_dim + 1, hidden_dim)\n", " self.third = nn.Linear(hidden_dim + 1, dim)\n", "\n", " def forward(self, t, x):\n", " out = self.first(x)\n", " times = torch.ones_like(x) * t\n", " cat_inp = torch.cat((out, times), dim=1)\n", " out = self.second(cat_inp)\n", " out = F.relu(out)\n", " times = torch.ones_like(x) * t\n", " out = F.relu(out)\n", " cat_inp = torch.cat((out, times), dim=1)\n", " out = self.third(cat_inp)\n", " return out" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Создаём сеть и считаем количество параметров:" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "network = ODEBlock(ODEfunc(dim=1, hidden_dim=500))" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "252502" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_parameters(network)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Обучаем архитектуру.\n", "Возьмём синтетические данные - случайные числа от 0 до 2, будем предсказывать их квадрат.\n", "Заметим, что " ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(params=network.parameters(), lr=0.000001)" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "for i in range(20000):\n", " # генерируем данные\n", " batch = np.random.sample(size=(400, 1)) * 2\n", " values = batch ** 2\n", " batch = torch.tensor(batch, dtype=torch.float32)\n", " values = torch.tensor(values, dtype=torch.float32)\n", " \n", " # считаем значения и функцию потерь\n", " predictions = network(batch, tol=1e-3)\n", " loss = F.mse_loss(input=predictions, target=values)\n", " \n", " # считаем градиент ошибок по папарметрам и делаем шаг в направлении антиградиента\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if i % 100 == 0:\n", " print(\"MSE Loss (iter {}): {:.3f}\".format(i, float(loss)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Визуализация" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(0, 2, 40)\n", "out = network(torch.tensor(x.reshape(-1, 1), dtype=torch.float32))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.style.use('ggplot')" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(6, 6))\n", "plt.plot(x, out.detach().numpy(), label='NeuralODE')\n", "plt.plot(x, x**2, label='golden')\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }