{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch-Ignite playground\n", "\n", "We create below a dummy trainer and evaluator for quick experimentations. We added `print` command at different events to show how things are working. There is no meaningful training logic in the code to run things fast. This setup can be easily modified to test other user ideas.\n", "\n", "\n", "Optionally, to install the latest version:\n", "```\n", "!pip install --upgrade pytorch-ignite\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "!pip install --upgrade pytorch-ignite\n", "\n", "import torch\n", "import ignite\n", "print(torch.__version__, ignite.__version__)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "### Basic setup with various events filtering use-cases" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from ignite.engine import Engine, Events\n", "from ignite.utils import setup_logger, logging\n", "\n", "\n", "train_data = range(10)\n", "eval_data = range(4)\n", "max_epochs = 5\n", "\n", "\n", "def train_step(engine, batch):\n", " print(f\"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}\", flush=True)\n", "\n", "trainer = Engine(train_step)\n", "\n", "# Enable trainer logger for a debug mode\n", "# trainer.logger = setup_logger(\"trainer\", level=logging.DEBUG)\n", "\n", "evaluator = Engine(lambda e, b: None)\n", "\n", "\n", "@trainer.on(Events.EPOCH_COMPLETED(every=2))\n", "def run_validation():\n", " print(f\"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - run validation\", flush=True)\n", " evaluator.run(eval_data)\n", "\n", "\n", "@trainer.on(Events.ITERATION_COMPLETED(every=7))\n", "def log_events_filtering__every():\n", " print(f\"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - calling log_events_filtering__every\", flush=True)\n", "\n", " \n", "@trainer.on(Events.EPOCH_COMPLETED(once=3))\n", "def log_events_filtering__once():\n", " print(f\"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - calling log_events_filtering__once\", flush=True)\n", "\n", "\n", "def custom_event_filter(engine, event):\n", " if trainer.state.epoch == 2 and event in (1, 3):\n", " return True\n", " return False\n", " \n", "\n", "@evaluator.on(Events.ITERATION_COMPLETED(event_filter=custom_event_filter))\n", "def log_events_filtering__event_filter():\n", " print(f\"{trainer.state.epoch} / {trainer.state.max_epochs} | {evaluator.state.iteration} - calling log_events_filtering__event_filter\", flush=True)\n", " \n", "\n", "trainer.run(train_data, max_epochs=max_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Distributed run with 2 processes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "import time\n", "import ignite.distributed as idist\n", "\n", "from ignite.engine import Engine, Events\n", "from ignite.utils import setup_logger, logging\n", "\n", "\n", "def pprint(*args, **kwargs):\n", " rank = idist.get_rank()\n", " time.sleep(rank * 0.1)\n", " print(f\"Rank {rank}:\", end=\" \")\n", " print(*args, **kwargs)\n", "\n", "\n", "def run(local_rank): \n", " rank = idist.get_rank() \n", " torch.manual_seed(12 + rank)\n", " \n", " train_data = range(10)\n", " eval_data = range(4)\n", " max_epochs = 5\n", "\n", " def train_step(engine, batch):\n", " pprint(f\"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}\", flush=True)\n", "\n", " trainer = Engine(train_step)\n", " \n", " @trainer.on(Events.EPOCH_COMPLETED)\n", " def sync():\n", " idist.barrier()\n", "\n", " trainer.run(train_data, max_epochs=max_epochs) \n", " \n", " \n", "idist.spawn(\"gloo\", run, (), nproc_per_node=2) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }