{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "GYCcRRZas1PS" }, "source": [ "# Checkpointing with Orbax" ] }, { "cell_type": "markdown", "metadata": { "id": "1Ik2ARq4JaL3" }, "source": [ "This page serves as a simple overview of common tasks that you may wish to accomplish with Orbax. For more in-depth documentation of the API's, see [API Overview](https://orbax.readthedocs.io/en/latest/orbax_checkpoint_api_overview.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "_-dvMEjPw5fl" }, "source": [ "## Saving and Restoring" ] }, { "cell_type": "markdown", "metadata": { "id": "VyfEdKvwswys" }, "source": [ "The following example shows how you can synchronously save and restore a PyTree. See [Checkpointing PyTrees](https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html) for more detail.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NDPLLgWQtLSL" }, "outputs": [], "source": [ "import numpy as np\n", "import orbax.checkpoint as ocp\n", "import jax" ] }, { "cell_type": "markdown", "metadata": { "id": "TUmW2-ynt8Uh" }, "source": [ "Ensure that the top-level directory already exists before saving." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dC3IOhlEt7ne" }, "outputs": [], "source": [ "path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')" ] }, { "cell_type": "markdown", "metadata": { "id": "LAFJDHPLuN7V" }, "source": [ "Create a basic [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html). This is simply a nested tree-like structure, which can include dicts, lists, or more complicated objects. For the leaves of the tree, Orbax is capable of handling many different types. For our purposes, we will simply use a nested dict of some simple arrays.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d4CzPwbiuEKq" }, "outputs": [], "source": [ "my_tree = {\n", " 'a': np.arange(8),\n", " 'b': {\n", " 'c': 42,\n", " 'd': np.arange(16),\n", " },\n", "}\n", "abstract_my_tree = jax.tree_util.tree_map(\n", " ocp.utils.to_shape_dtype_struct, my_tree)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZcjGuIwdvZzF" }, "source": [ "To save and restore, we create a `Checkpointer` object. The `Checkpointer` must be constructed with a `CheckpointHandler` - essentially as a configuration providing the `Checkpointer` with the logic needed to save and restore your object.\n", "\n", "For PyTrees, the most common checkpointable object, we can use the convenient shorthand of `StandardCheckpointer`, which is the same as `Checkpointer(StandardCheckpointHandler())` (see [docs](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#standardcheckpointhandler) for more info)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UmE--WLDskhL" }, "outputs": [], "source": [ "checkpointer = ocp.StandardCheckpointer()\n", "# 'checkpoint_name' must not already exist.\n", "checkpointer.save(path / 'checkpoint_name', my_tree)\n", "checkpointer.restore(\n", " path / 'checkpoint_name/',\n", " args=ocp.args.StandardRestore(abstract_my_tree)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SWhla0U_D7iR" }, "source": [ "Metadata about the checkpoint can be retrieved with the `metadata` function, making it easy to gather information about an arbitrary checkpoint, or to manually inspect certain properties." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iGXn9jLFEIOb" }, "outputs": [], "source": [ "checkpointer.metadata(path / 'checkpoint_name')" ] }, { "cell_type": "markdown", "metadata": { "id": "i8SzoiOLEQKK" }, "source": [ "## Multiple Objects" ] }, { "cell_type": "markdown", "metadata": { "id": "LI-3bCyMEUCC" }, "source": [ "It is often necessary to deal with multiple distinct checkpointable objects at once, often with different types. A `Checkpointer` combined with a `CompositeCheckpointHandler` ([docs](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#compositecheckpointhandler)) can be used to represent a single checkpoint consisting of multiple sub-items, each represented by a sub-directory within the checkpoint." ] }, { "cell_type": "markdown", "metadata": { "id": "dmjd-ywaGETl" }, "source": [ "However, when you have a particular object that you're saving, Orbax needs to know **how** you want to save it. After all, if you provide a nested dict to save, there's no way to tell whether it should be saved in a simple JSON representation sufficient for basic metadata, or whether it requires more advanced logic suitable for sharded `jax.Array`s. This information can be provided via the `orbax.checkpoint.args` module." ] }, { "cell_type": "markdown", "metadata": { "id": "lvc66T0dGhbn" }, "source": [ "In the example below, we can imagine that `state` is a PyTree consisting of large sharded arrays. In contrast `metadata` contains a few strings and ints, and can easily be saved using JSON." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dj56Y5kRG6z6" }, "outputs": [], "source": [ "metadata = {\n", " 'version': 1.0,\n", " 'lang': 'en',\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s3Nu3Lo-Fshf" }, "outputs": [], "source": [ "checkpointer = ocp.Checkpointer(\n", " ocp.CompositeCheckpointHandler('state', 'metadata')\n", ")\n", "checkpointer.save(\n", " path / 'composite_checkpoint',\n", " args=ocp.args.Composite(\n", " state=ocp.args.StandardSave(my_tree),\n", " metadata=ocp.args.JsonSave(metadata),\n", " ),\n", ")\n", "restored = checkpointer.restore(path / 'composite_checkpoint')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eA1P7PynHl7Z" }, "outputs": [], "source": [ "restored.state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EZAXy8bEHnVs" }, "outputs": [], "source": [ "restored.metadata" ] }, { "cell_type": "markdown", "metadata": { "id": "3coO4_6dHldR" }, "source": [ "Inspecting the checkpoint directory, we can see that it has sub-directories for `state` and `metadata`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JWchUf1vHhZs" }, "outputs": [], "source": [ "list((path / 'composite_checkpoint').iterdir())" ] }, { "cell_type": "markdown", "metadata": { "id": "RK2PmRAow8Ty" }, "source": [ "## Managing Checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "sYxNt2Zrw_Yt" }, "source": [ "In the context of training a model, it is helpful to deal with a series of steps. The `CheckpointManager` allows you to save steps sequentially, according to a given period, cleaning up after a certain number are stored, and many other [functionalities](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html#checkpointmanageroptions).\n", "\n", "**Beware: `CheckpointManager.save(...)` happens in a background thread by default. See [Asynchronous Checkpointing](https://orbax.readthedocs.io/en/latest/async_checkpointing.html) for more details.**\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d3DSYn5_ynhi" }, "outputs": [], "source": [ "path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager')\n", "state = {\n", " 'a': np.arange(8),\n", " 'b': np.arange(16),\n", "}\n", "extra_params = [42, 43]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kWKXe6vCx3O9" }, "outputs": [], "source": [ "# Keeps a maximum of 3 checkpoints, and only saves every other step.\n", "options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)\n", "mngr = ocp.CheckpointManager(\n", " path, options=options, item_names=('state', 'extra_params')\n", ")\n", "\n", "for step in range(11): # [0, 1, ..., 10]\n", " mngr.save(\n", " step,\n", " args=ocp.args.Composite(\n", " state=ocp.args.StandardSave(state),\n", " extra_params=ocp.args.JsonSave(extra_params),\n", " ),\n", " )\n", "mngr.wait_until_finished()\n", "restored = mngr.restore(10)\n", "restored_state, restored_extra_params = restored.state, restored.extra_params" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q5NkveVgysr9" }, "outputs": [], "source": [ "mngr.all_steps()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dPTf7D2pyva4" }, "outputs": [], "source": [ "mngr.latest_step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vU7nsnC-ywDs" }, "outputs": [], "source": [ "mngr.should_save(11)" ] }, { "cell_type": "markdown", "metadata": { "id": "t3trISVvzasj" }, "source": [ "## A Standard Recipe" ] }, { "cell_type": "markdown", "metadata": { "id": "6kIBdZZgzc0R" }, "source": [ "In most cases, users will wish to save and restore a PyTree representing a model state over the course of many training steps. Many users will also wish to do this is a multi-host, multi-device environment." ] }, { "cell_type": "markdown", "metadata": { "id": "CiYSn5nEzly-" }, "source": [ "First, we will create a PyTree state with sharded `jax.Array` as leaves." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kbdURa01zvVy" }, "outputs": [], "source": [ "import jax\n", "\n", "path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager_sharded')\n", "\n", "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('model',)),\n", " jax.sharding.PartitionSpec(\n", " 'model',\n", " ),\n", ")\n", "create_sharded_array = lambda x: jax.device_put(x, sharding)\n", "train_state = {\n", " 'a': np.arange(16),\n", " 'b': np.ones(16),\n", "}\n", "train_state = jax.tree_util.tree_map(create_sharded_array, train_state)\n", "jax.tree_util.tree_map(lambda x: x.sharding, train_state)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qwhFWzvLzg5b" }, "outputs": [], "source": [ "num_steps = 10\n", "options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)\n", "mngr = ocp.CheckpointManager(path, options=options)\n", "\n", "\n", "@jax.jit\n", "def train_fn(state):\n", " return jax.tree_util.tree_map(lambda x: x + 1, state)\n", "\n", "\n", "for step in range(num_steps):\n", " train_state = train_fn(train_state)\n", " mngr.save(step, args=ocp.args.StandardSave(train_state))\n", "mngr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sZOLdSHj3DqI" }, "outputs": [], "source": [ "mngr.restore(mngr.latest_step())" ] }, { "cell_type": "markdown", "metadata": { "id": "Q5bT61h-2LiQ" }, "source": [ "Let's imagine now that we are starting a new training run, and would like to restore the checkpoint previously saved. In this case, we only know the tree structure of the checkpoint, and not the actual array values. We would also like to load the arrays with different sharding constraints than how they were originally saved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEo_AE_k06S1" }, "outputs": [], "source": [ "train_state = jax.tree_util.tree_map(np.zeros_like, train_state)\n", "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('model',)),\n", " jax.sharding.PartitionSpec(\n", " None,\n", " ),\n", ")\n", "create_sharded_array = lambda x: jax.device_put(x, sharding)\n", "train_state = jax.tree_util.tree_map(create_sharded_array, train_state)\n", "abstract_train_state = jax.tree_util.tree_map(\n", " ocp.utils.to_shape_dtype_struct, train_state\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "XVmkVuMK7GuG" }, "source": [ "Construct arguments needed for restoration." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "695TeDGW4OQi" }, "outputs": [], "source": [ "restored = mngr.restore(\n", " mngr.latest_step(),\n", " args=ocp.args.StandardRestore(abstract_train_state),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "axUVKyvz4tJn" }, "outputs": [], "source": [ "restored" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MpWx0PG24t28" }, "outputs": [], "source": [ "jax.tree_util.tree_map(lambda x: x.sharding, restored)" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1QNxBBBN16Br9Xj-a7LvtJzJWjOBhjFps", "timestamp": 1686159333109 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }