{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "GYCcRRZas1PS" }, "source": [ "# Checkpointing PyTrees of Arrays" ] }, { "cell_type": "markdown", "metadata": { "id": "iJwUDQVA7GIV" }, "source": [ "A [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the `jax.Array`. This, too, has extensive first-class support." ] }, { "cell_type": "markdown", "metadata": { "id": "aL1Xscnl7h0l" }, "source": [ "## `CheckpointHandler` Support" ] }, { "cell_type": "markdown", "metadata": { "id": "hP-MHyde77OI" }, "source": [ "There are essentially two options provided by Orbax for working with PyTrees.\n", "\n", "\n", "* [`StandardCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#orbax.checkpoint.StandardCheckpointHandler) - applicable in the **majority** of use cases.\n", "* [`PyTreeCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#orbax.checkpoint.PyTreeCheckpointHandler) - useful when advanced customization is desired.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Mn7oFpJ8kq2" }, "outputs": [], "source": [ "import numpy as np\n", "import orbax.checkpoint as ocp\n", "import jax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FXlVdKu285XZ" }, "outputs": [], "source": [ "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('model',)),\n", " jax.sharding.PartitionSpec(\n", " 'model',\n", " ),\n", ")\n", "create_sharded_array = lambda x: jax.device_put(x, sharding)\n", "state = {\n", " 'a': np.arange(16),\n", " 'b': np.ones(16),\n", "}\n", "state = jax.tree_util.tree_map(create_sharded_array, state)\n", "abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)" ] }, { "cell_type": "markdown", "metadata": { "id": "XEOf07p-8VgS" }, "source": [ "## Basic Checkpointing" ] }, { "cell_type": "markdown", "metadata": { "id": "cQDXplyP9wOq" }, "source": [ "Let's use `StandardCheckpointHandler` to work with PyTrees of `jax.Array`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DyaiVrUo7n80" }, "outputs": [], "source": [ "path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')\n", "# Make sure to use async for improved performance!\n", "ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f1OFlu8L9E0z" }, "outputs": [], "source": [ "ckptr.save(path / '1', args=ocp.args.StandardSave(state))\n", "ckptr.wait_until_finished()" ] }, { "cell_type": "markdown", "metadata": { "id": "WYxThD3g92Tx" }, "source": [ "We specify the `abstract_state` in order to restore with the given dtypes, shapes, and shardings for each leaf." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MSJa92RF9WZL" }, "outputs": [], "source": [ "restored = ckptr.restore(path / '1', args=ocp.args.StandardRestore(abstract_state))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyFRWThb9qgu" }, "outputs": [], "source": [ "restored" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8bRma2Tq9rIE" }, "outputs": [], "source": [ "restored['a'].sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "tT7s1_G9ruBf" }, "source": [ "You can do the exact same with a \"concrete\" target rather than an \"abstract\" target. However, this requires that you fully initialize the target train state\n", "before restoring from the checkpoint, which is inefficient. It is better practice to only initialize metadata (either by manually creating `jax.ShapeDtypeStruct`s or using `jax.eval_shape`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z6ibyHvysBQH" }, "outputs": [], "source": [ "ckptr.restore(path / '1', args=ocp.args.StandardRestore(state))" ] }, { "cell_type": "markdown", "metadata": { "id": "0XwT7OEq-CRj" }, "source": [ "### Customizing Restored Properties" ] }, { "cell_type": "markdown", "metadata": { "id": "tqvsrw-V_I-H" }, "source": [ "#### Array dtype" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ZIclcqI-ElS" }, "outputs": [], "source": [ "def set_restore_dtype(x: jax.ShapeDtypeStruct) -\u003e jax.ShapeDtypeStruct:\n", " x.dtype = np.int16\n", " return x\n", "\n", "cast_dtype_abstract_state = jax.tree_util.tree_map(\n", " set_restore_dtype, abstract_state)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EdWjOdOM-S5c" }, "outputs": [], "source": [ "ckptr.restore(\n", " path / '1',\n", " args=ocp.args.StandardRestore(cast_dtype_abstract_state),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "WFuhjDD2_NWS" }, "source": [ "#### Pad / truncate shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "W2WR28Ss_Pix" }, "outputs": [], "source": [ "different_shape_abstract_state = {\n", " 'a': jax.ShapeDtypeStruct(\n", " shape=(8,),\n", " dtype=abstract_state['a'].dtype,\n", " sharding=abstract_state['a'].sharding\n", " ),\n", " 'b': jax.ShapeDtypeStruct(\n", " shape=(32,),\n", " dtype=abstract_state['b'].dtype,\n", " sharding=abstract_state['b'].sharding\n", " ),\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbIzTTi9_v2p" }, "outputs": [], "source": [ "ckptr.restore(\n", " path / '1',\n", " args=ocp.args.StandardRestore(different_shape_abstract_state),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "_DCeo0N-_zbe" }, "source": [ "#### Change sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "ahq2NLdR_3Y3" }, "source": [ "**NOTE: This can often be a particularly sharp edge.**\n", "\n", "Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.\n", "\n", "**If changing topologies, you MUST specify sharding when restoring.**\n", "\n", "Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on you behalf. If you have the exact same topology,\n", "however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ebaAxl8W_2th" }, "outputs": [], "source": [ "restored = ckptr.restore(path / '1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-EzInXPARHA" }, "outputs": [], "source": [ "restored['a'].sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "0r0Sd6B-A6gY" }, "source": [ "In the example below, we alter the sharding while loading." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v_VFeNz5BCYs" }, "outputs": [], "source": [ "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('x',)),\n", " jax.sharding.PartitionSpec(),\n", ")\n", "def set_sharding(x: jax.ShapeDtypeStruct) -\u003e jax.ShapeDtypeStruct:\n", " x.sharding = sharding\n", " return x\n", "\n", "change_sharding_abstract_state = jax.tree_util.tree_map(\n", " set_sharding, abstract_state)\n", "restored = ckptr.restore(\n", " path / '1',\n", " args=ocp.args.StandardRestore(change_sharding_abstract_state),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AA4i5SZTBNOz" }, "outputs": [], "source": [ "restored['a'].sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "VkVCqDTNBYOg" }, "source": [ "## Advanced Options" ] }, { "cell_type": "markdown", "metadata": { "id": "e1f-Csv3BaRF" }, "source": [ "There are some advanced options that `StandardCheckpointHandler` does not provide. Additional options can be specified using `PyTreeCheckpointHandler`\n", "instead." ] }, { "cell_type": "markdown", "metadata": { "id": "-sk2M8BsBlgZ" }, "source": [ "### Saving" ] }, { "cell_type": "markdown", "metadata": { "id": "hvNsP_LwBmq9" }, "source": [ "For example, `PyTreeCheckpointHandler` can be used to customize the on-disk type used to save individual arrays. First, let's save and restore as normal." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KDeRAyZGBxeK" }, "outputs": [], "source": [ "path = ocp.test_utils.erase_and_create_empty('/tmp/advanced/')\n", "# Make sure to use async for improved performance!\n", "ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YO8DaA1-BtFv" }, "outputs": [], "source": [ "ckptr.save(path / '1', args=ocp.args.PyTreeSave(state))\n", "ckptr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R5D_ZRDUB4Qt" }, "outputs": [], "source": [ "restored = ckptr.restore(path / '1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XZWZcJHeB8w9" }, "outputs": [], "source": [ "restored['a'].dtype" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l8D2EEELCj2k" }, "outputs": [], "source": [ "restored['b'].dtype" ] }, { "cell_type": "markdown", "metadata": { "id": "acU57c1iCSiC" }, "source": [ "Now, let's set the dtype of the array when saving." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rSTkx2T4CU-8" }, "outputs": [], "source": [ "ckptr.save(\n", " path / '2',\n", " args=ocp.args.PyTreeSave(\n", " state,\n", " save_args={\n", " # We must set one ocp.SaveArgs per leaf.\n", " 'a': ocp.SaveArgs(dtype=np.int16),\n", " 'b': ocp.SaveArgs()\n", " }\n", " ),\n", ")\n", "ckptr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oZ7-PJ9HCmH8" }, "outputs": [], "source": [ "restored = ckptr.restore(path / '2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P8U2qeTMCnMg" }, "outputs": [], "source": [ "restored['a'].dtype" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NZmQA5L6Cn2v" }, "outputs": [], "source": [ "restored['b'].dtype" ] }, { "cell_type": "markdown", "metadata": { "id": "NS0acS2QDIzi" }, "source": [ "### Restoring" ] }, { "cell_type": "markdown", "metadata": { "id": "PfUHk-EyDJs1" }, "source": [ "Options similar to the above are available, where we can customize shape, dtype, and sharding when restoring." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gauAkQ4yDPsy" }, "outputs": [], "source": [ "ckptr.restore(\n", " path / '2',\n", " args=ocp.args.PyTreeRestore(\n", " restore_args={\n", " # RestoreArgs is the parent class for ArrayRestoreArgs.\n", " # We must set one RestoreArgs per leaf.\n", " 'a': ocp.RestoreArgs(restore_type=np.ndarray),\n", " 'b': ocp.ArrayRestoreArgs(dtype=np.int16, sharding=sharding)\n", " }\n", " ),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "edXlWrqTDpnf" }, "source": [ "Note that \"a\" was restored as `np.ndarray` rather than `jax.Array`." ] }, { "cell_type": "markdown", "metadata": { "id": "MMPpjjErDzTT" }, "source": [ "`PyTreeCheckpointHandler` also provides options to perform transformations when restoring. This is useful when your target tree has a different structure than your checkpoint tree. This allows you to avoid loading some keys or rename other keys. Full details are available at the [Transformations](https://orbax.readthedocs.io/en/latest/transformations.html) page." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TYq5YTtBDy5w" }, "outputs": [], "source": [ "ckptr.restore(\n", " path / '2',\n", " args=ocp.args.PyTreeRestore(\n", " # `item` serves as a guide to what the result tree structure should look\n", " # like.\n", " item={\n", " # Value doesn't really matter here, as long as it's not None.\n", " 'c': ...,\n", " # Can add in extra keys.\n", " 'd': np.arange(8)\n", " },\n", " # `restore_args` must be relative to the result tree, not the\n", " # checkpoint.\n", " restore_args={\n", " 'c': ocp.RestoreArgs(restore_type=np.ndarray),\n", " },\n", " transforms={\n", " 'c': ocp.Transform(original_key='a')\n", " },\n", " ),\n", ")" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1QNxBBBN16Br9Xj-a7LvtJzJWjOBhjFps", "timestamp": 1686159333109 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }