In [1]:
from zntrack import Node, dvc, config, zn

In [2]:
config.nb_name = "06_named_nodes.ipynb"

In [3]:
from zntrack.utils import cwd_temp_dir

temp_dir = cwd_temp_dir()

In [4]:
!git init
!dvc init

Initialized empty Git repository in /tmp/tmpcn8yts4z/.git/
Initialized DVC repository.

You can now commit the changes to git.

[31m+---------------------------------------------------------------------+
[0m[31m|[0m                                                                     [31m|[0m
[31m|[0m        DVC has enabled anonymous aggregate usage analytics.         [31m|[0m
[31m|[0m     Read the analytics documentation (and how to opt-out) here:     [31m|[0m
[31m|[0m             <[36mhttps://dvc.org/doc/user-guide/analytics[39m>              [31m|[0m
[31m|[0m                                                                     [31m|[0m
[31m+---------------------------------------------------------------------+
[0m
[33mWhat's next?[39m
[33m------------[39m
- Check out the documentation: <[36mhttps://dvc.org/doc[39m>
- Get help and share ideas: <[36mhttps://dvc.org/chat[39m>
- Star us on GitHub: <[36mhttps://github.com/iterative/dvc[3

# Named Nodes
Named Nodes allow us to use the same Node multiple times in a single graph at e.g. different steps. Therefore, we can pass a `name` argument to the `__init__` of our Node.

<blockquote>Notice that this is one of only very few scenarios where we want to pass an argument directly to the `__init__`</blockquote>

In [5]:
class HelloWorld(Node):
    inputs = zn.params()
    outputs = zn.outs()

    def __init__(self, inputs=None, **kwargs):
        super().__init__(**kwargs)
        self.inputs = inputs

    def run(self):
        self.outputs = self.inputs

In [6]:
HelloWorld(inputs=3).write_graph(no_exec=False)
HelloWorld(name="Test01", inputs=17).write_graph(no_exec=False)
HelloWorld(name="Test02", inputs=42).write_graph(no_exec=False)

Submit issues to https://github.com/zincware/ZnTrack.


In [7]:
!dvc dag

+------------+ 
| HelloWorld | 
+------------+ 
+--------+ 
| Test01 | 
+--------+ 
+--------+ 
| Test02 | 
+--------+ 
[0m

We can now also build a Node that depends on multiple of the same Nodes

In [8]:
class FindMaximum(Node):
    deps = zn.deps(
        [
            HelloWorld.load(),
            HelloWorld.load(name="Test01"),
            HelloWorld.load(name="Test02"),
        ]
    )
    maximum = zn.outs()

    def run(self):
        self.maximum = 0
        for node in self.deps:
            if node.outputs > self.maximum:
                self.maximum = node.outputs
                print(f"New maximum found {node.outputs}.")

In [9]:
FindMaximum().write_graph(run=True)



In [10]:
!dvc dag

+------------+          +--------+          +--------+ 
| HelloWorld |          | Test01 |          | Test02 | 
+------------+**        +--------+       ***+--------+ 
                ***          *        ***              
                   ****     *     ****                 
                       **   *   **                     
                    +-------------+                    
                    | FindMaximum |                    
                    +-------------+                    
[0m

Using this combined Node we can e.g. find the maximum of the generated values.

In [11]:
FindMaximum.load().maximum

42

In [12]:
# Running it manually to highlight the print statements
FindMaximum.load().run()

New maximum found 3.
New maximum found 17.
New maximum found 42.


In addition to the introduced classmethod `Node.load(name="nodename")` it is also possible to use `Node["nodename"]`. Note that this only works for `Node["nodename"]` and not for `Node()["nodename"]`. Using this we can also write the following:

In [13]:
print(HelloWorld["Test01"].outputs)
print(HelloWorld["Test01"].node_name)

17
Test01


this is equivalent to the classmethod `load()`. It is also possible to pass a dictionary as kwargs which will be passed to `load(**kwargs)`.

In [14]:
print(HelloWorld.load("Test02").outputs)
print(HelloWorld.load("Test02").node_name)
print(HelloWorld[{"name": "Test02"}].outputs)

42
Test02
42


In [15]:
temp_dir.cleanup()