{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "UFuWr7-bLyeS" }, "source": [ "# API Overview" ] }, { "cell_type": "markdown", "metadata": { "id": "aw_0020nL1hQ" }, "source": [ "## CheckpointHandler Layer" ] }, { "cell_type": "markdown", "metadata": { "id": "xXFBCzozL4Ml" }, "source": [ "The lowest-level API that users typically interact with in Orbax is the [`CheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html). Every `CheckpointHandler` is also paired with one or two [`CheckpointArgs`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html) objects which encapsulate all necessary and optional arguments that a user can provide when saving or restoring. At a high level `CheckpointHandler` exists to provide the logic required to save or restore a particular object in a checkpoint." ] }, { "cell_type": "markdown", "metadata": { "id": "rI1LM9NvMTbC" }, "source": [ "`CheckpointHandler` allows for synchronous saving. Subclasses of [`AsyncCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#asynccheckpointhandler) allow for asynchronous saving. (Restoration is always synchronous.)" ] }, { "cell_type": "markdown", "metadata": { "id": "oSFxgHNxMmmq" }, "source": [ "Crucially a `CheckpointHandler` instance **should not be used in isolation**, but should always be used **in conjunction with a `Checkpointer`** (see below). Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with `Checkpointer` APIs rather than `CheckpointHandler` APIs." ] }, { "cell_type": "markdown", "metadata": { "id": "a68ZfF-2NytG" }, "source": [ "However, it is still essential to understand `CheckpointHandler` because you need to know how you want your object to be saved and restored, and what arguments are necessary to make that happen." ] }, { "cell_type": "markdown", "metadata": { "id": "6y_cukPSONo7" }, "source": [ "Let's consider the example of [`StandardCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#standardcheckpointhandler). This class is paired with [`StandardSave`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html#standardsave) and [`StandardRestore`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html#standardsave)." ] }, { "cell_type": "markdown", "metadata": { "id": "gIZP-4SOQUfQ" }, "source": [ "`StandardSave` allows specifying the `item` argument, which is the PyTree to be saved using Tensorstore. It also includes `save_args`, which is an optional `PyTree` with a structure matching `item`. Each leave is a `ocp.type_handlers.SaveArgs` object, which can be used to customize things like the `dtype` of the saved array." ] }, { "cell_type": "markdown", "metadata": { "id": "SjN6_Q52RSy5" }, "source": [ "`StandardRestore` only has one possible argument, the `item`, which is a PyTree of concrete or abstract arrays matching the structure of the checkpoint. This is optional, and the checkpoint will be restored exactly as saved if no argument is provided." ] }, { "cell_type": "markdown", "metadata": { "id": "zoXth95fRhCj" }, "source": [ "In general, other `CheckpointHandler`s may have other arguments, and the contract can be discerned by looking at the corresponding `CheckpointArgs`. Additionally, `CheckpointHandler`s can be [customized](https://orbax.readthedocs.io/en/latest/custom_handlers.html) for specific needs." ] }, { "cell_type": "markdown", "metadata": { "id": "Oo7R2xhseVme" }, "source": [ "### CompositeCheckpointHandler" ] }, { "cell_type": "markdown", "metadata": { "id": "VUPjUCozeXKj" }, "source": [ "A special case of `CheckpointHandler` the [`CompositeCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#compositecheckpointhandler). While `CheckpointHandler`s are typically expected to deal with a single object, `CompositeCheckpointHandler` is explicitly designed for delegating save/restore logic for multiple distinct objects to separate `CheckpointHandler`s." ] }, { "cell_type": "markdown", "metadata": { "id": "kCQvXmjDixaQ" }, "source": [ "At minimum, `CompositeCheckpointHandler` must be initialized with a series of item names, which are used to differentiate distinct items. In many cases, you do not need to manually specify the delegated `CheckpointHandler` instance for a particular item up front. Here's an example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyZ6k_Hfiwqj" }, "outputs": [], "source": [ "import orbax.checkpoint as ocp\n", "from etils import epath\n", "\n", "path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rIbkkjNijOd1" }, "outputs": [], "source": [ "handler = ocp.CompositeCheckpointHandler('state', 'metadata', 'dataset')" ] }, { "cell_type": "markdown", "metadata": { "id": "QeItizMGjSxl" }, "source": [ "Now, it will be possible to use this handler (in conjunction with a `Checkpointer`!) to save and restore 3 distinct objects, named 'state', 'metadata', and 'dataset'." ] }, { "cell_type": "markdown", "metadata": { "id": "52TLxevMjf2v" }, "source": [ "When we call save or restore, it is necessary to specify a `CheckpointArgs` subclass for each item. This is used to infer the desired `CheckpointHandler`. For example, if we specify `StandardSave`, the object will get saved using `StandardCheckpointHandler`. Per-item `CheckpointArgs` must be wrapped in the `CheckpointArgs` for `CompositeCheckpointHandler`, which is `ocp.args.Composite`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OmlDN7CvjwWC" }, "outputs": [], "source": [ "state = {'layer0': {'bias': 0, 'weight': 1}}\n", "metadata = {'version': 1.0}\n", "dataset = {'my_data': 2}\n", "\n", "checkpointer = ocp.Checkpointer(handler)\n", "checkpointer.save(\n", " path / 'composite_checkpoint',\n", " ocp.args.Composite(\n", " state=ocp.args.StandardSave(state),\n", " metadata=ocp.args.JsonSave(metadata),\n", " dataset=ocp.args.JsonSave(dataset),\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "V8dqRBf7oJtk" }, "source": [ "When restoring, we can retrieve a subset of the items:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zjP4twpaoUD2" }, "outputs": [], "source": [ "# Restore all items:\n", "checkpointer.restore(path / 'composite_checkpoint')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ipu6YWEGobEq" }, "outputs": [], "source": [ "# Restore some items, but not all:\n", "checkpointer.restore(\n", " path / 'composite_checkpoint',\n", " ocp.args.Composite(\n", " state=ocp.args.StandardRestore(),\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y9WLzMROohYq" }, "outputs": [], "source": [ "# Restore some items, and specify optional arguments for restoration:\n", "checkpointer.restore(\n", " path / 'composite_checkpoint',\n", " ocp.args.Composite(\n", " state=ocp.args.StandardRestore(state),\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "_z7r53zuJYUO" }, "source": [ "As noted above, every `CheckpointHandler` has one or two `CheckpointArgs` subclasses which represent save and restore arguments. `CompositeCheckpointHandler` is no exception. Both save and restore arguments are represented by `args.Composite`, which is basically just a wrapper for other `CheckpointArgs` passed to the the sub-handlers.\n", "\n", "Similarly, the return value of `restore` is also `args.Composite`.\n", "\n", "The `args.Composite` class is basically just a key-value store similar to a dictionary." ] }, { "cell_type": "markdown", "metadata": { "id": "ezKPjN1IRy6v" }, "source": [ "## Checkpointer Layer" ] }, { "cell_type": "markdown", "metadata": { "id": "pAzQ5wWPUo96" }, "source": [ "Conceptually, the [`Checkpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html) exists to work with a single checkpoint that exists at a single path. It is no frills (relative to `CheckpointManager`) but guarantees atomicity and allows for asynchronous saving via `AsyncCheckpointer`." ] }, { "cell_type": "markdown", "metadata": { "id": "ss-thT6SMbJN" }, "source": [ "Async checkpointing provided via `AsyncCheckpointer` can often help to realize significant resource savings and training speedups because write to disk happens in a background thread. See [here](https://orbax.readthedocs.io/en/latest/async_checkpointing.html) for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "OujTKCY-U_v9" }, "source": [ "As mentioned above, a `Checkpointer` is always combined with a `CheckpointHandler`. You can think of the `CheckpointHandler` as providing a configuration that tells the `Checkpointer` what serialization logic to use to deal with a particular object." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4n0Y-y5fMG4L" }, "outputs": [], "source": [ "ckptr = ocp.Checkpointer(ocp.JsonCheckpointHandler())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jXoImFXCMJ6a" }, "outputs": [], "source": [ "ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())" ] }, { "cell_type": "markdown", "metadata": { "id": "t9POe_jwMNI7" }, "source": [ "Orbax provides some shorthand checkpointers, such as `StandardCheckpointer`, which is just `Checkpointer(StandardCheckpointHandler())`." ] }, { "cell_type": "markdown", "metadata": { "id": "s8Sn6VJ3MtsM" }, "source": [ "While most `Checkpointer`/`CheckpointHandler` pairs deal with a single object that is saved and restored, pairing a `Checkpointer` with `CompositeCheckpointHandler` allows dealing with multiple distinct objects at once (see above)." ] }, { "cell_type": "markdown", "metadata": { "id": "ackGl1L5M9w5" }, "source": [ "## CheckpointManager Layer" ] }, { "cell_type": "markdown", "metadata": { "id": "OcusSC4jNBxz" }, "source": [ "The most high-level API layer provided by Orbax is the [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html). This is the API of choice for users dealing with a series of checkpoints denoted as steps in the context of a training run." ] }, { "cell_type": "markdown", "metadata": { "id": "Zhq4_JYANyE_" }, "source": [ "`CheckpointManagerOptions` allows customizing the behavior of the `CheckpointManager` along various dimensions. A partial list of important customization options is given below. See the API reference for a complete list.\n", "\n", "* `save_interval_steps`: An interval at which to save checkpoints.\n", "* `max_to_keep`: Starts to delete checkpoints when more than this number are present. Depending on other settings, more checkpoints than this number may be present at any given time.\n", "* `step_format_fixed_length`: Formats with leading `n` digits. This can make visually examining the checkpoints in sorted order easier.\n", "* `cleanup_tmp_directories`: Automatically cleans up existing temporary/incomplete directories when the `CheckpointManager` is created.\n", "* `read_only`: If True, then checkpoints save and delete are skipped. Restore works as usual.\n", "* `enable_async_checkpointing`: True by default. Be wary of turning off, as save performance may be significantly impacted." ] }, { "cell_type": "markdown", "metadata": { "id": "Lmbp14qjjl9-" }, "source": [ "If dealing with a single checkpointable object, like a train state, `CheckpointManager` can be created as follows:\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "e839OeUkAk2w" }, "source": [ "Note that `CheckpointManager` always saves asynchronously, unless you set `enable_async_checkpointing=False` in `CheckpointManagerOptions`. Make sure to use `wait_until_finished()` if you need to block until a save is complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5XQiz0MXj_Nc" }, "outputs": [], "source": [ "import jax\n", "\n", "directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-single/')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MXhFQTo9jrHc" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions(\n", " save_interval_steps=2,\n", " max_to_keep=2,\n", " # other options\n", ")\n", "mngr = ocp.CheckpointManager(\n", " directory,\n", " options=options,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BS4NMHwukWpx" }, "outputs": [], "source": [ "num_steps = 5\n", "state = {'layer0': {'bias': 0, 'weight': 1}}\n", "\n", "def train_step(state):\n", " return jax.tree_util.tree_map(lambda x: x + 1, state)\n", "\n", "for step in range(num_steps):\n", " state = train_step(state)\n", " mngr.save(step, args=ocp.args.StandardSave(state))\n", "mngr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xiuLapZ5kyAo" }, "outputs": [], "source": [ "mngr.latest_step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ofu5ML1Ak1du" }, "outputs": [], "source": [ "mngr.all_steps()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MelC1FBXk2sH" }, "outputs": [], "source": [ "mngr.restore(mngr.latest_step())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "curPlwXCk8IV" }, "outputs": [], "source": [ "# Restore with additional arguments, like dtype or sharding.\n", "target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}\n", "mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(target_state))" ] }, { "cell_type": "markdown", "metadata": { "id": "htixsGLrmAD-" }, "source": [ "If we're dealing with multiple items, we need to provide `item_names` when configuring the `CheckpointManager`. Internally, `CheckpointManager` is using `CompositeCheckpointHandler`, so the information above also applies here." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jOdOthW2mqgC" }, "outputs": [], "source": [ "directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-multiple/')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x-6_0oYdmuSe" }, "outputs": [], "source": [ "options = ocp.CheckpointManagerOptions(\n", " save_interval_steps=2,\n", " max_to_keep=2,\n", " # other options\n", ")\n", "mngr = ocp.CheckpointManager(\n", " directory,\n", " options=options,\n", " item_names=('state', 'extra_metadata'),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ab8CgvfVm8_c" }, "outputs": [], "source": [ "num_steps = 5\n", "state = {'layer0': {'bias': 0, 'weight': 1}}\n", "extra_metadata = {'version': 1.0, 'step': 0}\n", "\n", "def train_step(step, _state, _extra_metadata):\n", " return jax.tree_util.tree_map(lambda x: x + 1, _state), {**_extra_metadata, **{'step': step}}\n", "\n", "for step in range(num_steps):\n", " state, extra_metadata = train_step(step, state, extra_metadata)\n", " mngr.save(\n", " step,\n", " args=ocp.args.Composite(\n", " state=ocp.args.StandardSave(state),\n", " extra_metadata=ocp.args.JsonSave(extra_metadata),\n", " )\n", " )\n", "mngr.wait_until_finished()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kh00Mr-fnePY" }, "outputs": [], "source": [ "# Restore exactly as saved\n", "result = mngr.restore(mngr.latest_step())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mGAeYMEcnzJg" }, "outputs": [], "source": [ "result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UhOid4Pcn0Xm" }, "outputs": [], "source": [ "result.state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rWcK2-n3n1Dp" }, "outputs": [], "source": [ "result.extra_metadata" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SueZMwQxnvS3" }, "outputs": [], "source": [ "# Skip `state` when restoring.\n", "# Note that it is possible to provide `extra_metadata=None` because we already\n", "# saved using `JsonSave`. This is internally cached, so we know it uses JSON\n", "# logic to save and restore. If you had called `restore` without first calling\n", "# `save`, however, it would have been necessary to provide\n", "# `ocp.args.JsonRestore`.\n", "mngr.restore(mngr.latest_step(), args=ocp.args.Composite(extra_metadata=None))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mRxTflYUnh27" }, "outputs": [], "source": [ "# Restore with additional arguments, like dtype or sharding.\n", "target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}\n", "mngr.restore(mngr.latest_step(), args=ocp.args.Composite(\n", " state=ocp.args.StandardRestore(target_state), extra_metadata=None)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "6Yori7xSdYtw" }, "source": [ "There are some scenarios when the mapping between items and respective `CheckpointHandler`s need to be provided at the time of creating a `CheckpointManager` instance.\n", "\n", "CheckpointManager constructor argument, `item_handlers`, enables to resolve those scenarios. Please see [Using the Refactored CheckpointManager API](https://orbax.readthedocs.io/en/latest/api_refactor.html) for the details.\n", "\n" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "", "kind": "local" }, "private_outputs": true, "provenance": [ { "file_id": "1a4YsNUsbx1IZj-mRvbexRg6sk3dmdHZx", "timestamp": 1703027264216 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }