{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "91nkq8cAXf95" }, "source": [ "# Checkpoint Format Guide" ] }, { "cell_type": "markdown", "metadata": { "id": "d2CVAjXdaXsw" }, "source": [ "It is important to have an understanding of how Orbax structures checkpoints on\n", "disk, particularly if you ever need to debug at the checkpoint level, or if you\n", "wish to work with specific pieces of a larger checkpoint." ] }, { "cell_type": "markdown", "metadata": { "id": "0-UNQoMmbBy5" }, "source": [ "First, some setup:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skQfDEX7Xlg5" }, "outputs": [], "source": [ "from etils import epath\n", "import jax\n", "import numpy as np\n", "import orbax.checkpoint as ocp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGoHzSn6XryA" }, "outputs": [], "source": [ "sharding = jax.sharding.NamedSharding(\n", " jax.sharding.Mesh(jax.devices(), ('model',)),\n", " jax.sharding.PartitionSpec(\n", " 'model',\n", " ),\n", ")\n", "create_sharded_array = lambda x: jax.device_put(x, sharding)\n", "state = {\n", " 'a': np.arange(16),\n", " 'b': np.ones(16),\n", "}\n", "state = jax.tree_util.tree_map(create_sharded_array, state)\n", "abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)\n", "state['c'] = np.arange(4)\n", "state['d'] = 5\n", "state['e'] = 'foo'\n", "state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HOz6zJHEZSzZ" }, "outputs": [], "source": [ "def print_directory(directory: epath.PathLike, level: int = 0):\n", " \"\"\"Prints a directory tree for debugging purposes.\"\"\"\n", " directory = epath.Path(directory)\n", " assert directory.is_dir()\n", " level_str = '..' * level\n", " if level == 0:\n", " print(f'Printing directory tree: {directory}/')\n", " else:\n", " print(f'{level_str}{directory.name}/')\n", "\n", " level_str = '..' * (level + 1)\n", " for p in directory.iterdir():\n", " if p.is_dir():\n", " print_directory(p, level=level + 1)\n", " else:\n", " print(f'{level_str}{p.name}')" ] }, { "cell_type": "markdown", "metadata": { "id": "3LtbC4p9bEyW" }, "source": [ "We will start by creating a checkpoint for step `0`, consisting of two items:\n", "`state` and `metadata`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "crmVgq70X786" }, "outputs": [], "source": [ "path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0ehjDFEDYH8c" }, "outputs": [], "source": [ "global_metadata = {'global_property': 'foo'}\n", "with ocp.CheckpointManager(\n", " path, item_names=('state', 'custom_data'), metadata=global_metadata\n", ") as mngr:\n", " mngr.save(\n", " 0,\n", " args=ocp.args.Composite(\n", " state=ocp.args.PyTreeSave(state),\n", " custom_data=ocp.args.JsonSave({'lang': 'en', 'version': 1.2}),\n", " ),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "62Wo_H0-Yhhz" }, "outputs": [], "source": [ "print_directory(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "Orls3G2dbOIB" }, "source": [ "Let's understand each of these pieces separately." ] }, { "cell_type": "markdown", "metadata": { "id": "wPohuIO0bUqG" }, "source": [ "### Root Directory" ] }, { "cell_type": "markdown", "metadata": { "id": "uLhpxF9zblFW" }, "source": [ "The \"root directory\" is understood to be the directory provided when creating a\n", "`CheckpointManager`. It represents the parent directory where all \"sequential\"\n", "checkpoints will reside (see below). In the above example, this corresponds to\n", "`/tmp/checkpoint/`.\n", "\n", "Within the root directory, aside from the sequential checkpoints, there may also\n", "be a `metadata` subdirectory (if `metadata` was provided when configuring the\n", "`CheckpointManager`)." ] }, { "cell_type": "markdown", "metadata": { "id": "mwKuibLGbz8x" }, "source": [ "### Sequential Checkpoint" ] }, { "cell_type": "markdown", "metadata": { "id": "Lj_7-2d2b1lK" }, "source": [ "With the term \"sequential checkpoint\", we refer to a checkpoint that represents\n", "a particular step in a longer sequence. Typically, in Orbax, this is simply\n", "denote with a directory named with an integer value (`0/` in the above example).\n", "However, options are available to\n", "[customize](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.path.step.html)\n", "the default format.\n", "\n", "The sequential checkpoint has a top-level `_CHECKPOINT_METADATA` file that\n", "stores basic information like the creation timestamp, and other fields." ] }, { "cell_type": "markdown", "metadata": { "id": "S75QHRWLcPbR" }, "source": [ "### Checkpoint Items" ] }, { "cell_type": "markdown", "metadata": { "id": "uOJJCPSocRPk" }, "source": [ "Within a sequential checkpoint directory, we have subdirectories corresponding\n", "to \"items\". An \"item\" represents a logically distinct unit of a larger\n", "checkpoint, so these are naturally represented in separate subdirectories. In\n", "the above example, the items are `state` and `custom_data`.\n", "\n", "This representation makes compositition easier if you want to combine the\n", "dataset from one checkpoint with the state from another, for instance. It also\n", "prevents collisions if you use the same `CheckpointHandler` to save both state\n", "and embeddings, for instance.\n", "\n", "Below this level, the format is no longer universally standard, because each\n", "`CheckpointHandler` customizes its own file format." ] }, { "cell_type": "markdown", "metadata": { "id": "UueAkpOIiaJ9" }, "source": [ "### PyTree Checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "SXFXKlTgifJp" }, "source": [ "Because the `state` item was saved with `ocp.args.PyTreeSave` (the same would\n", "apply if saved with `ocp.args.StandardSave`), it takes the following form:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bTD6Y2_NisRp" }, "outputs": [], "source": [ "print_directory(path / '0' / 'state')" ] }, { "cell_type": "markdown", "metadata": { "id": "3mr2dzvli5i9" }, "source": [ "The `_METADATA` file provides a complete description of the PyTree structure,\n", "including custom and empty nodes.\n", "\n", "The tree is represented as a flattened dictionary, where each key is represented\n", "as a tuple, where successive elements denote successive levels of nesting. For\n", "example, for the dict `{'a': {'b': [1, 2]}}` the metadata file would contain two\n", "entries with keys `('a', 'b', '0')` and `('a', 'b', '1')`.\n", "\n", "Keys at each level of nesting also encode what type they are: i.e. whether they\n", "are a dict key or a sequential key.\n", "\n", "Finally, metadata about the value type is stored (e.g. `jax.Array`,\n", "`np.ndarray`, etc.) in order to allow for later reconstruction without\n", "explicitly requiring the object type to be provided." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nAZ2LDB5jZjN" }, "outputs": [], "source": [ "import json\n", "\n", "json.loads((path / '0' / 'state' / '_METADATA').read_text())" ] }, { "cell_type": "markdown", "metadata": { "id": "0A-A6Qx2m580" }, "source": [ "The `_sharding` file stores information about the shardings originally used when\n", "saving `jax.Array`s in the tree. It isn't really human-readable though. To get\n", "information about shardings, use the `metadata` APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "5WKt4_A-nL_U" }, "source": [ "Beyond these metadata files, which are directly managed by Orbax, we also have a\n", "`manifest.ocdbt` file managed by the TensorStore library. Actual array data is\n", "stored within the `d/` subdirectory. Since these files are opaque to human\n", "readers, we will not go into detail on their structure." ] }, { "cell_type": "markdown", "metadata": { "id": "F7S1iZMRn_oA" }, "source": [ "Finally, you'll notice the presence of the directory `ocdbt.process_0/`, which\n", "also has a `manifest.ocdbt` and its own `d/` subdirectory. One such folder\n", "exists for every process on which the checkpoint was saved. This exists because\n", "each process first writes its own data independently to its corresponding\n", "subdirectory.\n", "\n", "When all processes have finished, Orbax runs a finalization pass to cheaply\n", "merge the metadatas from all per-process subdirectories into a global view (note\n", "that this still references data in the original subdirectories)." ] }, { "cell_type": "markdown", "metadata": { "id": "CL2i8L2hjdUF" }, "source": [ "### Working with TensorStore" ] }, { "cell_type": "markdown", "metadata": { "id": "C676mXpXjjRa" }, "source": [ "Sometimes, it is helpful to work directly with the [TensorStore](https://google.github.io/tensorstore/) API to debug individual parameters in a checkpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MGWC5PSwmLj_" }, "outputs": [], "source": [ "from etils import epath\n", "import jax\n", "import tensorstore as ts\n", "\n", "ts_context = ts.Context(\n", " {\n", " # Provide cache pool for B-tree nodes to avoid repeated reads.\n", " # 100MB limit.\n", " 'cache_pool#ocdbt': {'total_bytes_limit': 100000000},\n", " },\n", " parent=jax.experimental.array_serialization.serialization.TS_CONTEXT,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "axweQWhpjsSM" }, "source": [ "To read using TensorStore, we need to construct a TensorStore Spec. For this, we can use Orbax APIs. The spec points to a base path, as well as a particular parameter name (`a` in this case). It contains further options related to the checkpoint format." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SWCCkHnkmNoI" }, "outputs": [], "source": [ "ParamInfo = ocp.type_handlers.ParamInfo\n", "state_dir = path / '0' / 'state'\n", "param_name = 'a'\n", "param_path = state_dir / param_name\n", "info = ParamInfo(name='a', path=path, parent_dir=state_dir, is_ocdbt_checkpoint=True, use_zarr3=True)\n", "tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)\n", "tspec" ] }, { "cell_type": "markdown", "metadata": { "id": "G68gjMJFj80I" }, "source": [ "We can verify which keys are present in the checkpoint, which matches information we gathered earlier from the Orbax `metadata` API." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FDFbWablo-QE" }, "outputs": [], "source": [ "ts.KvStore.open({\"driver\": \"ocdbt\", \"base\": \"file:///tmp/checkpoint/0/state/\"}).result().list().result()" ] }, { "cell_type": "markdown", "metadata": { "id": "wRvlthGUkNKz" }, "source": [ "Finally, we can directly restore the array using TensorStore." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QO6hZh3Kmj0n" }, "outputs": [], "source": [ "tspec = {'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': 'file:///tmp/checkpoint/0/state/', 'path': 'a'}}\n", "t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()\n", "result = t.read().result()\n", "result" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1y-D8gfz2AVpxFr8UiCJWFdNkMU8Z0DKK", "timestamp": 1717694482327 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }