{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# General Information\n", "\n", "This notebook demonstrates how the `fastai_sparse` library can be used in semantic segmentation tasks using the example of the [ShapeNet Core55](https://shapenet.cs.stanford.edu/iccv17/) 3D semantic segmentation solution presented in [SparseConvNet example](https://github.com/facebookresearch/SparseConvNet/tree/master/examples/3d_segmentation). \n", "\n", "\n", "<img src=\"https://camo.githubusercontent.com/a94ad53ba6adc857323bd9ba3050805fa16d8aab/687474703a2f2f6d73617676612e6769746875622e696f2f66696c65732f73686170656e65742e706e67\" width=\"480\" />\n", "\n", "Initial data is subset of ShapeNetCore containing about 17,000 models from 16 shape categories Each category is annotated with 2 to 6 parts and there are 50 different parts annotated in total. 3D shapes are represented as point clouds uniformly sampled from 3D surfaces\n", "\n", "Evaluation metric: weighted IoU (see [https://arxiv.org/pdf/1711.10275.pdf](https://arxiv.org/pdf/1711.10275.pdf))\n", "\n", "Firstly, it is necessary to upload and prepare the initial data. See [examples/shapenet_iccv17](examples/shapenet_iccv17)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import sparseconvnet as scn\n", "import time\n", "import os, sys\n", "import math\n", "import numpy as np\n", "import pandas as pd\n", "import datetime\n", "import glob\n", "from IPython.display import display, HTML, FileLink\n", "from os.path import join, exists, basename, splitext\n", "from pathlib import Path\n", "from matplotlib import pyplot as plt\n", "from matplotlib import cm\n", "#from tensorboardX import SummaryWriter\n", "from joblib import cpu_count\n", "from tqdm import tqdm\n", "#from tqdm import tqdm_notebook as tqdm\n", "\n", "#import fastai\n", "\n", "# autoreload python modules on the fly when its source is changed\n", "%load_ext autoreload\n", "\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from fastai_sparse import utils\n", "from fastai_sparse.utils import log, log_dict, print_random_states\n", "from fastai_sparse.datasets import find_files\n", "from fastai_sparse.datasets import PointsDataset\n", "#, SparseDataBunch\n", "from fastai_sparse import visualize\n", "\n", "from datasets import DataSourceConfig, reader_fn\n", "import transform as T\n", "#from data import merge_fn" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "assert torch.cuda.is_available()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment environment / system metrics" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "experiment_name = 'unet_24_detailed'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "virtualenv: (fastai_sparse) \n", "python: 3.6.8\n", "nvidia driver: b'384.130'\n", "nvidia cuda: 9.0, V9.0.176\n", "cudnn: 7.1.4\n", "torch: 1.0.1.post2\n", "fastai: 1.0.48\n", "fastai_sparse: 0.0.3.dev0\n" ] } ], "source": [ "utils.watermark()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook options" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<style>.container { width:70% !important; }</style>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "utils.wide_notebook()\n", "# uncomment this lines if you want switch off interactive and save visaulisation as screenshoots:\n", "# For rendering run command in terminal: `chromium-browser --remote-debugging-port=9222`\n", "if True:\n", " visualize.options.interactive = False\n", " visualize.options.save_images = True\n", " visualize.options.verbose = True\n", " visualize.options.filename_pattern_image = Path('images', experiment_name, 'fig_{fig_number}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sourse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create DataFrames" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data\n", "data/train_val\n" ] } ], "source": [ "SOURCE_DIR = Path('data').expanduser()\n", "#SOURCE_DIR = Path('/home/ssd/shapenet_data').expanduser()\n", "assert SOURCE_DIR.exists()\n", "\n", "DIR_TRAIN_VAL = SOURCE_DIR / 'train_val'\n", "assert DIR_TRAIN_VAL.exists(), \"Hint: run `download_and_split_data.sh` then `convert_to_numpy.ipynb`\"\n", "\n", "print(SOURCE_DIR)\n", "print(DIR_TRAIN_VAL)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of categories: 16\n" ] } ], "source": [ "categories = [\n", " \"02691156\", \"02773838\", \"02954340\", \"02958343\", \"03001627\", \"03261776\",\n", " \"03467517\", \"03624134\", \"03636649\", \"03642806\", \"03790512\", \"03797390\",\n", " \"03948459\", \"04099429\", \"04225987\", \"04379243\"\n", "]\n", "\n", "classes = [\n", " 'Airplane', 'Bag', 'Cap', 'Car', 'Chair', 'Earphone', 'Guitar', 'Knife',\n", " 'Lamp', 'Laptop', 'Motorbike', 'Mug', 'Pistol', 'Rocket', 'Skateboard',\n", " 'Table'\n", "]\n", "\n", "num_classes_by_category = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]\n", "assert len(categories) == len(classes)\n", "\n", "print(\"Number of categories:\", len(categories))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "df_train = find_files(path=SOURCE_DIR / 'npy' / 'train', ext='.points.npy', ext_labels='.labels.npy', categories=categories)\n", "df_valid = find_files(path=SOURCE_DIR / 'npy' / 'valid', ext='.points.npy', ext_labels='.labels.npy', categories=categories)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6955\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>example_id</th>\n", " <th>subdir</th>\n", " <th>categ_idx</th>\n", " <th>ext</th>\n", " <th>ext_labels</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>000908</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>010886</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>013973</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>007190</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>010360</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " example_id subdir categ_idx ext ext_labels\n", "0 000908 02691156 0 .points.npy .labels.npy\n", "1 010886 02691156 0 .points.npy .labels.npy\n", "2 013973 02691156 0 .points.npy .labels.npy\n", "3 007190 02691156 0 .points.npy .labels.npy\n", "4 010360 02691156 0 .points.npy .labels.npy" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(len(df_train))\n", "df_train.head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7052\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>example_id</th>\n", " <th>subdir</th>\n", " <th>categ_idx</th>\n", " <th>ext</th>\n", " <th>ext_labels</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>005663</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>011957</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>009038</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>009906</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>004778</td>\n", " <td>02691156</td>\n", " <td>0</td>\n", " <td>.points.npy</td>\n", " <td>.labels.npy</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " example_id subdir categ_idx ext ext_labels\n", "0 005663 02691156 0 .points.npy .labels.npy\n", "1 011957 02691156 0 .points.npy .labels.npy\n", "2 009038 02691156 0 .points.npy .labels.npy\n", "3 009906 02691156 0 .points.npy .labels.npy\n", "4 004778 02691156 0 .points.npy .labels.npy" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(len(df_valid))\n", "df_valid.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DataSets config\n", "You can create PointsDataset using the configuration." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataSourceConfig;\n", " root_dir: data/npy/train\n", " batch_size: 16\n", " num_workers: 12\n", " init_numpy_random_seed: True\n", " num_classes: 50\n", " num_classes_by_category: [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]\n", " class_offsets: [ 0 4 6 8 12 16 19 22 24 28 30 36 38 41 44 47 50]\n", " Items count: 6955" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_source_config = DataSourceConfig( \n", " root_dir=SOURCE_DIR / 'npy' / 'train', \n", " df=df_train,\n", " batch_size=16,\n", " num_workers=12,\n", " num_classes=50,\n", " num_classes_by_category=num_classes_by_category,\n", " )\n", "train_source_config" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataSourceConfig;\n", " root_dir: data/npy/valid\n", " batch_size: 16\n", " num_workers: 12\n", " init_numpy_random_seed: False\n", " num_classes: 50\n", " num_classes_by_category: [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]\n", " class_offsets: [ 0 4 6 8 12 16 19 22 24 28 30 36 38 41 44 47 50]\n", " Items count: 7052" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_source_config = DataSourceConfig( \n", " root_dir=SOURCE_DIR / 'npy' / 'valid',\n", " df=df_valid,\n", " batch_size=16,\n", " num_workers=12,\n", " num_classes=50,\n", " num_classes_by_category=num_classes_by_category,\n", " init_numpy_random_seed=False,\n", " )\n", "valid_source_config" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "train_source_config.check_accordance(valid_source_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Datasets" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Load file names: 100%|██████████| 6955/6955 [00:01<00:00, 6525.49it/s]\n", "Load file names: 100%|██████████| 7052/7052 [00:01<00:00, 6516.53it/s]\n" ] } ], "source": [ "train_items = PointsDataset.from_source_config(train_source_config, reader_fn=reader_fn)\n", "valid_items = PointsDataset.from_source_config(valid_source_config, reader_fn=reader_fn)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Check files exist: 100%|██████████| 6955/6955 [00:00<00:00, 188045.98it/s]\n", "Check files exist: 100%|██████████| 7052/7052 [00:00<00:00, 184664.28it/s]\n" ] } ], "source": [ "train_items.check()\n", "valid_items.check()\n", "#train_items.check_num_classes(max_num_examples=100)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PointsDataset (6955 items)\n", "('000908', n: 2460),('010886', n: 2407),('013973', n: 2463),('007190', n: 2690),('010360', n: 2438)\n", "Path: data/npy/train" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_items" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PointsItem ('002337', n: 2712)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o = train_items.get(5)\n", "o" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what we've done with one example. " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PointsItem (002337)\n", "points shape: (2712, 3) dtype: float32 min: -0.35228, max: 0.35657, mean: -0.01345\n", "labels shape: (2712,) dtype: int64 min: 1, max: 4, mean: 1.74410\n" ] }, { "data": { "text/html": [ "Saved to file: <a href='images/unet_24_detailed/fig_1.png' target='_blank'>images/unet_24_detailed/fig_1.png</a>" ], "text/plain": [ "images/unet_24_detailed/fig_1.png" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "<IPython.core.display.Image object>" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o.describe()\n", "o.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define transforms\n", "\n", "In order to reproduce the [example of SparseConvNet](https://github.com/facebookresearch/SparseConvNet/tree/master/examples/3d_segmentation), the same transformations have been redone, but in the manner of fast.ai transformations.\n", "\n", "The following cells define the transformations: preprocessing (PRE_TFMS); augmentation (AUGS_); and transformation to convert the points cloud to a sparse representation (SPARSE_TFMS). Sparse representation is the input format for the SparseConvNet model and contains a list of voxels and their features" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "resolution = 24\n", "\n", "PRE_TFMS = [\n", " T.fit_to_sphere(center=False),\n", " T.shift_labels(offset=-1)\n", " ]\n", "\n", "AUGS_TRAIN = [\n", " T.rotate(),\n", " T.flip_x(p=0.5),\n", "]\n", "\n", "AUGS_VALID = [\n", " T.rotate(),\n", " T.flip_x(p=0.5),\n", "]\n", "\n", "SPARSE_TFMS = [\n", " T.translate(offset=2), # segment [-1, 1] ---> segment [1, 3]\n", " T.scale(scale=resolution),\n", " T.merge_features(ones=True),\n", " T.to_sparse_voxels(),\n", "]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what we got in results of train and valid tranformations for the first example:\n", "\n", "initial representation:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PointsItem (012843)\n", "points shape: (2640, 3) dtype: float32 min: -0.45134, max: 0.44987, mean: 0.01532\n", "labels shape: (2640,) dtype: int64 min: 48, max: 49, mean: 48.21932\n" ] }, { "data": { "text/html": [ "Saved to file: <a href='images/unet_24_detailed/fig_2.png' target='_blank'>images/unet_24_detailed/fig_2.png</a>" ], "text/plain": [ "images/unet_24_detailed/fig_2.png" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "<IPython.core.display.Image object>" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o = train_items[-1]\n", "o.describe()\n", "o.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "transformed:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "id: 012843\n", "coords shape: (2640, 3) dtype: int64 min: 27, max: 70, mean: 47.64508\n", "features shape: (2640, 1) dtype: float32 min: 1.00000, max: 1.00000, mean: 1.00000\n", "x shape: (2640,) dtype: int64 min: 27, max: 70, mean: 47.19735\n", "y shape: (2640,) dtype: int64 min: 34, max: 59, mean: 44.41023\n", "z shape: (2640,) dtype: int64 min: 40, max: 59, mean: 51.32765\n", "labels shape: (2640,) dtype: int64 min: 47, max: 48, mean: 47.21932\n", "voxels: 1011\n", "points / voxels: 2.6112759643916914\n" ] }, { "data": { "text/html": [ "Saved to file: <a href='images/unet_24_detailed/fig_3.png' target='_blank'>images/unet_24_detailed/fig_3.png</a>" ], "text/plain": [ "images/unet_24_detailed/fig_3.png" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "<IPython.core.display.Image object>" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.random.seed(42)\n", "b = o.copy().apply_tfms(PRE_TFMS + AUGS_TRAIN + SPARSE_TFMS)\n", "b.describe()\n", "b.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Apply transforms to datasets" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "tfms = (\n", " PRE_TFMS + AUGS_TRAIN + SPARSE_TFMS,\n", " PRE_TFMS + AUGS_VALID + SPARSE_TFMS,\n", ")\n", "\n", "train_items.transform(tfms[0])\n", "pass\n", "\n", "valid_items.transform(tfms[1])\n", "pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DataBunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In fast.ai the data is represented DataBunch which contains train, valid and optionally test data loaders." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "from fastai_sparse.data import SparseDataBunch\n", "from data import merge_fn" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: 6955, shuffle: True, batch_size: 16, num_workers: 12, num_batches: 434, drop_last: True\n", "Valid: 7052, shuffle: False, batch_size: 16, num_workers: 12, num_batches: 441, drop_last: False\n" ] } ], "source": [ "data = SparseDataBunch.create(train_ds=train_items,\n", " valid_ds=valid_items,\n", " collate_fn=merge_fn,)\n", "\n", "data.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataloader idle run speed measurement" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_cpus: 16\n", "Model name: AMD Ryzen 7 1700 Eight-Core Processor\n", "\n" ] } ], "source": [ "from fastai_sparse.core import num_cpus\n", "print(\"num_cpus:\", num_cpus())\n", "!lscpu | grep \"Model\"\n", "print()\n", "\n", "#data.describe()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 434/434 [00:04<00:00, 106.94it/s]\n" ] } ], "source": [ "# train\n", "t = tqdm(enumerate(data.train_dl), total=len(data.train_dl))\n", "for i, batch in t:\n", " pass" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 441/441 [00:02<00:00, 184.60it/s]\n" ] } ], "source": [ "# valid\n", "t = tqdm(enumerate(data.valid_dl), total=len(data.valid_dl))\n", "for i, batch in t:\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "U-Net SparseConvNet implemenation ([link](https://github.com/facebookresearch/SparseConvNet/blob/master/examples/3d_segmentation/unet.py)): " ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SparseModelConfig;\n", " spatial_size: 192\n", " dimension: 3\n", " block_reps: 1\n", " m: 32\n", " num_planes: [32, 64, 96, 128, 160]\n", " residual_blocks: False\n", " num_classes: 50\n", " num_input_features: 1\n", " mode: 3\n", " downsample: [2, 2]\n", " bias: False" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fastai_sparse.learner import SparseModelConfig\n", "\n", "\n", "model_config = SparseModelConfig(spatial_size=24 * 8, num_input_features=1)\n", "model_config.check_accordance(data.train_ds.source_config, sparse_item=data.train_ds[0])\n", "model_config" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, cfg):\n", " nn.Module.__init__(self)\n", " \n", " spatial_size = torch.LongTensor([cfg.spatial_size]*3)\n", " \n", " self.sparseModel = scn.Sequential(\n", " scn.InputLayer(cfg.dimension, spatial_size, mode=cfg.mode),\n", " scn.SubmanifoldConvolution(cfg.dimension, nIn=cfg.num_input_features, nOut=cfg.m, filter_size=3, bias=cfg.bias),\n", " scn.UNet(cfg.dimension, cfg.block_reps, cfg.num_planes, residual_blocks=cfg.residual_blocks, downsample=cfg.downsample),\n", " scn.BatchNormReLU(cfg.m),\n", " scn.OutputLayer(cfg.dimension),\n", " )\n", " self.linear = nn.Linear(cfg.m, cfg.num_classes)\n", "\n", " def forward(self, xb):\n", " coords = xb['coords']\n", " features = xb['features']\n", " x = [coords, features]\n", "\n", " x = self.sparseModel(x)\n", " x = self.linear(x)\n", " return x\n", "\n", "model = Model(model_config)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Model(\n", " (sparseModel): Sequential(\n", " (0): InputLayer()\n", " (1): SubmanifoldConvolution 1->32 C3\n", " (2): Sequential(\n", " (0): Sequential(\n", " (0): BatchNormLeakyReLU(32,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 32->32 C3\n", " )\n", " (1): ConcatTable(\n", " (0): Identity()\n", " (1): Sequential(\n", " (0): BatchNormLeakyReLU(32,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): Convolution 32->64 C2/2\n", " (2): Sequential(\n", " (0): Sequential(\n", " (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 64->64 C3\n", " )\n", " (1): ConcatTable(\n", " (0): Identity()\n", " (1): Sequential(\n", " (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): Convolution 64->96 C2/2\n", " (2): Sequential(\n", " (0): Sequential(\n", " (0): BatchNormLeakyReLU(96,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 96->96 C3\n", " )\n", " (1): ConcatTable(\n", " (0): Identity()\n", " (1): Sequential(\n", " (0): BatchNormLeakyReLU(96,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): Convolution 96->128 C2/2\n", " (2): Sequential(\n", " (0): Sequential(\n", " (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 128->128 C3\n", " )\n", " (1): ConcatTable(\n", " (0): Identity()\n", " (1): Sequential(\n", " (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): Convolution 128->160 C2/2\n", " (2): Sequential(\n", " (0): Sequential(\n", " (0): BatchNormLeakyReLU(160,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 160->160 C3\n", " )\n", " )\n", " (3): BatchNormLeakyReLU(160,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (4): Deconvolution 160->128 C2/2\n", " )\n", " )\n", " (2): JoinTable()\n", " (3): Sequential(\n", " (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 256->128 C3\n", " )\n", " )\n", " (3): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (4): Deconvolution 128->96 C2/2\n", " )\n", " )\n", " (2): JoinTable()\n", " (3): Sequential(\n", " (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 192->96 C3\n", " )\n", " )\n", " (3): BatchNormLeakyReLU(96,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (4): Deconvolution 96->64 C2/2\n", " )\n", " )\n", " (2): JoinTable()\n", " (3): Sequential(\n", " (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 128->64 C3\n", " )\n", " )\n", " (3): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (4): Deconvolution 64->32 C2/2\n", " )\n", " )\n", " (2): JoinTable()\n", " (3): Sequential(\n", " (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.9,affine=True,leakiness=0)\n", " (1): SubmanifoldConvolution 64->32 C3\n", " )\n", " )\n", " (3): BatchNormReLU(32,eps=0.0001,momentum=0.9,affine=True)\n", " (4): OutputLayer()\n", " )\n", " (linear): Linear(in_features=32, out_features=50, bias=True)\n", ")" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total: 3,841,234\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>name</th>\n", " <th>number</th>\n", " <th>shape</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>sparseModel.1.weight</td>\n", " <td>864</td>\n", " <td>27 x 1 x 32</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>sparseModel.2.0.0.weight</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>sparseModel.2.0.0.bias</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>sparseModel.2.0.1.weight</td>\n", " <td>27648</td>\n", " <td>27 x 32 x 32</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>sparseModel.2.1.1.0.weight</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>sparseModel.2.1.1.0.bias</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>sparseModel.2.1.1.1.weight</td>\n", " <td>16384</td>\n", " <td>8 x 32 x 64</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>sparseModel.2.1.1.2.0.0.weight</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>sparseModel.2.1.1.2.0.0.bias</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>sparseModel.2.1.1.2.0.1.weight</td>\n", " <td>110592</td>\n", " <td>27 x 64 x 64</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>sparseModel.2.1.1.2.1.1.0.weight</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>sparseModel.2.1.1.2.1.1.0.bias</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>sparseModel.2.1.1.2.1.1.1.weight</td>\n", " <td>49152</td>\n", " <td>8 x 64 x 96</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.0.0.weight</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.0.0.bias</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.0.1.weight</td>\n", " <td>248832</td>\n", " <td>27 x 96 x 96</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.0.weight</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.0.bias</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.1.weight</td>\n", " <td>98304</td>\n", " <td>8 x 96 x 128</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.0.0.weight</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.0.0.bias</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.0.1.weight</td>\n", " <td>442368</td>\n", " <td>27 x 128 x 128</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.0.weight</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.0.bias</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.1.weight</td>\n", " <td>163840</td>\n", " <td>8 x 128 x 160</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.0.weight</td>\n", " <td>160</td>\n", " <td>160</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.0.bias</td>\n", " <td>160</td>\n", " <td>160</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.1.weight</td>\n", " <td>691200</td>\n", " <td>27 x 160 x 160</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.3.weight</td>\n", " <td>160</td>\n", " <td>160</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.3.bias</td>\n", " <td>160</td>\n", " <td>160</td>\n", " </tr>\n", " <tr>\n", " <th>30</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.4.weight</td>\n", " <td>163840</td>\n", " <td>8 x 160 x 128</td>\n", " </tr>\n", " <tr>\n", " <th>31</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.3.0.weight</td>\n", " <td>256</td>\n", " <td>256</td>\n", " </tr>\n", " <tr>\n", " <th>32</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.3.0.bias</td>\n", " <td>256</td>\n", " <td>256</td>\n", " </tr>\n", " <tr>\n", " <th>33</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.2.3.1.weight</td>\n", " <td>884736</td>\n", " <td>27 x 256 x 128</td>\n", " </tr>\n", " <tr>\n", " <th>34</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.3.weight</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>35</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.3.bias</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>36</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.1.1.4.weight</td>\n", " <td>98304</td>\n", " <td>8 x 128 x 96</td>\n", " </tr>\n", " <tr>\n", " <th>37</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.3.0.weight</td>\n", " <td>192</td>\n", " <td>192</td>\n", " </tr>\n", " <tr>\n", " <th>38</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.3.0.bias</td>\n", " <td>192</td>\n", " <td>192</td>\n", " </tr>\n", " <tr>\n", " <th>39</th>\n", " <td>sparseModel.2.1.1.2.1.1.2.3.1.weight</td>\n", " <td>497664</td>\n", " <td>27 x 192 x 96</td>\n", " </tr>\n", " <tr>\n", " <th>40</th>\n", " <td>sparseModel.2.1.1.2.1.1.3.weight</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>41</th>\n", " <td>sparseModel.2.1.1.2.1.1.3.bias</td>\n", " <td>96</td>\n", " <td>96</td>\n", " </tr>\n", " <tr>\n", " <th>42</th>\n", " <td>sparseModel.2.1.1.2.1.1.4.weight</td>\n", " <td>49152</td>\n", " <td>8 x 96 x 64</td>\n", " </tr>\n", " <tr>\n", " <th>43</th>\n", " <td>sparseModel.2.1.1.2.3.0.weight</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>44</th>\n", " <td>sparseModel.2.1.1.2.3.0.bias</td>\n", " <td>128</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>45</th>\n", " <td>sparseModel.2.1.1.2.3.1.weight</td>\n", " <td>221184</td>\n", " <td>27 x 128 x 64</td>\n", " </tr>\n", " <tr>\n", " <th>46</th>\n", " <td>sparseModel.2.1.1.3.weight</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>47</th>\n", " <td>sparseModel.2.1.1.3.bias</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>48</th>\n", " <td>sparseModel.2.1.1.4.weight</td>\n", " <td>16384</td>\n", " <td>8 x 64 x 32</td>\n", " </tr>\n", " <tr>\n", " <th>49</th>\n", " <td>sparseModel.2.3.0.weight</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>50</th>\n", " <td>sparseModel.2.3.0.bias</td>\n", " <td>64</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>51</th>\n", " <td>sparseModel.2.3.1.weight</td>\n", " <td>55296</td>\n", " <td>27 x 64 x 32</td>\n", " </tr>\n", " <tr>\n", " <th>52</th>\n", " <td>sparseModel.3.weight</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>53</th>\n", " <td>sparseModel.3.bias</td>\n", " <td>32</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>54</th>\n", " <td>linear.weight</td>\n", " <td>1600</td>\n", " <td>50 x 32</td>\n", " </tr>\n", " <tr>\n", " <th>55</th>\n", " <td>linear.bias</td>\n", " <td>50</td>\n", " <td>50</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " name number shape\n", "0 sparseModel.1.weight 864 27 x 1 x 32\n", "1 sparseModel.2.0.0.weight 32 32\n", "2 sparseModel.2.0.0.bias 32 32\n", "3 sparseModel.2.0.1.weight 27648 27 x 32 x 32\n", "4 sparseModel.2.1.1.0.weight 32 32\n", "5 sparseModel.2.1.1.0.bias 32 32\n", "6 sparseModel.2.1.1.1.weight 16384 8 x 32 x 64\n", "7 sparseModel.2.1.1.2.0.0.weight 64 64\n", "8 sparseModel.2.1.1.2.0.0.bias 64 64\n", "9 sparseModel.2.1.1.2.0.1.weight 110592 27 x 64 x 64\n", "10 sparseModel.2.1.1.2.1.1.0.weight 64 64\n", "11 sparseModel.2.1.1.2.1.1.0.bias 64 64\n", "12 sparseModel.2.1.1.2.1.1.1.weight 49152 8 x 64 x 96\n", "13 sparseModel.2.1.1.2.1.1.2.0.0.weight 96 96\n", "14 sparseModel.2.1.1.2.1.1.2.0.0.bias 96 96\n", "15 sparseModel.2.1.1.2.1.1.2.0.1.weight 248832 27 x 96 x 96\n", "16 sparseModel.2.1.1.2.1.1.2.1.1.0.weight 96 96\n", "17 sparseModel.2.1.1.2.1.1.2.1.1.0.bias 96 96\n", "18 sparseModel.2.1.1.2.1.1.2.1.1.1.weight 98304 8 x 96 x 128\n", "19 sparseModel.2.1.1.2.1.1.2.1.1.2.0.0.weight 128 128\n", "20 sparseModel.2.1.1.2.1.1.2.1.1.2.0.0.bias 128 128\n", "21 sparseModel.2.1.1.2.1.1.2.1.1.2.0.1.weight 442368 27 x 128 x 128\n", "22 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.0.weight 128 128\n", "23 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.0.bias 128 128\n", "24 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.1.weight 163840 8 x 128 x 160\n", "25 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.0.weight 160 160\n", "26 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.0.bias 160 160\n", "27 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.2.0.1.weight 691200 27 x 160 x 160\n", "28 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.3.weight 160 160\n", "29 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.3.bias 160 160\n", "30 sparseModel.2.1.1.2.1.1.2.1.1.2.1.1.4.weight 163840 8 x 160 x 128\n", "31 sparseModel.2.1.1.2.1.1.2.1.1.2.3.0.weight 256 256\n", "32 sparseModel.2.1.1.2.1.1.2.1.1.2.3.0.bias 256 256\n", "33 sparseModel.2.1.1.2.1.1.2.1.1.2.3.1.weight 884736 27 x 256 x 128\n", "34 sparseModel.2.1.1.2.1.1.2.1.1.3.weight 128 128\n", "35 sparseModel.2.1.1.2.1.1.2.1.1.3.bias 128 128\n", "36 sparseModel.2.1.1.2.1.1.2.1.1.4.weight 98304 8 x 128 x 96\n", "37 sparseModel.2.1.1.2.1.1.2.3.0.weight 192 192\n", "38 sparseModel.2.1.1.2.1.1.2.3.0.bias 192 192\n", "39 sparseModel.2.1.1.2.1.1.2.3.1.weight 497664 27 x 192 x 96\n", "40 sparseModel.2.1.1.2.1.1.3.weight 96 96\n", "41 sparseModel.2.1.1.2.1.1.3.bias 96 96\n", "42 sparseModel.2.1.1.2.1.1.4.weight 49152 8 x 96 x 64\n", "43 sparseModel.2.1.1.2.3.0.weight 128 128\n", "44 sparseModel.2.1.1.2.3.0.bias 128 128\n", "45 sparseModel.2.1.1.2.3.1.weight 221184 27 x 128 x 64\n", "46 sparseModel.2.1.1.3.weight 64 64\n", "47 sparseModel.2.1.1.3.bias 64 64\n", "48 sparseModel.2.1.1.4.weight 16384 8 x 64 x 32\n", "49 sparseModel.2.3.0.weight 64 64\n", "50 sparseModel.2.3.0.bias 64 64\n", "51 sparseModel.2.3.1.weight 55296 27 x 64 x 32\n", "52 sparseModel.3.weight 32 32\n", "53 sparseModel.3.bias 32 32\n", "54 linear.weight 1600 50 x 32\n", "55 linear.bias 50 50" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "utils.print_trainable_parameters(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learner creation\n", "Learner is core fast.ai class which contains model architecture, databunch and optimizer options and implement train loop and prediction" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "from torch import optim\n", "from functools import partial\n", "from fastai.callbacks.general_sched import TrainingPhase, GeneralScheduler\n", "#from fastai.callbacks.csv_logger import CSVLogger\n", "#from fastai.callbacks.tracker import SaveModelCallback\n", "from fastai.callback import annealing_exp\n", "\n", "\n", "from fastai_sparse.learner import Learner\n", "from fastai_sparse.callbacks import TimeLogger, SaveModelCallback, CSVLogger, CSVLoggerIouByCategory\n", "\n", "from metrics import IouByCategories\n" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model,\n", " opt_func=partial(optim.SGD, momentum=0.9),\n", " wd=1e-4,\n", " true_wd=False,\n", " path=str(Path('results', experiment_name)))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning Rate finder\n", "We use Learning Rate Finder provided by fast.ai library to find the optimal learning rate" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n", "CPU times: user 27.9 s, sys: 6.98 s, total: 34.9 s\n", "Wall time: 28.3 s\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%%time\n", "learn.lr_find(start_lr=1e-5, end_lr=100)\n", "learn.recorder.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train\n", "To visualize the learning process, we specify some additional callbacks " ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "learn.callbacks = []\n", "cb_iou = IouByCategories(learn, len(categories))\n", "learn.callbacks.append(cb_iou)\n", "\n", "learn.callbacks.append(TimeLogger(learn))\n", "learn.callbacks.append(CSVLogger(learn))\n", "learn.callbacks.append(CSVLoggerIouByCategory(learn, cb_iou, categories_names=classes))\n", "\n", "learn.callbacks.append(SaveModelCallback(learn, every='epoch', name='weights', overwrite=True))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 09:16 <p><table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>train_time</th>\n", " <th>valid_time</th>\n", " <th>train_waiou</th>\n", " <th>valid_waiou</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>2.066369</td>\n", " <td>1.948893</td>\n", " <td>121.718252</td>\n", " <td>63.441886</td>\n", " <td>0.451465</td>\n", " <td>0.518886</td>\n", " <td>03:05</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.672141</td>\n", " <td>1.583488</td>\n", " <td>121.557564</td>\n", " <td>63.407442</td>\n", " <td>0.548420</td>\n", " <td>0.600601</td>\n", " <td>03:05</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.318663</td>\n", " <td>1.290712</td>\n", " <td>122.697799</td>\n", " <td>63.085566</td>\n", " <td>0.616172</td>\n", " <td>0.590881</td>\n", " <td>03:05</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_annealing_exp(3, lr=0.1, lr_decay=4e-2, momentum=0.9)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "# learn.fit_annealing_exp(100, lr=0.1, lr_decay=4e-2, momentum=0.9)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>datatype</th>\n", " <th>average</th>\n", " <th>Airplane</th>\n", " <th>Bag</th>\n", " <th>Cap</th>\n", " <th>Car</th>\n", " <th>Chair</th>\n", " <th>Earphone</th>\n", " <th>Guitar</th>\n", " <th>Knife</th>\n", " <th>Lamp</th>\n", " <th>Laptop</th>\n", " <th>Motorbike</th>\n", " <th>Mug</th>\n", " <th>Pistol</th>\n", " <th>Rocket</th>\n", " <th>Skateboard</th>\n", " <th>Table</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>valid</td>\n", " <td>0.518886</td>\n", " <td>0.379600</td>\n", " <td>0.492501</td>\n", " <td>0.414970</td>\n", " <td>0.227367</td>\n", " <td>0.509858</td>\n", " <td>0.531944</td>\n", " <td>0.484199</td>\n", " <td>0.445833</td>\n", " <td>0.539656</td>\n", " <td>0.344094</td>\n", " <td>0.231674</td>\n", " <td>0.473616</td>\n", " <td>0.243760</td>\n", " <td>0.283587</td>\n", " <td>0.391957</td>\n", " <td>0.704618</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>0.548420</td>\n", " <td>0.416585</td>\n", " <td>0.472786</td>\n", " <td>0.366650</td>\n", " <td>0.240662</td>\n", " <td>0.555655</td>\n", " <td>0.513805</td>\n", " <td>0.504931</td>\n", " <td>0.574740</td>\n", " <td>0.525350</td>\n", " <td>0.424994</td>\n", " <td>0.261024</td>\n", " <td>0.484206</td>\n", " <td>0.273017</td>\n", " <td>0.287875</td>\n", " <td>0.414793</td>\n", " <td>0.719739</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>1</td>\n", " <td>valid</td>\n", " <td>0.600601</td>\n", " <td>0.516312</td>\n", " <td>0.492734</td>\n", " <td>0.398441</td>\n", " <td>0.268659</td>\n", " <td>0.637381</td>\n", " <td>0.516604</td>\n", " <td>0.608734</td>\n", " <td>0.617883</td>\n", " <td>0.553338</td>\n", " <td>0.431588</td>\n", " <td>0.263243</td>\n", " <td>0.545524</td>\n", " <td>0.256332</td>\n", " <td>0.275717</td>\n", " <td>0.462712</td>\n", " <td>0.747984</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>2</td>\n", " <td>train</td>\n", " <td>0.616172</td>\n", " <td>0.543162</td>\n", " <td>0.476210</td>\n", " <td>0.397841</td>\n", " <td>0.248176</td>\n", " <td>0.669318</td>\n", " <td>0.524780</td>\n", " <td>0.638997</td>\n", " <td>0.636003</td>\n", " <td>0.541066</td>\n", " <td>0.411982</td>\n", " <td>0.266059</td>\n", " <td>0.573489</td>\n", " <td>0.281751</td>\n", " <td>0.291917</td>\n", " <td>0.482149</td>\n", " <td>0.755185</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>2</td>\n", " <td>valid</td>\n", " <td>0.590881</td>\n", " <td>0.523460</td>\n", " <td>0.523732</td>\n", " <td>0.392183</td>\n", " <td>0.228041</td>\n", " <td>0.613275</td>\n", " <td>0.552095</td>\n", " <td>0.596156</td>\n", " <td>0.624740</td>\n", " <td>0.534418</td>\n", " <td>0.383885</td>\n", " <td>0.270513</td>\n", " <td>0.589842</td>\n", " <td>0.294795</td>\n", " <td>0.262375</td>\n", " <td>0.492768</td>\n", " <td>0.742103</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch datatype average Airplane Bag Cap Car Chair \\\n", "1 0 valid 0.518886 0.379600 0.492501 0.414970 0.227367 0.509858 \n", "2 1 train 0.548420 0.416585 0.472786 0.366650 0.240662 0.555655 \n", "3 1 valid 0.600601 0.516312 0.492734 0.398441 0.268659 0.637381 \n", "4 2 train 0.616172 0.543162 0.476210 0.397841 0.248176 0.669318 \n", "5 2 valid 0.590881 0.523460 0.523732 0.392183 0.228041 0.613275 \n", "\n", " Earphone Guitar Knife Lamp Laptop Motorbike Mug \\\n", "1 0.531944 0.484199 0.445833 0.539656 0.344094 0.231674 0.473616 \n", "2 0.513805 0.504931 0.574740 0.525350 0.424994 0.261024 0.484206 \n", "3 0.516604 0.608734 0.617883 0.553338 0.431588 0.263243 0.545524 \n", "4 0.524780 0.638997 0.636003 0.541066 0.411982 0.266059 0.573489 \n", "5 0.552095 0.596156 0.624740 0.534418 0.383885 0.270513 0.589842 \n", "\n", " Pistol Rocket Skateboard Table \n", "1 0.243760 0.283587 0.391957 0.704618 \n", "2 0.273017 0.287875 0.414793 0.719739 \n", "3 0.256332 0.275717 0.462712 0.747984 \n", "4 0.281751 0.291917 0.482149 0.755185 \n", "5 0.294795 0.262375 0.492768 0.742103 " ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.find_callback('CSVLoggerIouByCategory').read_logged_file().tail()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>train_time</th>\n", " <th>valid_time</th>\n", " <th>train_waiou</th>\n", " <th>valid_waiou</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>2.066369</td>\n", " <td>1.948893</td>\n", " <td>121.718252</td>\n", " <td>63.441886</td>\n", " <td>0.451465</td>\n", " <td>0.518886</td>\n", " <td>03:05</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.672141</td>\n", " <td>1.583488</td>\n", " <td>121.557564</td>\n", " <td>63.407442</td>\n", " <td>0.548420</td>\n", " <td>0.600601</td>\n", " <td>03:05</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.318663</td>\n", " <td>1.290712</td>\n", " <td>122.697799</td>\n", " <td>63.085566</td>\n", " <td>0.616172</td>\n", " <td>0.590881</td>\n", " <td>03:05</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch train_loss valid_loss train_time valid_time train_waiou \\\n", "0 0 2.066369 1.948893 121.718252 63.441886 0.451465 \n", "1 1 1.672141 1.583488 121.557564 63.407442 0.548420 \n", "2 2 1.318663 1.290712 122.697799 63.085566 0.616172 \n", "\n", " valid_waiou time \n", "0 0.518886 03:05 \n", "1 0.600601 03:05 \n", "2 0.590881 03:05 " ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.find_callback('CSVLogger').read_logged_file().tail()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_losses()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x1152 with 4 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_metrics()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "1124px", "left": "1506px", "top": "370.133px", "width": "243px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 2 }