{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0jj2MOXcL9Eh" }, "source": [ "# Using the Refactored CheckpointManager API" ] }, { "cell_type": "markdown", "metadata": { "id": "3hFuQ-97OYLl" }, "source": [ "As of `orbax-checkpoint-0.5.0`, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with `CheckpointManager`. This page shows a side-by-side comparison of the old and new APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "wZKtrmVojffN" }, "source": [ "**The legacy APIs is deprecated and will stop working soon. Please ensure you are using the new style ASAP.**" ] }, { "cell_type": "markdown", "metadata": { "id": "HaKjonlHKLpl" }, "source": [ "**`CheckpointManager.save(...)` is now async by default. Make sure you call `wait_until_finished` if depending on a previous save being completed. Otherwise, the behavior can be disabled via the\n", "`CheckpointManagerOptions.enable_async_checkpointing` option.**" ] }, { "cell_type": "markdown", "metadata": { "id": "iabvwEt5je5e" }, "source": [ "For further information on how to use the new API, see the [introductory tutorial](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) and the [API Overview](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_api_overview.html)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yOTQ0-9Kw0Yu" }, "outputs": [], "source": [ "import orbax.checkpoint as ocp\n", "from etils import epath" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mwgNGsM2Rk90" }, "outputs": [], "source": [ "# Dummy PyTrees for simplicity.\n", "\n", "# In reality, this would be a tree of np.ndarray or jax.Array.\n", "pytree = {'a': 0}\n", "# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata\n", "# for restoration).\n", "abstract_pytree = {'a': 0}\n", "\n", "extra_metadata = {'version': 1.0}" ] }, { "cell_type": "markdown", "metadata": { "id": "GcHS7DCWUprl" }, "source": [ "## Single-Item Checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "4BNW1o3mSH0L" }, "source": [ "### Before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yW6Jh1G3O8oK" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),\n", " ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),\n", " options=options,\n", ")\n", "\n", "restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)\n", "mngr.save(0, pytree)\n", "mngr.wait_until_finished()\n", "\n", "mngr.restore(\n", " 0,\n", " items=abstract_pytree,\n", " restore_kwargs={'restore_args': restore_args}\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "VBTMpnuDSt-D" }, "source": [ "### After" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Umu6GgpTSvQB" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "with ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),\n", " options=options,\n", ") as mngr:\n", " mngr.save(0, args=ocp.args.StandardSave(pytree))\n", "\n", " # The `CheckpointManager` already knows that the object is saved and restored\n", " # using \"standard\" pytree logic. In many cases, you can restore exactly as\n", " # saved without specifying additional arguments.\n", " mngr.restore(0)\n", " # If customization of properties like sharding or dtype is desired, just provide\n", " # the abstract target PyTree, the properties of which will be used to set\n", " # the properties of the restored arrays.\n", " mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))" ] }, { "cell_type": "markdown", "metadata": { "id": "e28L6prpTqfl" }, "source": [ "Important notes:\n", "\n", "\n", "* Don't forget to use the keyword `args=...` for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.\n", "* The value of `args` is a subclass of `CheckpointArgs`, present in the `ocp.args` module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, use `StandardSave`/`StandardRestore`." ] }, { "cell_type": "markdown", "metadata": { "id": "ebL-zbpVaH-4" }, "source": [ "Let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DrPYyYZRZwgx" }, "outputs": [], "source": [ "# Unmapped CheckpointHandlers on a new CheckpointManager instance.\n", "new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n", "try:\n", " new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler\n", "except BaseException as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "o6royiq9WNIg" }, "source": [ "To fix this, use one of the following options:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g_Wo9Q07c0tu" }, "outputs": [], "source": [ "new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))\n", "new_mngr.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "2wZzpye8WYdi" }, "source": [ "We can also configure the `CheckpointManager` to know how to restore the object in advance." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Bdto6tPAfQy" }, "outputs": [], "source": [ "# The item name is \"default\".\n", "list(epath.Path('/tmp/ckpt2/0').iterdir())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tATQ5rmCAqey" }, "outputs": [], "source": [ "registry = ocp.handlers.DefaultCheckpointHandlerRegistry()\n", "registry.add('default', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zeDU5DWxZzzf" }, "outputs": [], "source": [ "# item_handlers can be used as an alternative to restore(..., args=...).\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt2/',\n", " options=options,\n", " handler_registry=registry,\n", ") as new_mngr:\n", " print(new_mngr.restore(0))" ] }, { "cell_type": "markdown", "metadata": { "id": "MZidq2_EKMo_" }, "source": [ "**NOTE:**\n", "`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n", "\n", "So, `handler_registry` is currently required when calling `item_metadata(step)` before calling restore or save." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-qvpESD4bNlH" }, "outputs": [], "source": [ "# item_handlers becomes even more critical with item_metadata() calls.\n", "new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n", "try:\n", " new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler\n", "except BaseException as e:\n", " print(e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "672mph1JbsDD" }, "outputs": [], "source": [ "with ocp.CheckpointManager(\n", " '/tmp/ckpt2/',\n", " options=options,\n", " handler_registry=registry,\n", ") as new_mngr:\n", " new_mngr.item_metadata(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "LCt1O-NWUvMK" }, "source": [ "## Multiple-Item Checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb_I3d_1UxI7" }, "source": [ "### Before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hLVyGC1PTKl0" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),\n", " {\n", " 'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),\n", " 'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())\n", " },\n", " options=options,\n", ")\n", "\n", "restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)\n", "mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})\n", "mngr.wait_until_finished()\n", "\n", "mngr.restore(\n", " 0,\n", " items={'state': abstract_pytree, 'extra_metadata': None},\n", " restore_kwargs={\n", " 'state': {'restore_args': restore_args},\n", " 'extra_metadata': None\n", " },\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SfVUMiPfVluZ" }, "source": [ "### After" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HaRxV1-FVlMj" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),\n", " # `item_names` defines an up-front contract about what items the\n", " # CheckpointManager will be dealing with.\n", " options=options,\n", ")\n", "\n", "mngr.save(0, args=ocp.args.Composite(\n", " state=ocp.args.StandardSave(pytree),\n", " extra_metadata=ocp.args.JsonSave(extra_metadata))\n", ")\n", "mngr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FNoKE4X0hgfE" }, "outputs": [], "source": [ "# Restore as saved\n", "mngr.restore(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oXn7nYSJhlw2" }, "outputs": [], "source": [ "# Restore with customization. Restore a subset of items.\n", "mngr.restore(0, args=ocp.args.Composite(\n", " state=ocp.args.StandardRestore(abstract_pytree)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AsDNbCEtXL-N" }, "outputs": [], "source": [ "mngr.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "mZSZzWkhLSvz" }, "source": [ "Just like single item use case described above, let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M32H2fQ0LmTo" }, "outputs": [], "source": [ "# Unmapped CheckpointHandlers on a new CheckpointManager instance.\n", "new_mngr = ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_names=('state', 'extra_metadata'),\n", ")\n", "try:\n", " new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers\n", "except BaseException as e:\n", " print(e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_VSCqkHAL_NG" }, "outputs": [], "source": [ "new_mngr.restore(\n", " 0,\n", " args=ocp.args.Composite(\n", " state=ocp.args.StandardRestore(abstract_pytree),\n", " extra_metadata=ocp.args.JsonRestore(),\n", " ),\n", ")\n", "new_mngr.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nG12gM6l_y6C" }, "outputs": [], "source": [ "registry = ocp.handlers.DefaultCheckpointHandlerRegistry()\n", "registry.add('state', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)\n", "registry.add('extra_metadata', ocp.args.JsonRestore, ocp.JsonCheckpointHandler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8nx3t0jMcJQ" }, "outputs": [], "source": [ "# item_handlers can be used as an alternative to restore(..., args=...).\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " handler_registry=registry,\n", ") as new_mngr:\n", " print(new_mngr.restore(0))" ] }, { "cell_type": "markdown", "metadata": { "id": "iYjdji-eNHDY" }, "source": [ "**NOTE:**\n", "`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n", "\n", "So, `handler_registry` is currently required with `item_metadata(step)` calls." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5b5v9aEkNJKb" }, "outputs": [], "source": [ "# item_handlers becomes even more critical with item_metadata() calls.\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_names=('state', 'extra_metadata'),\n", ") as new_mngr:\n", " try:\n", " new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers\n", " except BaseException as e:\n", " print(e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oFKESnvlNYWP" }, "outputs": [], "source": [ "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " handler_registry=registry,\n", ") as new_mngr:\n", " print(new_mngr.item_metadata(0))" ] } ], "metadata": { "colab": { "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }