{ "cells": [ { "cell_type": "markdown", "id": "55c7fada", "metadata": {}, "source": [ "\n", "# How to work with data iterators\n", "\n", "When the data provider for training or validation is an iterator\n", "(infinite or finite with known or unknown size), here are some basic\n", "examples of how to setup trainer or evaluator.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "e97045a5", "metadata": {}, "source": [ "## Infinite iterator for training\n", "\n", "Let’s use an infinite data iterator as training dataflow" ] }, { "cell_type": "code", "execution_count": 1, "id": "5b5f175a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/3 : 1 - 63.862\n", "1/3 : 2 - 64.042\n", "1/3 : 3 - 63.936\n", "1/3 : 4 - 64.141\n", "1/3 : 5 - 64.767\n", "2/3 : 6 - 63.791\n", "2/3 : 7 - 64.565\n", "2/3 : 8 - 63.602\n", "2/3 : 9 - 63.995\n", "2/3 : 10 - 63.943\n", "3/3 : 11 - 63.831\n", "3/3 : 12 - 64.276\n", "3/3 : 13 - 64.148\n", "3/3 : 14 - 63.920\n", "3/3 : 15 - 64.226\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 15\n", "\tepoch: 3\n", "\tepoch_length: 5\n", "\tmax_epochs: 3\n", "\toutput: \n", "\tbatch: \n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "def infinite_iterator(batch_size):\n", " while True:\n", " batch = torch.rand(batch_size, 3, 32, 32)\n", " yield batch\n", "\n", "def train_step(trainer, batch):\n", " # ...\n", " s = trainer.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}\"\n", " )\n", "\n", "trainer = Engine(train_step)\n", "\n", "# We need to specify epoch_length to define the epoch\n", "trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3)" ] }, { "cell_type": "markdown", "id": "a755b048", "metadata": {}, "source": [ "If we do not specify **epoch_length**, we can stop the training explicitly by calling [`terminate()`](https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine). In this case, there will be only a single epoch defined." ] }, { "cell_type": "code", "execution_count": 2, "id": "d48531dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 : 1 - 63.862\n", "1/1 : 2 - 64.042\n", "1/1 : 3 - 63.936\n", "1/1 : 4 - 64.141\n", "1/1 : 5 - 64.767\n", "1/1 : 6 - 63.791\n", "1/1 : 7 - 64.565\n", "1/1 : 8 - 63.602\n", "1/1 : 9 - 63.995\n", "1/1 : 10 - 63.943\n", "1/1 : 11 - 63.831\n", "1/1 : 12 - 64.276\n", "1/1 : 13 - 64.148\n", "1/1 : 14 - 63.920\n", "1/1 : 15 - 64.226\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 15\n", "\tepoch: 1\n", "\tepoch_length: \n", "\tmax_epochs: 1\n", "\toutput: \n", "\tbatch: \n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "def infinite_iterator(batch_size):\n", " while True:\n", " batch = torch.rand(batch_size, 3, 32, 32)\n", " yield batch\n", "\n", "def train_step(trainer, batch):\n", " # ...\n", " s = trainer.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}\"\n", " )\n", "\n", "trainer = Engine(train_step)\n", "\n", "@trainer.on(Events.ITERATION_COMPLETED(once=15))\n", "def stop_training():\n", " trainer.terminate()\n", "\n", "trainer.run(infinite_iterator(4))" ] }, { "cell_type": "markdown", "id": "30d63d14", "metadata": {}, "source": [ "Same code can be used for validating models." ] }, { "cell_type": "markdown", "id": "37190708", "metadata": {}, "source": [ "## Finite iterator with unknown length\n", "\n", "Let's use a finite data iterator but with unknown length (for user). In\n", "case of training, we would like to perform several passes over the\n", "dataflow and thus we need to restart the data iterator when it is\n", "exhausted. In the code, we do not specify `epoch_length` which will be automatically\n", "determined." ] }, { "cell_type": "code", "execution_count": 3, "id": "199087b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/5 : 1 - 0.000\n", "1/5 : 2 - 1.000\n", "1/5 : 3 - 2.000\n", "1/5 : 4 - 3.000\n", "1/5 : 5 - 4.000\n", "1/5 : 6 - 5.000\n", "1/5 : 7 - 6.000\n", "1/5 : 8 - 7.000\n", "1/5 : 9 - 8.000\n", "1/5 : 10 - 9.000\n", "1/5 : 11 - 10.000\n", "2/5 : 12 - 0.000\n", "2/5 : 13 - 1.000\n", "2/5 : 14 - 2.000\n", "2/5 : 15 - 3.000\n", "2/5 : 16 - 4.000\n", "2/5 : 17 - 5.000\n", "2/5 : 18 - 6.000\n", "2/5 : 19 - 7.000\n", "2/5 : 20 - 8.000\n", "2/5 : 21 - 9.000\n", "2/5 : 22 - 10.000\n", "3/5 : 23 - 0.000\n", "3/5 : 24 - 1.000\n", "3/5 : 25 - 2.000\n", "3/5 : 26 - 3.000\n", "3/5 : 27 - 4.000\n", "3/5 : 28 - 5.000\n", "3/5 : 29 - 6.000\n", "3/5 : 30 - 7.000\n", "3/5 : 31 - 8.000\n", "3/5 : 32 - 9.000\n", "3/5 : 33 - 10.000\n", "4/5 : 34 - 0.000\n", "4/5 : 35 - 1.000\n", "4/5 : 36 - 2.000\n", "4/5 : 37 - 3.000\n", "4/5 : 38 - 4.000\n", "4/5 : 39 - 5.000\n", "4/5 : 40 - 6.000\n", "4/5 : 41 - 7.000\n", "4/5 : 42 - 8.000\n", "4/5 : 43 - 9.000\n", "4/5 : 44 - 10.000\n", "5/5 : 45 - 0.000\n", "5/5 : 46 - 1.000\n", "5/5 : 47 - 2.000\n", "5/5 : 48 - 3.000\n", "5/5 : 49 - 4.000\n", "5/5 : 50 - 5.000\n", "5/5 : 51 - 6.000\n", "5/5 : 52 - 7.000\n", "5/5 : 53 - 8.000\n", "5/5 : 54 - 9.000\n", "5/5 : 55 - 10.000\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 55\n", "\tepoch: 5\n", "\tepoch_length: 11\n", "\tmax_epochs: 5\n", "\toutput: \n", "\tbatch: 10\n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "def finite_unk_size_data_iter():\n", " for i in range(11):\n", " yield i\n", "\n", "def train_step(trainer, batch):\n", " # ...\n", " s = trainer.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}\"\n", " )\n", "\n", "trainer = Engine(train_step)\n", "\n", "@trainer.on(Events.DATALOADER_STOP_ITERATION)\n", "def restart_iter():\n", " trainer.state.dataloader = finite_unk_size_data_iter()\n", "\n", "data_iter = finite_unk_size_data_iter()\n", "trainer.run(data_iter, max_epochs=5)" ] }, { "cell_type": "markdown", "id": "ee068ac8", "metadata": {}, "source": [ "In case of validation, the code is simply" ] }, { "cell_type": "code", "execution_count": 4, "id": "beae6490", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 : 1 - 0.000\n", "1/1 : 2 - 1.000\n", "1/1 : 3 - 2.000\n", "1/1 : 4 - 3.000\n", "1/1 : 5 - 4.000\n", "1/1 : 6 - 5.000\n", "1/1 : 7 - 6.000\n", "1/1 : 8 - 7.000\n", "1/1 : 9 - 8.000\n", "1/1 : 10 - 9.000\n", "1/1 : 11 - 10.000\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 11\n", "\tepoch: 1\n", "\tepoch_length: 11\n", "\tmax_epochs: 1\n", "\toutput: \n", "\tbatch: \n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "def finite_unk_size_data_iter():\n", " for i in range(11):\n", " yield i\n", "\n", "def val_step(evaluator, batch):\n", " # ...\n", " s = evaluator.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}\"\n", " )\n", "\n", "evaluator = Engine(val_step)\n", "\n", "data_iter = finite_unk_size_data_iter()\n", "evaluator.run(data_iter)" ] }, { "cell_type": "markdown", "id": "5d1abaa7", "metadata": {}, "source": [ "## Finite iterator with known length\n", "\n", "Let's use a finite data iterator with known size for training or validation. If we need to restart the data iterator, we can do this either as in case of unknown size by attaching the restart handler on `@trainer.on(Events.DATALOADER_STOP_ITERATION)`, but here we will do this explicitly on iteration:" ] }, { "cell_type": "code", "execution_count": 5, "id": "a7f519ac", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/5 : 1 - 0.000\n", "1/5 : 2 - 1.000\n", "1/5 : 3 - 2.000\n", "1/5 : 4 - 3.000\n", "1/5 : 5 - 4.000\n", "1/5 : 6 - 5.000\n", "1/5 : 7 - 6.000\n", "1/5 : 8 - 7.000\n", "1/5 : 9 - 8.000\n", "1/5 : 10 - 9.000\n", "1/5 : 11 - 10.000\n", "2/5 : 12 - 0.000\n", "2/5 : 13 - 1.000\n", "2/5 : 14 - 2.000\n", "2/5 : 15 - 3.000\n", "2/5 : 16 - 4.000\n", "2/5 : 17 - 5.000\n", "2/5 : 18 - 6.000\n", "2/5 : 19 - 7.000\n", "2/5 : 20 - 8.000\n", "2/5 : 21 - 9.000\n", "2/5 : 22 - 10.000\n", "3/5 : 23 - 0.000\n", "3/5 : 24 - 1.000\n", "3/5 : 25 - 2.000\n", "3/5 : 26 - 3.000\n", "3/5 : 27 - 4.000\n", "3/5 : 28 - 5.000\n", "3/5 : 29 - 6.000\n", "3/5 : 30 - 7.000\n", "3/5 : 31 - 8.000\n", "3/5 : 32 - 9.000\n", "3/5 : 33 - 10.000\n", "4/5 : 34 - 0.000\n", "4/5 : 35 - 1.000\n", "4/5 : 36 - 2.000\n", "4/5 : 37 - 3.000\n", "4/5 : 38 - 4.000\n", "4/5 : 39 - 5.000\n", "4/5 : 40 - 6.000\n", "4/5 : 41 - 7.000\n", "4/5 : 42 - 8.000\n", "4/5 : 43 - 9.000\n", "4/5 : 44 - 10.000\n", "5/5 : 45 - 0.000\n", "5/5 : 46 - 1.000\n", "5/5 : 47 - 2.000\n", "5/5 : 48 - 3.000\n", "5/5 : 49 - 4.000\n", "5/5 : 50 - 5.000\n", "5/5 : 51 - 6.000\n", "5/5 : 52 - 7.000\n", "5/5 : 53 - 8.000\n", "5/5 : 54 - 9.000\n", "5/5 : 55 - 10.000\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 55\n", "\tepoch: 5\n", "\tepoch_length: 11\n", "\tmax_epochs: 5\n", "\toutput: \n", "\tbatch: 10\n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "size = 11\n", "\n", "def finite_size_data_iter(size):\n", " for i in range(size):\n", " yield i\n", "\n", "def train_step(trainer, batch):\n", " # ...\n", " s = trainer.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}\"\n", " )\n", "\n", "trainer = Engine(train_step)\n", "\n", "@trainer.on(Events.ITERATION_COMPLETED(every=size))\n", "def restart_iter():\n", " trainer.state.dataloader = finite_size_data_iter(size)\n", "\n", "data_iter = finite_size_data_iter(size)\n", "trainer.run(data_iter, max_epochs=5)" ] }, { "cell_type": "markdown", "id": "a518b014", "metadata": {}, "source": [ "In case of validation, the code is simply" ] }, { "cell_type": "code", "execution_count": 6, "id": "d1402c18", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 : 1 - 0.000\n", "1/1 : 2 - 1.000\n", "1/1 : 3 - 2.000\n", "1/1 : 4 - 3.000\n", "1/1 : 5 - 4.000\n", "1/1 : 6 - 5.000\n", "1/1 : 7 - 6.000\n", "1/1 : 8 - 7.000\n", "1/1 : 9 - 8.000\n", "1/1 : 10 - 9.000\n", "1/1 : 11 - 10.000\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 11\n", "\tepoch: 1\n", "\tepoch_length: 11\n", "\tmax_epochs: 1\n", "\toutput: \n", "\tbatch: \n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: \n", "\ttimes: " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from ignite.engine import Engine, Events\n", "\n", "torch.manual_seed(12)\n", "\n", "size = 11\n", "\n", "def finite_size_data_iter(size):\n", " for i in range(size):\n", " yield i\n", "\n", "def val_step(evaluator, batch):\n", " # ...\n", " s = evaluator.state\n", " print(\n", " f\"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}\"\n", " )\n", "\n", "evaluator = Engine(val_step)\n", "\n", "data_iter = finite_size_data_iter(size)\n", "evaluator.run(data_iter)" ] } ], "metadata": { "interpreter": { "hash": "668c1b3fdfcad7da09e9c177fb24f18a657bbc5f55005750960a78843b3807f7" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }