{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" }, "tags": [] }, "source": [ "# Custom ZnTrackOptions\n", "\n", "ZnTrack allows you to create a custom ZnTrackOption similar to `zn.outs`.\n", "ZnTrack tries to handle some standard types automatically within the `zn.outs` option, but it can be useful to write custom ones.\n", "In the following example we use [Atomic Simulation Environment](https://wiki.fysik.dtu.dk/ase/index.html) to store / load objects to a custom datafile." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from zntrack import config\n", "\n", "# When using ZnTrack we can write our code inside a Jupyter notebook.\n", "# We can make use of this functionality by setting the `nb_name` config as follows:\n", "config.nb_name = \"08_custom_zntrackoptions.ipynb\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [] }, "outputs": [], "source": [ "from zntrack.utils import cwd_temp_dir\n", "\n", "temp_dir = cwd_temp_dir()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "nbsphinx": "hidden", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized empty Git repository in /tmp/tmp5twr4kp_/.git/\n", "Initialized DVC repository.\n", "\n", "You can now commit the changes to git.\n", "\n", "+---------------------------------------------------------------------+\n", "| |\n", "| DVC has enabled anonymous aggregate usage analytics. |\n", "| Read the analytics documentation (and how to opt-out) here: |\n", "| |\n", "| |\n", "+---------------------------------------------------------------------+\n", "\n", "What's next?\n", "------------\n", "- Check out the documentation: \n", "- Get help and share ideas: \n", "- Star us on GitHub: \n" ] } ], "source": [ "!git init\n", "!dvc init" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use the `ZnTrackOption` to build our new custom options." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import zntrack\n", "import ase.db\n", "import ase.io\n", "import tqdm" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "class Atoms(zntrack.Field):\n", " # we will save the file as dvc run --outs\n", " dvc_option = \"outs\"\n", " group = zntrack.FieldGroup.RESULT # you can choose from RESULT or PARAMETER\n", "\n", " def get_files(self, instance) -> list:\n", " \"\"\"Define the filename that is passed to dvc (used if tracked=True)\"\"\"\n", " # self.name is the name of the class attribute we use for this database\n", " return [instance.nwd / f\"{self.name}.db\"]\n", "\n", " def save(self, instance):\n", " \"\"\"Save the values to file\"\"\"\n", " # we gather the actual values using __get__\n", " atoms = getattr(instance, self.name)\n", " # get the file name\n", " file = self.get_files(instance)[0]\n", " # save the data to the file\n", " with ase.db.connect(file) as db:\n", " for atom in tqdm.tqdm(atoms, ncols=70, desc=f\"Writing atoms to {file}\"):\n", " db.write(atom, group=instance.name)\n", "\n", " def get_data(self, instance):\n", " \"\"\"Load data with ase.db.connect from file\"\"\"\n", " # get the file name\n", " file = self.get_files(instance)[0]\n", " # load the data\n", " atoms = []\n", " with ase.db.connect(file) as db:\n", " for row in tqdm.tqdm(\n", " db.select(), ncols=70, desc=f\"Loading atoms from {file}\"\n", " ):\n", " atoms.append(row.toatoms())\n", " # return the data so it can be saved in __dict__\n", " return atoms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have defined our custom ZnTrackOption we can use it as follows." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class AtomsClass(zntrack.Node):\n", " atoms = Atoms()\n", "\n", " def run(self):\n", " self.atoms = [ase.Atoms(\"N2\", positions=[[0, 0, -1], [0, 0, 1]])]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Running DVC command: 'stage add --name AtomsClass --force ...'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Creating 'dvc.yaml'\n", "Adding stage 'AtomsClass' in 'dvc.yaml'\n", "\n", "To track the changes with git, run:\n", "\n", "\tgit add dvc.yaml nodes/AtomsClass/.gitignore\n", "\n", "To enable auto staging, run:\n", "\n", "\tdvc config core.autostage true\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Jupyter support is an experimental feature! Please save your notebook before running this command!\n", "Submit issues to https://github.com/zincware/ZnTrack.\n", "[NbConvertApp] Converting notebook 08_custom_zntrackoptions.ipynb to script\n", "[NbConvertApp] Writing 2881 bytes to 08_custom_zntrackoptions.py\n" ] } ], "source": [ "with zntrack.Project() as project:\n", " node = AtomsClass()\n", "project.run(repro=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running stage 'AtomsClass': \n", "> zntrack run src.AtomsClass.AtomsClass --name AtomsClass\n", "Loading atoms from nodes/AtomsClass/atoms.db: 0it [00:00, ?it/s]\n", "Writing atoms to nodes/AtomsClass/atoms.db: 100%|█| 1/1 [00:00<00:00, \n", "Generating lock file 'dvc.lock' \n", "Updating lock file 'dvc.lock'\n", "\n", "To track the changes with git, run:\n", "\n", "\tgit add dvc.lock\n", "\n", "To enable auto staging, run:\n", "\n", "\tdvc config core.autostage true\n", "Use `dvc push` to send your updates to remote storage.\n" ] } ], "source": [ "!dvc repro" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 1151.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[Atoms(symbols='N2', pbc=False)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading atoms from nodes/AtomsClass/atoms.db: 1it [00:00, 3231.36it/s]\n" ] }, { "data": { "text/plain": [ "[Atoms(symbols='N2', pbc=False)]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "node.load()\n", "print(node.atoms)\n", "# or\n", "AtomsClass.from_rev().atoms" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "temp_dir.cleanup()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 4 }