Custom ZnTrackOptions#

ZnTrack allows you to create a custom ZnTrackOption similar to zn.outs. ZnTrack tries to handle some standard types automatically within the zn.outs option, but it can be useful to write custom ones. In the following example we use Atomic Simulation Environment to store / load objects to a custom datafile.

[1]:
from zntrack import config

# When using ZnTrack we can write our code inside a Jupyter notebook.
# We can make use of this functionality by setting the `nb_name` config as follows:
config.nb_name = "08_custom_zntrackoptions.ipynb"

We will use the ZnTrackOption to build our new custom options.

[4]:
import zntrack
import ase.db
import ase.io
import tqdm
[5]:
class Atoms(zntrack.Field):
    # we will save the file as dvc run --outs
    dvc_option = "outs"
    group = zntrack.FieldGroup.RESULT  # you can choose from RESULT or PARAMETER

    def get_files(self, instance) -> list:
        """Define the filename that is passed to dvc (used if tracked=True)"""
        # self.name is the name of the class attribute we use for this database
        return [instance.nwd / f"{self.name}.db"]

    def save(self, instance):
        """Save the values to file"""
        # we gather the actual values using __get__
        atoms = getattr(instance, self.name)
        # get the file name
        file = self.get_files(instance)[0]
        # save the data to the file
        with ase.db.connect(file) as db:
            for atom in tqdm.tqdm(atoms, ncols=70, desc=f"Writing atoms to {file}"):
                db.write(atom, group=instance.name)

    def get_data(self, instance):
        """Load data with ase.db.connect from file"""
        # get the file name
        file = self.get_files(instance)[0]
        # load the data
        atoms = []
        with ase.db.connect(file) as db:
            for row in tqdm.tqdm(
                db.select(), ncols=70, desc=f"Loading atoms from {file}"
            ):
                atoms.append(row.toatoms())
        # return the data so it can be saved in __dict__
        return atoms

Now that we have defined our custom ZnTrackOption we can use it as follows.

[6]:
class AtomsClass(zntrack.Node):
    atoms = Atoms()

    def run(self):
        self.atoms = [ase.Atoms("N2", positions=[[0, 0, -1], [0, 0, 1]])]
[7]:
with zntrack.Project() as project:
    node = AtomsClass()
project.run(repro=False)
Running DVC command: 'stage add --name AtomsClass --force ...'
Creating 'dvc.yaml'
Adding stage 'AtomsClass' in 'dvc.yaml'

To track the changes with git, run:

        git add dvc.yaml nodes/AtomsClass/.gitignore

To enable auto staging, run:

        dvc config core.autostage true
Jupyter support is an experimental feature! Please save your notebook before running this command!
Submit issues to https://github.com/zincware/ZnTrack.
[NbConvertApp] Converting notebook 08_custom_zntrackoptions.ipynb to script
[NbConvertApp] Writing 2881 bytes to 08_custom_zntrackoptions.py
[8]:
!dvc repro
Running stage 'AtomsClass':
> zntrack run src.AtomsClass.AtomsClass --name AtomsClass
Loading atoms from nodes/AtomsClass/atoms.db: 0it [00:00, ?it/s]
Writing atoms to nodes/AtomsClass/atoms.db: 100%|█| 1/1 [00:00<00:00,
Generating lock file 'dvc.lock'
Updating lock file 'dvc.lock'

To track the changes with git, run:

        git add dvc.lock

To enable auto staging, run:

        dvc config core.autostage true
Use `dvc push` to send your updates to remote storage.
[9]:
node.load()
print(node.atoms)
# or
AtomsClass.from_rev().atoms
Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 1151.02it/s]
[Atoms(symbols='N2', pbc=False)]
Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 3231.36it/s]
[9]:
[Atoms(symbols='N2', pbc=False)]
[10]:
temp_dir.cleanup()