More about Nodes

This section describes some special cases for Node definitions.

On and Off Graph Nodes

The Node instances we have seen so far are all placed onto the graph. In other words, they are defined within the context of the Project and will have a run method that is executed when the Project runs.

Note

Each of these Node instances is represented by an individual stage in the DVC graph.

In some cases, a Node should provide additional methods but will only be used within other Node instances. Such a Node is called “off-graph” and can be represented by a Python dataclass. They are often used to define an interchangeable model, as illustrated in the example on Scikit-learn Classifier Comparison. Another use case for off-graph Node instances is reusing a Node from another project. If you load a Node via zntrack.from_rev, you can also use it as an off-graph Node.

In other words, off-graph Node instances do not produce any output files when the graph is executed. They are only used as dependencies for on-graph Node instances, which are responsible for creating output files.

Note

Just like on-graph Node definitions, it must be possible to import the dataclass-derived Node. Therefore, it is recommended to place them alongside on-graph Node definitions, e.g., in the same module. If you define them inside main.py, you must ensure that the Project is constructed inside a code block after if __name__ == "__main__": to avoid executing the script when importing the Node.

from dataclasses import dataclass
import zntrack

@dataclass
class Shift:
    shift: float

    def compute(self, input: float) -> float:
        return input + self.shift

@dataclass
class Scale:
    scale: float

    def compute(self, input: float) -> float:
        return input * self.scale

class ManipulateNumber(zntrack.Node):
    number: float = zntrack.params()
    method: Shift | Scale = zntrack.deps()

    result: float = zntrack.outs()

    def run(self) -> None:
        self.result = self.method.compute(self.number)

if __name__ == "__main__":
    project = zntrack.Project()

    # You can define these Nodes anywhere, but
    # to avoid confusion, they should be placed outside the Project context
    shift = Shift(shift=1.0)
    scale = Scale(scale=2.0)

    with project:
        shifted_number = ManipulateNumber(number=1.0, method=shift)
        scaled_number = ManipulateNumber(number=1.0, method=scale)
    project.repro()

Off-graph Node instances can be extended with zntrack.params_path() and zntrack.deps_path() to define parameters and dependencies, which will be connected to the Node they are used in. This can be useful e.g. when defining a method that uses a parameter file or requires a specific file dependency without providing a run method and thus not being a Node itself.

from dataclasses import dataclass
import yaml
import zntrack

@dataclass
class Calculator:
    config_file: str = zntrack.params_path()
    model_path: str = zntrack.deps_path()

    def get_calculator(self):
        with open(self.config_file, "r") as f:
            config = yaml.safe_load(f)
        return func(model=self.model_path, **config)

Warning

Reading files without using the DVCFileSystem in dataclasses will lead to errors when using zntrack.from_rev() with a rev or remote argument.

Always Changed

In some cases, you may want a Node to always run, even if the inputs have not changed. This can be useful when debugging a new Node. In such cases, you can set always_changed=True.

import zntrack.examples

project = zntrack.Project()

with project:
    node = zntrack.examples.ParamsToOuts(params=42, always_changed=True)

project.repro()

Node State

Each Node provides a state attribute to access metadata or the DVCFileSystem. The zntrack.state.NodeStatus() is frozen and read-only.

class zntrack.state.NodeStatus(remote: str | None = None, rev: str | None = None, run_count: int = 0, state: ~zntrack.config.NodeStatusEnum = NodeStatusEnum.CREATED, lazy_evaluation: bool = True, tmp_path: ~pathlib.Path | None = None, node: Node|None = None, plugins: dict[str, ~zntrack.plugins.base.ZnTrackPlugin] = <factory>, group: ~zntrack.group.Group | None = None, run_time: ~datetime.timedelta | None = None, path: ~pathlib.Path = <factory>, lockfile: dict | None = None, fs: ~fsspec.spec.AbstractFileSystem | None = <factory>)[source]

Node status object.

Parameters:
  • remote (str, optional) – The repository remote, e.g. the URL of the git repository.

  • rev (str, optional) – The revision of the repository, e.g. the git commit hash.

  • run_count (int) – How often this Node has been run. Only incremented when the Node is restarted.

  • state (NodeStatusEnum) – The state of the Node.

  • lazy_evaluation (bool) – Whether to load fields lazily.

  • tmp_path (pathlib.Path, optional) – The temporary path when using ‘use_tmp_path’.

  • node (Node, optional) – The Node object.

  • plugins (dict[str, ZnTrackPlugin], optional) – Active plugins. In addition to the default plugins, MLFLow or AIM plugins will be added here.

  • group (Group, optional) – The group of the Node.

  • run_time (datetime.timedelta, optional) – The total run time of the Node.

  • name (str) – The name of the Node.

  • nwd (pathlib.Path) – The node working directory.

  • restarted (bool) – Whether the Node was restarted and has been run at least once before.

  • path (str) – The path to the directory where the zntrack.json file is located.

use_tmp_path(path: Path | None = None) Iterator[Path][source]

Load the data for *_path into a temporary directory.

If you can not use node.state.fs.open you can use this as an alternative. This will load the data into a temporary directory and then delete it afterwards. The respective paths node.*_path will be replaced automatically inside the context manager.

This is only set, if either remote or rev are set. Otherwise, the data will be loaded from the current directory.

Examples

>>> import zntrack
>>> from pathlib import Path
>>>
>>> class MyNode(zntrack.Node):
...     outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt")
...
...     def run(self):
...         self.outs_path.parent.mkdir(parents=True, exist_ok=True)
...         self.outs_path.write_text("Hello World!")
...
...     @property
...     def data(self):
...         with self.state.use_tmp_path():
...             with open(self.outs_path) as f:
...                 return f.read()
...
>>> # build and run the graph and make multiple commits.
>>> # the `use_tmp_path` ensures that the correct version
>>> # of the file is loaded in the temporary directory and
>>> # the `self.outs_path` is updated accordingly.
>>>
>>> zntrack.from_rev("MyNode", rev="HEAD").data
>>> zntrack.from_remote("MyNode", rev="HEAD~1").data

Custom Run Methods

By default, a Node will execute the run method. Sometimes, it is useful to define multiple methods for a single Node with slightly different behavior. This can be achieved by using zntrack.apply().

zntrack.apply(obj: o, method: str) o[source]

Update the default run method of zntrack.Node.

Parameters:
  • obj (zntrack.Node) – The node to copy and update the run method.

  • method (str) – The new method to use instead of the default run method.

Returns:

A new class which uses the new method instead of the default run method.

Return type:

zntrack.Node

Examples

>>> import zntrack
>>> class MyNode(zntrack.Node):
...     outs: str = zntrack.outs()
...
...     def run(self):
...         self.outs = "Hello, World!"
...
...     def my_run(self):
...         self.outs = "Hello, Custom World!"
...
>>> OtherMyNode = zntrack.apply(MyNode, "my_run")
>>> with zntrack.Project() as proj:
...     a = MyNode()
...     b = OtherMyNode()
>>> proj.repro()
>>> a.outs
'Hello, World!'
>>> b.outs
'Hello, Custom World!'

Entry Points

If you are developing a package based on ZnTrack, you can expose your Node definitions to other packages.

You can define one or more groups of Nodes and register them using a function like this:

import zntrack

def nodes() -> list[zntrack.Node]:
    return [
        mypackage.MyNode1,
        mypackage.MyNode2,
    ]

This function should be registered as an entry point in your pyproject.toml:

[project.entry-points."zntrack.nodes"]
mypackage = "mypackage.nodes:nodes"

Each entry represents a group of Nodes. If your Nodes are organized into different categories, you can define multiple entry points accordingly.