{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0jj2MOXcL9Eh" }, "source": [ "# Using the Refactored CheckpointManager API" ] }, { "cell_type": "markdown", "metadata": { "id": "3hFuQ-97OYLl" }, "source": [ "As of `orbax-checkpoint-0.5.0`, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with `CheckpointManager`. This page shows a side-by-side comparison of the old and new APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "wZKtrmVojffN" }, "source": [ "**The legacy APIs is deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.**" ] }, { "cell_type": "markdown", "metadata": { "id": "HaKjonlHKLpl" }, "source": [ "**`CheckpointManager.save(...)` is now async by default. Make sure you call `wait_until_finished` if depending on a previous save being completed. Otherwise, the behavior can be disabled via the\n", "`CheckpointManagerOptions.enable_async_checkpointing` option.**" ] }, { "cell_type": "markdown", "metadata": { "id": "iabvwEt5je5e" }, "source": [ "For further information on how to use the new API, see the [introductory tutorial](https://orbax.readthedocs.io/en/latest/orbax_checkpoint_101.html) and the [API Overview](https://orbax.readthedocs.io/en/latest/orbax_checkpoint_api_overview.html)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yOTQ0-9Kw0Yu" }, "outputs": [], "source": [ "import orbax.checkpoint as ocp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mwgNGsM2Rk90" }, "outputs": [], "source": [ "# Dummy PyTrees for simplicity.\n", "\n", "# In reality, this would be a tree of np.ndarray or jax.Array.\n", "pytree = {'a': 0}\n", "# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata\n", "# for restoration).\n", "abstract_pytree = {'a': 0}\n", "\n", "extra_metadata = {'version': 1.0}" ] }, { "cell_type": "markdown", "metadata": { "id": "GcHS7DCWUprl" }, "source": [ "## Single-Item Checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "4BNW1o3mSH0L" }, "source": [ "### Before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yW6Jh1G3O8oK" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),\n", " ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),\n", " options=options,\n", ")\n", "\n", "restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)\n", "mngr.save(0, pytree)\n", "mngr.wait_until_finished()\n", "\n", "mngr.restore(\n", " 0,\n", " items=abstract_pytree,\n", " restore_kwargs={'restore_args': restore_args}\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "VBTMpnuDSt-D" }, "source": [ "### After" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Umu6GgpTSvQB" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "with ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),\n", " options=options,\n", ") as mngr:\n", "\n", " mngr.save(0, args=ocp.args.StandardSave(pytree))\n", " mngr.wait_until_finished()\n", "\n", " # After providing `args` during an initial `save` or `restore` call, the\n", " # `CheckpointManager` instance records the type so that you do not need to\n", " # specify it again. If the `CheckpointManager` instance is not provided with a\n", " # `ocp.args.CheckpointArgs` instance for a particular item on a previous\n", " # occasion it cannot be restored without specifying the argument at restore\n", " # time.\n", "\n", " # In many cases, you can restore exactly as saved without specifying additional\n", " # arguments.\n", " mngr.restore(0)\n", " # If customization of properties like sharding or dtype is desired, just provide\n", " # the abstract target PyTree, the properties of which will be used to set\n", " # the properties of the restored arrays.\n", " mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))" ] }, { "cell_type": "markdown", "metadata": { "id": "e28L6prpTqfl" }, "source": [ "Important notes:\n", "\n", "\n", "* Don't forget to use the keyword `args=...` for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.\n", "* The value of `args` is a subclass of `CheckpointArgs`, present in the `ocp.args` module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, use `StandardSave`/`StandardRestore`." ] }, { "cell_type": "markdown", "metadata": { "id": "ebL-zbpVaH-4" }, "source": [ "Let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names.\n", "\n", "`CheckpointManager(..., item_handlers=...)` can be used to resolve these scenarios." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DrPYyYZRZwgx" }, "outputs": [], "source": [ "# Unmapped CheckpointHandlers on a new CheckpointManager instance.\n", "new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n", "new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g_Wo9Q07c0tu" }, "outputs": [], "source": [ "new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tGXkK0BIXDaz" }, "outputs": [], "source": [ "new_mngr.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zeDU5DWxZzzf" }, "outputs": [], "source": [ "# item_handlers can be used as an alternative to restore(..., args=...).\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt2/',\n", " options=options,\n", " item_handlers=ocp.StandardCheckpointHandler()\n", ") as new_mngr:\n", " print(new_mngr.restore(0))" ] }, { "cell_type": "markdown", "metadata": { "id": "MZidq2_EKMo_" }, "source": [ "**NOTE:**\n", "`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n", "\n", "So, `item_handlers` is the only option available with `item_metadata(step)` calls." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-qvpESD4bNlH" }, "outputs": [], "source": [ "# item_handlers becomes even more critical with item_metadata() calls.\n", "new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)\n", "new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "672mph1JbsDD" }, "outputs": [], "source": [ "new_mngr = ocp.CheckpointManager(\n", " '/tmp/ckpt2/',\n", " options=options,\n", " item_handlers=ocp.StandardCheckpointHandler(),\n", ")\n", "new_mngr.item_metadata(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mTZK3rB-LFg8" }, "outputs": [], "source": [ "new_mngr.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "LCt1O-NWUvMK" }, "source": [ "## Multiple-Item Checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb_I3d_1UxI7" }, "source": [ "### Before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hLVyGC1PTKl0" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),\n", " {\n", " 'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),\n", " 'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())\n", " },\n", " options=options,\n", ")\n", "\n", "restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)\n", "mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})\n", "mngr.wait_until_finished()\n", "\n", "mngr.restore(\n", " 0,\n", " items={'state': abstract_pytree, 'extra_metadata': None},\n", " restore_kwargs={\n", " 'state': {'restore_args': restore_args},\n", " 'extra_metadata': None\n", " },\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SfVUMiPfVluZ" }, "source": [ "### After" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HaRxV1-FVlMj" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions()\n", "mngr = ocp.CheckpointManager(\n", " ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),\n", " # `item_names` defines an up-front contract about what items the\n", " # CheckpointManager will be dealing with.\n", " item_names=('state', 'extra_metadata'),\n", " options=options,\n", ")\n", "\n", "mngr.save(0, args=ocp.args.Composite(\n", " state=ocp.args.StandardSave(pytree),\n", " extra_metadata=ocp.args.JsonSave(extra_metadata))\n", ")\n", "mngr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FNoKE4X0hgfE" }, "outputs": [], "source": [ "# Restore as saved\n", "mngr.restore(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oXn7nYSJhlw2" }, "outputs": [], "source": [ "# Restore with customization. Restore a subset of items.\n", "mngr.restore(0, args=ocp.args.Composite(\n", " state=ocp.args.StandardRestore(abstract_pytree)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AsDNbCEtXL-N" }, "outputs": [], "source": [ "mngr.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "mZSZzWkhLSvz" }, "source": [ "Just like single item use case described above, let's explore scenarios when `restore()` and `item_metadata()` calls raise errors due to unspecified CheckpointHandlers for item names.\n", "\n", "`CheckpointManager(..., item_handlers=...)` can be used to resolve these scenarios." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M32H2fQ0LmTo" }, "outputs": [], "source": [ "# Unmapped CheckpointHandlers on a new CheckpointManager instance.\n", "new_mngr = ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_names=('state', 'extra_metadata'),\n", ")\n", "new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_VSCqkHAL_NG" }, "outputs": [], "source": [ "new_mngr.restore(\n", " 0,\n", " args=ocp.args.Composite(\n", " state=ocp.args.StandardRestore(abstract_pytree),\n", " extra_metadata=ocp.args.JsonRestore(),\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wVtGDZS1XQKy" }, "outputs": [], "source": [ "new_mngr.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8nx3t0jMcJQ" }, "outputs": [], "source": [ "# item_handlers can be used as an alternative to restore(..., args=...).\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_handlers={\n", " 'state': ocp.StandardCheckpointHandler(),\n", " 'extra_metadata': ocp.JsonCheckpointHandler(),\n", " },\n", ") as new_mngr:\n", " print(new_mngr.restore(0))" ] }, { "cell_type": "markdown", "metadata": { "id": "iYjdji-eNHDY" }, "source": [ "**NOTE:**\n", "`CheckpointManager.item_metadata(step)` doesn't support any input like `args` in `restore(..., args=...)`.\n", "\n", "So, `item_handlers` is the only option available with `item_metadata(step)` calls." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5b5v9aEkNJKb" }, "outputs": [], "source": [ "# item_handlers becomes even more critical with item_metadata() calls.\n", "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_names=('state', 'extra_metadata'),\n", ") as new_mngr:\n", " new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oFKESnvlNYWP" }, "outputs": [], "source": [ "with ocp.CheckpointManager(\n", " '/tmp/ckpt4/',\n", " options=options,\n", " item_handlers={\n", " 'state': ocp.StandardCheckpointHandler(),\n", " 'extra_metadata': ocp.JsonCheckpointHandler(),\n", " },\n", ") as new_mngr:\n", " print(new_mngr.item_metadata(0))" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "17zeb2jhSE6p1x3u7r_s15AuKrOPtQ7y3", "timestamp": 1704302873675 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }