{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 04 MNIST Example with DataLoaders and Convolutions\n", "\n", "In this notebook we will look at how data loaders can be used in `pycox`.\n", "This is particularly useful when working with larger data sets than what is possible to fit in memory, and is an important part of any deep learning framework.\n", "As `pycox` is build on [torchtuples](https://github.com/havakv/torchtuples), the same principles applies as for `torchtuples.Model`.\n", "\n", "For our example, we will consider the [simulation study proposed by Gensheimer and Narasimhan](https://peerj.com/articles/6257/) based on the MNIST data set of handwritten digits. \n", "The basic ideas is that each digit represents a survival function, so if we can identify the digit, it is quite straight forward to get good survival estimates.\n", "We will use the `LogisticHazard` methods (which [Gensheimer and Narasimhan](https://peerj.com/articles/6257/) refer to as Nnet-survival), with a convolutional network.\n", "\n", "We will however, consider a slightly different survival function than that of [Gensheimer and Narasimhan](https://peerj.com/articles/6257/), and we will consider all the digits from 0 to 9, while [Gensheimer and Narasimhan](https://peerj.com/articles/6257/) only considered the first 5." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader \n", "\n", "# MNIST is part of torchvision\n", "from torchvision import datasets, transforms\n", "\n", "\n", "import torchtuples as tt\n", "from pycox.models import LogisticHazard\n", "from pycox.utils import kaplan_meier\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# for reproducability\n", "np.random.seed(1234)\n", "_ = torch.manual_seed(1234)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The Dataset\n", "\n", "We start by obtaining the MNIST data set with standard preprocessing. The `transform` ensures the data is a `torch.Tensor` and normalize with with a mean and standard deviation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))]\n", ")\n", "mnist_train = datasets.MNIST('.', train=True, download=True,\n", " transform=transform)\n", "mnist_test = datasets.MNIST('.', train=False, transform=transform)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uuS8ANev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpXTQLo3iG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7prE0C3Jhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7E2LAOrQNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTUUx1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7irTgF0pe1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbtgJ8kQQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = plt.imshow(mnist_train[0][0][0].numpy(), cmap='gray')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulation\n", "\n", "Next we need to simulate the responses corresponding to the images.\n", "We draw event times from an exponential distribution with the digit defining the scale parameter\n", "\n", "$$\n", "\\beta(\\text{digit}) = \\frac{365 \\cdot \\exp(-0.6 \\cdot \\text{digit})}{\\log(1.2)},\n", "$$\n", "and we censor all times higher than 700." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def sim_event_times(mnist, max_time=700):\n", " digits = mnist.targets.numpy()\n", " betas = 365 * np.exp(-0.6 * digits) / np.log(1.2)\n", " event_times = np.random.exponential(betas)\n", " censored = event_times > max_time\n", " event_times[censored] = max_time\n", " return tt.tuplefy(event_times, ~censored)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We simulate a training set and test set, based on the respective MNIST data sets." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "sim_train = sim_event_times(mnist_train)\n", "sim_test = sim_event_times(mnist_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 21.19004682, 700. , 104.56743096, ..., 121.80432849,\n", " 2.50843078, 13.8114342 ]),\n", " array([ True, False, True, ..., True, True, True]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sim_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize\n", "We can visualize the survival curves for the 10 digits by applying the Kaplan-Meier estimator to the collection of event times for each digit" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", " idx = mnist_train.targets.numpy() == i\n", " kaplan_meier(*sim_train.iloc[idx]).rename(i).plot()\n", "_ = plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal will be to estimate these survival functions from the images." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label transforms\n", "\n", "Our simulated event times are drawn in continuous time, so to apply the `LogisticHazard` method, we need to discretize the observations. This can be done with the `label_transform` attribute, and we here use an equidistant grid with 20 grid points." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "labtrans = LogisticHazard.label_transform(20)\n", "target_train = labtrans.fit_transform(*sim_train)\n", "target_test = labtrans.transform(*sim_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The disretization grid is " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0. , 36.84210526, 73.68421053, 110.52631579,\n", " 147.36842105, 184.21052632, 221.05263158, 257.89473684,\n", " 294.73684211, 331.57894737, 368.42105263, 405.26315789,\n", " 442.10526316, 478.94736842, 515.78947368, 552.63157895,\n", " 589.47368421, 626.31578947, 663.15789474, 700. ])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labtrans.cuts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and the discrete targets are" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 1, 19, 3, ..., 4, 1, 1]),\n", " array([1., 0., 1., ..., 1., 1., 1.], dtype=float32))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Make DataLoaders\n", "\n", "To make a `DataLoader` we first need to create a `Dataset`. The `DataSet` is responsible for the obtaining and transforming the data, while the `DataLoader` is contains a `DataSet` a batch sampler etc.\n", "\n", "The standard way to create a `Dataset` in PyTorch is by inheriting the `Dataset` class and defining the `__getitem__` method which reads the data for one individual at a time. \n", "This also require a `collate_fn` for combining multiple individuals into a batch. \n", "\n", "The following is an example of this approach, but we will shortly present an alternative approach that is more in line with `torchtuples`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class MnistSimDatasetSingle(Dataset):\n", " \"\"\"Simulatied data from MNIST. Read a single entry at a time.\n", " \"\"\"\n", " def __init__(self, mnist_dataset, time, event):\n", " self.mnist_dataset = mnist_dataset\n", " self.time, self.event = tt.tuplefy(time, event).to_tensor()\n", "\n", " def __len__(self):\n", " return len(self.mnist_dataset)\n", "\n", " def __getitem__(self, index):\n", " if type(index) is not int:\n", " raise ValueError(f\"Need `index` to be `int`. Got {type(index)}.\")\n", " img = self.mnist_dataset[index][0]\n", " return img, (self.time[index], self.event[index])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "dataset_train = MnistSimDatasetSingle(mnist_train, *target_train)\n", "dataset_test = MnistSimDatasetSingle(mnist_test, *target_test)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 28, 28]), (torch.Size([]), torch.Size([])))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samp = tt.tuplefy(dataset_train[1])\n", "samp.shapes()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(19), tensor(0.))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samp[1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our dataset gives a nested tuple `(img, (idx_duration, event))`, meaning the default collate in PyTorch does not work. We therefore use `tuplefy` to stack the tensors instead" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def collate_fn(batch):\n", " \"\"\"Stacks the entries of a nested tuple\"\"\"\n", " return tt.tuplefy(batch).stack()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoader\n", "\n", "We can now use the regular pytorch `DataLoader`.\n", "Note that you can set the argument `num_workers` in the `DataLoader` to use multiple processes for reading data. Dependent on the system (mac/linux/windows) this can cause some memory issues, so we here use the default `num_workers = 0`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "batch_size = 128" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "dl_train = DataLoader(dataset_train, batch_size, shuffle=True, collate_fn=collate_fn)\n", "dl_test = DataLoader(dataset_test, batch_size, shuffle=False, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we now investigate a batch, we see that we have the same tuple structure `(img, (idx_durations, events))` but in a batch of size 128." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([128, 1, 28, 28]), (torch.Size([128]), torch.Size([128])))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = next(iter(dl_train))\n", "batch.shapes()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.float32, (torch.int64, torch.float32))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch.dtypes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset with batches (alternative)\n", "\n", "When working with `torchtuples` it is typically simpler to read a batch at a times. This means that we do not need a `collate_fn`, and all the logic is in the `Dataset`.\n", "This approach is not needed, and if you prefer the regular PyTorch `DataLoader`, you can skip this and continue at the Convolutional Network section." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class MnistSimDatasetBatch(Dataset):\n", " def __init__(self, mnist_dataset, time, event):\n", " self.mnist_dataset = mnist_dataset\n", " self.time, self.event = tt.tuplefy(time, event).to_tensor()\n", "\n", " def __len__(self):\n", " return len(self.time)\n", "\n", " def __getitem__(self, index):\n", " if not hasattr(index, '__iter__'):\n", " index = [index]\n", " img = [self.mnist_dataset[i][0] for i in index]\n", " img = torch.stack(img)\n", " return tt.tuplefy(img, (self.time[index], self.event[index]))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "dataset_train = MnistSimDatasetBatch(mnist_train, *target_train)\n", "dataset_test = MnistSimDatasetBatch(mnist_test, *target_test)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3, 1, 28, 28]), (torch.Size([3]), torch.Size([3])))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samp = dataset_train[[0, 1, 3]]\n", "samp.shapes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoaderBatch\n", "\n", "As we have a `Dataset` that reads a batch at a time, we cannot use the regular pytorch `DataLoader`.\n", "Instead we have to rely on the `DataLoaderBatch` from `torchtuples`, but note that we don't need the `collate_fn`." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "dl_train = tt.data.DataLoaderBatch(dataset_train, batch_size, shuffle=True)\n", "dl_test = tt.data.DataLoaderBatch(dataset_test, batch_size, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([128, 1, 28, 28]), (torch.Size([128]), torch.Size([128])))" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = next(iter(dl_train))\n", "batch.shapes()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.float32, (torch.int64, torch.float32))" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch.dtypes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that the end result is the same as for thte `DataLoader` above, so use the methods you find the simplest." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convolutional Network\n", "\n", "We will use a convolutional network with two convolutional layers, global average pooling, and two dense layers. This networks is very basic, so better performance would be expected with a more carefully designed network." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self, out_features):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(1, 16, 5, 1)\n", " self.max_pool = nn.MaxPool2d(2)\n", " self.conv2 = nn.Conv2d(16, 16, 5, 1)\n", " self.glob_avg_pool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc1 = nn.Linear(16, 16)\n", " self.fc2 = nn.Linear(16, out_features)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.conv1(x))\n", " x = self.max_pool(x)\n", " x = F.relu(self.conv2(x))\n", " x = self.glob_avg_pool(x)\n", " x = torch.flatten(x, 1)\n", " x = F.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Net(\n", " (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (conv2): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (glob_avg_pool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc1): Linear(in_features=16, out_features=16, bias=True)\n", " (fc2): Linear(in_features=16, out_features=20, bias=True)\n", ")" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net = Net(labtrans.out_features)\n", "net" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Logistic-Hazard Model\n", "\n", "We use the `LogisticHazard` with the Adam optimizer with a learning rate of 0.01." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To verify that the network works as expected we can use the batch from before" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 20])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = model.predict(batch[0])\n", "pred.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training\n", "\n", "We fit the network with `fit_dataloader` and use the `dl_test` to monitor the test performance. It should go without saying that, in practice, we need a validation set separate from the test set when we use early stopping, but this is just an illustrative example." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\t[33s / 33s],\t\ttrain_loss: 2.0743,\tval_loss: 1.9062\n", "1:\t[27s / 1m:1s],\t\ttrain_loss: 1.8318,\tval_loss: 1.7913\n", "2:\t[39s / 1m:40s],\t\ttrain_loss: 1.7777,\tval_loss: 1.7586\n", "3:\t[39s / 2m:19s],\t\ttrain_loss: 1.7623,\tval_loss: 1.7509\n", "4:\t[54s / 3m:14s],\t\ttrain_loss: 1.7507,\tval_loss: 1.7266\n", "5:\t[36s / 3m:51s],\t\ttrain_loss: 1.7427,\tval_loss: 1.7431\n", "6:\t[35s / 4m:27s],\t\ttrain_loss: 1.7330,\tval_loss: 1.7263\n", "7:\t[34s / 5m:1s],\t\ttrain_loss: 1.7266,\tval_loss: 1.7247\n", "8:\t[34s / 5m:36s],\t\ttrain_loss: 1.7265,\tval_loss: 1.7159\n", "9:\t[34s / 6m:11s],\t\ttrain_loss: 1.7179,\tval_loss: 1.7112\n", "10:\t[34s / 6m:46s],\t\ttrain_loss: 1.7146,\tval_loss: 1.7072\n", "11:\t[34s / 7m:20s],\t\ttrain_loss: 1.7136,\tval_loss: 1.7524\n", "12:\t[34s / 7m:55s],\t\ttrain_loss: 1.7107,\tval_loss: 1.7295\n", "13:\t[36s / 8m:32s],\t\ttrain_loss: 1.7080,\tval_loss: 1.7014\n", "14:\t[34s / 9m:6s],\t\ttrain_loss: 1.7051,\tval_loss: 1.7121\n", "15:\t[34s / 9m:41s],\t\ttrain_loss: 1.7054,\tval_loss: 1.7022\n", "16:\t[36s / 10m:18s],\t\ttrain_loss: 1.7026,\tval_loss: 1.7134\n", "17:\t[46s / 11m:5s],\t\ttrain_loss: 1.6998,\tval_loss: 1.6986\n", "18:\t[43s / 11m:48s],\t\ttrain_loss: 1.7000,\tval_loss: 1.7048\n", "19:\t[36s / 12m:25s],\t\ttrain_loss: 1.6948,\tval_loss: 1.6906\n", "20:\t[32s / 12m:58s],\t\ttrain_loss: 1.6955,\tval_loss: 1.6941\n", "21:\t[31s / 13m:29s],\t\ttrain_loss: 1.6943,\tval_loss: 1.6925\n", "22:\t[33s / 14m:2s],\t\ttrain_loss: 1.6925,\tval_loss: 1.6953\n", "23:\t[32s / 14m:34s],\t\ttrain_loss: 1.6936,\tval_loss: 1.6934\n", "24:\t[30s / 15m:5s],\t\ttrain_loss: 1.6918,\tval_loss: 1.6893\n", "25:\t[30s / 15m:35s],\t\ttrain_loss: 1.6883,\tval_loss: 1.6930\n", "26:\t[29s / 16m:5s],\t\ttrain_loss: 1.6865,\tval_loss: 1.6927\n", "27:\t[30s / 16m:35s],\t\ttrain_loss: 1.6863,\tval_loss: 1.6880\n", "28:\t[34s / 17m:10s],\t\ttrain_loss: 1.6857,\tval_loss: 1.6880\n", "29:\t[30s / 17m:40s],\t\ttrain_loss: 1.6862,\tval_loss: 1.6871\n", "30:\t[32s / 18m:13s],\t\ttrain_loss: 1.6849,\tval_loss: 1.6926\n", "31:\t[29s / 18m:42s],\t\ttrain_loss: 1.6838,\tval_loss: 1.6952\n", "32:\t[29s / 19m:12s],\t\ttrain_loss: 1.6834,\tval_loss: 1.6954\n", "33:\t[29s / 19m:42s],\t\ttrain_loss: 1.6835,\tval_loss: 1.6891\n", "34:\t[29s / 20m:11s],\t\ttrain_loss: 1.6805,\tval_loss: 1.6897\n" ] } ], "source": [ "callbacks = [tt.cb.EarlyStopping(patience=5)]\n", "epochs = 50\n", "verbose = True\n", "log = model.fit_dataloader(dl_train, epochs, callbacks, verbose, val_dataloader=dl_test)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = log.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prediction\n", "\n", "To predict, we need a data loader that only gives the images and not the targets. We therefore need to create a new `Dataset` for this purpose." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "class MnistSimInput(Dataset):\n", " def __init__(self, mnist_dataset):\n", " self.mnist_dataset = mnist_dataset\n", "\n", " def __len__(self):\n", " return len(self.mnist_dataset)\n", "\n", " def __getitem__(self, index):\n", " img = self.mnist_dataset[index][0]\n", " return img" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "dataset_test_x = MnistSimInput(mnist_test)\n", "dl_test_x = DataLoader(dataset_test_x, batch_size, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 1, 28, 28])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(dl_test_x)).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### (alternative)\n", "Alternatively, if you have used the batch method, we can use the method `dataloader_input_only` to create this `Dataloader` from `dl_test`." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "dl_test_x = tt.data.dataloader_input_only(dl_test)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 1, 28, 28])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(dl_test_x)).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Survial predictions\n", "\n", "We can obtain survival prediction in the regular manner, and one can include the `interpolation` if wanted." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "surv = model.predict_surv_df(dl_test_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Results\n", "\n", "We compute the average survival predictions for each digit in the test set" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", " idx = mnist_test.targets.numpy() == i\n", " surv.loc[:, idx].mean(axis=1).rename(i).plot()\n", "_ = plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and find that they are quite similar to the Kaplan-Meier estimates!" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", " idx = mnist_test.targets.numpy() == i\n", " kaplan_meier(*sim_test.iloc[idx]).rename(i).plot()\n", "_ = plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Concordance and Brier score" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "surv = model.interpolate(10).predict_surv_df(dl_test_x)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "ev = EvalSurv(surv, *sim_test, 'km')" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7426348804216191" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.concordance_td()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.10559285952465855" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time_grid = np.linspace(0, sim_test[0].max())\n", "ev.integrated_brier_score(time_grid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Next\n", "\n", "You can now look at other examples of survival methods in the [examples folder](https://nbviewer.jupyter.org/github/havakv/pycox/tree/master/examples).\n", "Or, alternatively take a look at\n", "\n", "- the more advanced training procedures in the notebook [02_introduction.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/02_introduction.ipynb).\n", "- other network architectures that combine autoencoders and survival networks in the notebook [03_network_architectures.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/03_network_architectures.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }