{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ac2ff6f8-c61e-4dc7-9e9f-d0ba103540ee", "metadata": {}, "outputs": [], "source": [ "from zntrack import Node, dvc, config, zn" ] }, { "cell_type": "code", "execution_count": 2, "id": "c7c9d484-81b6-4269-b506-c25fd5efe471", "metadata": {}, "outputs": [], "source": [ "config.nb_name = \"06_named_nodes.ipynb\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "d883b41b-61f2-448f-b935-52dec6baaa87", "metadata": { "nbsphinx": "hidden", "tags": [] }, "outputs": [], "source": [ "from zntrack.utils import cwd_temp_dir\n", "\n", "temp_dir = cwd_temp_dir()" ] }, { "cell_type": "code", "execution_count": 4, "id": "02b1bed1-bf12-47bb-8d60-2b40575cdb63", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized empty Git repository in /tmp/tmpcn8yts4z/.git/\r\n", "Initialized DVC repository.\r\n", "\r\n", "You can now commit the changes to git.\r\n", "\r\n", "\u001B[31m+---------------------------------------------------------------------+\r\n", "\u001B[0m\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", "\u001B[31m|\u001B[0m DVC has enabled anonymous aggregate usage analytics. \u001B[31m|\u001B[0m\r\n", "\u001B[31m|\u001B[0m Read the analytics documentation (and how to opt-out) here: \u001B[31m|\u001B[0m\r\n", "\u001B[31m|\u001B[0m <\u001B[36mhttps://dvc.org/doc/user-guide/analytics\u001B[39m> \u001B[31m|\u001B[0m\r\n", "\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", "\u001B[31m+---------------------------------------------------------------------+\r\n", "\u001B[0m\r\n", "\u001B[33mWhat's next?\u001B[39m\r\n", "\u001B[33m------------\u001B[39m\r\n", "- Check out the documentation: <\u001B[36mhttps://dvc.org/doc\u001B[39m>\r\n", "- Get help and share ideas: <\u001B[36mhttps://dvc.org/chat\u001B[39m>\r\n", "- Star us on GitHub: <\u001B[36mhttps://github.com/iterative/dvc\u001B[39m>\r\n", "\u001B[0m" ] } ], "source": [ "!git init\n", "!dvc init" ] }, { "cell_type": "markdown", "id": "d20eda47-4757-4de5-b920-1db19ee0ef37", "metadata": {}, "source": [ "# Named Nodes\n", "Named Nodes allow us to use the same Node multiple times in a single graph at e.g. different steps. Therefore, we can pass a `name` argument to the `__init__` of our Node.\n", "\n", "
Notice that this is one of only very few scenarios where we want to pass an argument directly to the `__init__`" ] }, { "cell_type": "code", "execution_count": 5, "id": "6a704798-b67a-4863-b96c-60cdeb65d559", "metadata": {}, "outputs": [], "source": [ "class HelloWorld(Node):\n", " inputs = zn.params()\n", " outputs = zn.outs()\n", "\n", " def __init__(self, inputs=None, **kwargs):\n", " super().__init__(**kwargs)\n", " self.inputs = inputs\n", "\n", " def run(self):\n", " self.outputs = self.inputs" ] }, { "cell_type": "code", "execution_count": 6, "id": "2a04b805-bbda-49ca-97ef-efb761f67165", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-03-16 13:28:59,293 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n", "Submit issues to https://github.com/zincware/ZnTrack.\n", "2022-03-16 13:29:03,070 (WARNING): Running DVC command: 'dvc run -n HelloWorld ...'\n", "2022-03-16 13:29:10,857 (WARNING): Running DVC command: 'dvc run -n Test01 ...'\n", "2022-03-16 13:29:18,410 (WARNING): Running DVC command: 'dvc run -n Test02 ...'\n" ] } ], "source": [ "HelloWorld(inputs=3).write_graph(no_exec=False)\n", "HelloWorld(name=\"Test01\", inputs=17).write_graph(no_exec=False)\n", "HelloWorld(name=\"Test02\", inputs=42).write_graph(no_exec=False)" ] }, { "cell_type": "code", "execution_count": 7, "id": "2b99fa5a-bc72-4834-8f75-98d79be58e4c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+ \r\n", "| HelloWorld | \r\n", "+------------+ \r\n", "+--------+ \r\n", "| Test01 | \r\n", "+--------+ \r\n", "+--------+ \r\n", "| Test02 | \r\n", "+--------+ \r\n", "\u001B[0m" ] } ], "source": [ "!dvc dag" ] }, { "cell_type": "markdown", "id": "a8cbb47f-8a5b-4f56-9408-4b365e487ba8", "metadata": {}, "source": [ "We can now also build a Node that depends on multiple of the same Nodes" ] }, { "cell_type": "code", "execution_count": 8, "id": "dc84f265-bffb-4e32-abf3-de7483a5d51d", "metadata": {}, "outputs": [], "source": [ "class FindMaximum(Node):\n", " deps = zn.deps(\n", " [\n", " HelloWorld.load(),\n", " HelloWorld.load(name=\"Test01\"),\n", " HelloWorld.load(name=\"Test02\"),\n", " ]\n", " )\n", " maximum = zn.outs()\n", "\n", " def run(self):\n", " self.maximum = 0\n", " for node in self.deps:\n", " if node.outputs > self.maximum:\n", " self.maximum = node.outputs\n", " print(f\"New maximum found {node.outputs}.\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "7612adfd-b00b-4026-a586-def7bf37f91f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-03-16 13:29:28,486 (WARNING): Running DVC command: 'dvc run -n FindMaximum ...'\n" ] } ], "source": [ "FindMaximum().write_graph(run=True)" ] }, { "cell_type": "code", "execution_count": 10, "id": "73562b4c-8240-4e2d-807f-b328aa315d6f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+ +--------+ +--------+ \r\n", "| HelloWorld | | Test01 | | Test02 | \r\n", "+------------+** +--------+ ***+--------+ \r\n", " *** * *** \r\n", " **** * **** \r\n", " ** * ** \r\n", " +-------------+ \r\n", " | FindMaximum | \r\n", " +-------------+ \r\n", "\u001B[0m" ] } ], "source": [ "!dvc dag" ] }, { "cell_type": "markdown", "id": "97067bfd-a085-4125-8adf-29e544c7f18a", "metadata": {}, "source": [ "Using this combined Node we can e.g. find the maximum of the generated values." ] }, { "cell_type": "code", "execution_count": 11, "id": "a1455d7b-d53f-4d36-a539-cb128e316e95", "metadata": {}, "outputs": [ { "data": { "text/plain": "42" }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "FindMaximum.load().maximum" ] }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New maximum found 3.\n", "New maximum found 17.\n", "New maximum found 42.\n" ] } ], "source": [ "# Running it manually to highlight the print statements\n", "FindMaximum.load().run()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "In addition to the introduced classmethod `Node.load(name=\"nodename\")` it is also possible to use `Node[\"nodename\"]`. Note that this only works for `Node[\"nodename\"]` and not for `Node()[\"nodename\"]`. Using this we can also write the following:" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "17\n", "Test01\n" ] } ], "source": [ "print(HelloWorld[\"Test01\"].outputs)\n", "print(HelloWorld[\"Test01\"].node_name)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "this is equivalent to the classmethod `load()`. It is also possible to pass a dictionary as kwargs which will be passed to `load(**kwargs)`." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "42\n", "Test02\n", "42\n" ] } ], "source": [ "print(HelloWorld.load(\"Test02\").outputs)\n", "print(HelloWorld.load(\"Test02\").node_name)\n", "print(HelloWorld[{\"name\": \"Test02\"}].outputs)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 15, "id": "7b59179c-0b15-45a6-97b1-b2d6cc1ebce3", "metadata": { "nbsphinx": "hidden", "tags": [] }, "outputs": [], "source": [ "temp_dir.cleanup()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }