{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "bgBmt4wXccHa" }, "source": [ "# Customizing Checkpointing Logic\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IDfsplyYcUVR" }, "source": [ "This page is relevant if your model state contains custom leaves in a [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html), or doesn't use PyTree at all.\n", "\n", "If your model uses PyTree but has custom leaves, read the **`TypeHandler`** section to see how register the custom type with `PyTreeCheckpointHandler`.\n", "\n", "If your model doesn't use PyTree or if you want to implement different serialization/deserialization logic, skip to the **`CheckpointHandler`** section." ] }, { "cell_type": "markdown", "metadata": { "id": "v1I0P_aKcgkw" }, "source": [ "## Setup\n", "\n", "If you're running this guide in a notebook, make sure to run this cell first." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DLL5q3xCpB91" }, "outputs": [], "source": [ "import asyncio\n", "from concurrent import futures\n", "from dataclasses import dataclass\n", "import functools\n", "import os\n", "import time\n", "from typing import Any, List, Optional, Sequence\n", "\n", "from etils import epath\n", "import numpy as np\n", "import orbax.checkpoint as ocp\n", "\n", "ParamInfo = ocp.type_handlers.ParamInfo\n", "Metadata = ocp.metadata.Metadata" ] }, { "cell_type": "markdown", "metadata": { "id": "1tRDU9NYcyLG" }, "source": [ "## TypeHandler\n", "\n", "`PyTreeCheckpointHandler` walks through the input PyTree and uses registered `TypeHandlers` to serialize/deserialize the leaves. If your custom model state is stored within the leaves of a PyTree, implement a `TypeHandler` and use it with `PyTreeCheckpointHandler`." ] }, { "cell_type": "markdown", "metadata": { "id": "HGFk3Jo77SFs" }, "source": [ "**Standard TypeHandlers**\n", "\n", "Orbax includes pre-defined `TypeHandlers` for saving certain types:\n", "\n", "* `ArrayHandler`: jax.Array\n", "* `NumpyHandler`: np.ndarray\n", "* `ScalarHandler`: int, float\n", "* `StringHandler`: str\n", "\n", "These default implementations all use Tensorstore to serialize and deserialize data except for `StringHandler` which serializes to JSON.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "nQvutZiWc2JD" }, "source": [ "### Custom serialization / deserialization\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "6U0wQpZhc6BW" }, "source": [ "To implement a custom `TypeHandler`, we must define the async serialize and deserialize methods (the section \"Async vs Non-Async\" lists reasons why these methods should be asynchronous). The new `TypeHandler` is then registered so that the `PyTreeCheckpointHandler` knows to use this handler when there is a `MyState` leaf in the PyTree.\n", "\n", "The inputs to the `TypeHandler` are batched to allow for performance optimizations in certain cases. `PyTreeCheckpointHandler` groups all leaves of the same type and dispatches them all in one-per-type batch." ] }, { "cell_type": "markdown", "metadata": { "id": "WIKDcwCbkm2H" }, "source": [ "The example below defines a `TypeHandler` for a custom dataclass that stores multiple numpy arrays." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gdqC12kGq3lO" }, "outputs": [], "source": [ "@dataclass\n", "class MyState:\n", " a: np.array\n", " b: np.array\n", "\n", "\n", "# Make sure to only run this cell once, otherwise a new `MyState` dataclass will\n", "# be created which could mess up Python issubclass/isinstance checks." ] }, { "cell_type": "markdown", "metadata": { "id": "yvPiuDRzkpeR" }, "source": [ "Here is a possible `TypeHandler` implementation for `MyState`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ftwzKJ9Thg_b" }, "outputs": [], "source": [ "class MyStateHandler(ocp.type_handlers.TypeHandler):\n", " \"\"\"Serializes MyState to the numpy npz format.\"\"\"\n", "\n", " def __init__(self):\n", " self._executor = futures.ThreadPoolExecutor(max_workers=1)\n", "\n", " def typestr(self) -\u003e str:\n", " return 'MyState'\n", "\n", " async def serialize(\n", " self,\n", " values: Sequence[MyState],\n", " infos: Sequence[ParamInfo],\n", " args: Optional[Sequence[ocp.SaveArgs]],\n", " ) -\u003e List[futures.Future]:\n", " del args # Unused in this example.\n", " futures = []\n", " for value, info in zip(values, infos):\n", " # make sure the per-key directory is present as OCDBT doesn't create one\n", " info.path.mkdir(exist_ok=True)\n", " futures.append(\n", " self._executor.submit(\n", " functools.partial(_write_state, value, info.path)\n", " )\n", " )\n", " return futures\n", "\n", " async def deserialize(\n", " self,\n", " infos: Sequence[ParamInfo],\n", " args: Optional[Sequence[ocp.RestoreArgs]] = None,\n", " ) -\u003e MyState:\n", " del args # Unused in this example.\n", " futures = []\n", " for info in infos:\n", " futures.append(\n", " await asyncio.get_event_loop().run_in_executor(\n", " self._executor, functools.partial(_from_state, info.path)\n", " )\n", " )\n", " return await asyncio.gather(*futures)\n", "\n", " async def metadata(self, infos: Sequence[ParamInfo]) -\u003e Sequence[Metadata]:\n", " # This method is explained in a separate section.\n", " return [Metadata(name=info.name, directory=info.path) for info in infos]\n", "\n", "\n", "def _write_state(state: MyState, path: epath.Path) -\u003e str:\n", " path = path / 'my_state.npz'\n", " np.savez(path, a=state.a, b=state.b)\n", " return path\n", "\n", "\n", "async def _from_state(path: epath.Path) -\u003e MyState:\n", " data = np.load(path / 'my_state.npz')\n", " return MyState(a=data['a'], b=data['b'])\n", "\n", "\n", "ocp.type_handlers.register_type_handler(\n", " MyState, MyStateHandler(), override=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zdLIa1QshW3S" }, "outputs": [], "source": [ "assert ocp.type_handlers.has_type_handler(MyState)" ] }, { "cell_type": "markdown", "metadata": { "id": "vxkjUUprdGiy" }, "source": [ "Here is `MyStateHandler` in action:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ynqShvJYpwOq" }, "outputs": [], "source": [ "my_tree = {\n", " 'state': {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])},\n", " 'my_state': MyState(a=np.array([10, 20, 30]), b=np.array([40, 50, 60])),\n", "}\n", "\n", "\n", "checkpointer = ocp.Checkpointer(\n", " ocp.PyTreeCheckpointHandler()\n", ")\n", "path = epath.Path('/tmp/my_checkpoints/')\n", "\n", "# Clear older checkpoints from directory.\n", "# Checkpointer.save will fail if path already exists, unless `force=True`\n", "if path.exists():\n", " path.rmtree()\n", "path.mkdir()\n", "\n", "checkpointer.save(path / 'my_tree', my_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BrMXn148pnPw" }, "outputs": [], "source": [ "!echo \"Files in path:\" $(ls /tmp/my_checkpoints)\n", "!echo \"Files in 'my_tree':\" $(ls /tmp/my_checkpoints/my_tree)\n", "!echo \"Files in 'my_tree/my_state':\" $(ls /tmp/my_checkpoints/my_tree/my_state)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uojtpS2GtdAv" }, "outputs": [], "source": [ "checkpointer.restore(path / 'my_tree')" ] }, { "cell_type": "markdown", "metadata": { "id": "VUOiT3hej_nA" }, "source": [ "### Metadata\n", "\n", "The `metadata()` method is used for inspecting existing checkpoints and is generally implemented to be less costly than a full restore. Some example use cases are determining whether the restored values can fit in the available memory, getting the checkpointed PyTree structure to extract specific subtrees, or validating whether the shapes and dtypes of the values match with your model data.\n", "\n", "In the previous example, `MyStateHandler` returned the default `Metadata()` object since the `TypeHandler` interface requires it. However, we recommend completing this implementation especially if the custom type targets general users." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DybgM_VHt2wD" }, "outputs": [], "source": [ "# 'my_state' returns a default Metadata object.\n", "checkpointer.metadata(path / 'my_tree')" ] }, { "cell_type": "markdown", "metadata": { "id": "bWcb3dFEu_27" }, "source": [ "Example implementation of `MyStateHandler.metadata:`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pgybXKy2u_S6" }, "outputs": [], "source": [ "# Define a metadata class.\n", "class MyStateMetadata(Metadata):\n", "\n", " def __init__(\n", " self,\n", " a_shape: np.shape,\n", " b_shape: np.shape,\n", " **kwargs,\n", " ):\n", " super().__init__(**kwargs)\n", " self.a_shape = a_shape\n", " self.b_shape = b_shape\n", "\n", "\n", "class MyStateHandlerWithMetdata(MyStateHandler):\n", "\n", " async def metadata(\n", " self, infos: Sequence[ParamInfo]\n", " ) -\u003e ocp.metadata.Metadata:\n", " metadata = []\n", " for info in infos:\n", " metadata.append(\n", " await asyncio.get_event_loop().run_in_executor(\n", " self._executor, functools.partial(_read_metadata, info)\n", " )\n", " )\n", " return await asyncio.gather(*metadata)\n", "\n", "\n", "async def _read_metadata(info: ParamInfo) -\u003e MyStateMetadata:\n", " # This function reads the entire state, but can be more optimally defined\n", " # by reading the header from the npz file. Another option is collectively\n", " # gathering all of the metadata info during serialization, and writing it to\n", " # a file. Since metadata is generally pretty small, it's better to write\n", " # to a single file rather than one for each value.\n", " result = await _from_state(info.path)\n", " return MyStateMetadata(\n", " a_shape=result.a.shape,\n", " b_shape=result.b.shape,\n", " name='my_state',\n", " directory=info.path,\n", " )\n", "\n", "\n", "ocp.type_handlers.register_type_handler(\n", " MyState, MyStateHandlerWithMetdata(), override=True\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "0NXBhj7e5xP0" }, "source": [ "Now check the metadata, the PyTree should now contain `MyStateMetadata`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nrwP--Sdvh6o" }, "outputs": [], "source": [ "checkpointer = ocp.PyTreeCheckpointer()\n", "checkpointer.metadata(path / 'my_tree')" ] }, { "cell_type": "markdown", "metadata": { "id": "UN1nwFaI_KE5" }, "source": [ "In this example, we didn't need to re-save the checkpoint using the newly registered `MyStateHandlerWithMetdata` TypeHandler, because the class doesn't write new files into the checkpoint." ] }, { "cell_type": "markdown", "metadata": { "id": "035JCngndVGt" }, "source": [ "## CheckpointHandler\n", "\n", "If your state is not stored within a PyTree, or if you'd like to customize more aspects of checkpointing, implement `CheckpointHandler`. `CheckpointHandlers` operate on the entire object so you have a lot of flexibility on how to save and restore the object.\n", "\n", "As of `orbax-checkpoint-0.5.0`, CheckpointHandler API has changed. This page shows a side-by-side comparison of the old and new APIs.\n", "\n", "**The legacy APIs are deprecated. Please ensure you are using the new style.**\n", "\n", "**Example**\n", "\n", "Serializing the same dataclass used in the `TypeHandler` example:\n", "\n", "```python\n", "@dataclass\n", "class MyState:\n", " a: np.array\n", " b: np.array\n", "```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lTNIJHFfLmsZ" }, "outputs": [], "source": [ "state = MyState(a=np.array([1.0, 1.5]), b=np.array([3, 4, 5]))" ] }, { "cell_type": "markdown", "metadata": { "id": "IXcHLt4jK7t5" }, "source": [ "### Before" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CFxuofMrRhWK" }, "outputs": [], "source": [ "import glob\n", "import json\n", "\n", "\n", "class LegacyMyStateCheckpointHandler(ocp.CheckpointHandler):\n", "\n", " def save(\n", " self,\n", " directory: epath.Path,\n", " item: MyState,\n", " # You can define any argument here:\n", " use_npz=True,\n", " **kwargs,\n", " ):\n", " if use_npz:\n", " np.savez(directory / 'my_state.npz', a=item.a, b=item.b)\n", " else:\n", " with open(os.path.join(directory, 'my_state.json'), 'w') as f:\n", " f.write(json.dumps(dict(a=state.a.tolist(), b=state.b.tolist())))\n", "\n", " def restore(\n", " self,\n", " directory: epath.Path,\n", " item: Optional[Any] = None,\n", " # You can define any argument here as well.\n", " restore_as_dict=False,\n", " **kwargs,\n", " ) -\u003e Any:\n", " state_file = glob.glob(os.fspath(directory / '*.*'))[0]\n", " if state_file == 'my_state.npz':\n", " data = np.load(directory / 'my_state.npz')\n", " else:\n", " with open(state_file, 'r') as f:\n", " data = json.load(f)\n", " data['a'] = np.array(data['a'])\n", " data['b'] = np.array(data['b'])\n", " if restore_as_dict:\n", " return dict(a=data['a'], b=data['b'])\n", " return MyState(a=data['a'], b=data['b'])\n", "\n", " def metadata(self, directory: epath.Path) -\u003e Optional[Any]:\n", " \"\"\"Returns metadata about the saved item.\"\"\"\n", " # In this example, the State is restored entirely, but this can be\n", " # optimized. For example, but writing a `metadata` file in `self.save()`,\n", " # and reading the file in this method.\n", " result = self.restore(directory)\n", " return MyStateMetadata(\n", " a_shape=result.a.shape,\n", " b_shape=result.b.shape,\n", " name='my_state',\n", " directory=directory / 'my_state',\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "sqKWffqRMOXu" }, "source": [ "### After" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BJ9-9pdxMmaW" }, "outputs": [], "source": [ "import glob\n", "import json\n", "\n", "\n", "class MyStateCheckpointHandler(ocp.CheckpointHandler):\n", "\n", " def save(\n", " self,\n", " directory: epath.Path,\n", " args: 'MyStateSave',\n", " ):\n", " if args.use_npz:\n", " np.savez(directory / 'my_state.npz', a=args.item.a, b=args.item.b)\n", " else:\n", " with open(os.path.join(directory, 'my_state.json'), 'w') as f:\n", " f.write(\n", " json.dumps(dict(a=args.item.a.tolist(), b=args.item.b.tolist()))\n", " )\n", "\n", " def restore(\n", " self,\n", " directory: epath.Path,\n", " args: 'MyStateRestore',\n", " ) -\u003e Any:\n", " state_file = glob.glob(os.fspath(directory / '*.*'))[0]\n", " if state_file == 'my_state.npz':\n", " data = np.load(directory / 'my_state.npz')\n", " else:\n", " with open(state_file, 'r') as f:\n", " data = json.load(f)\n", " data['a'] = np.array(data['a'])\n", " data['b'] = np.array(data['b'])\n", " if args.restore_as_dict:\n", " return dict(a=data['a'], b=data['b'])\n", " return MyState(a=data['a'], b=data['b'])\n", "\n", " def metadata(self, directory: epath.Path) -\u003e Optional[Any]:\n", " \"\"\"Returns metadata about the saved item.\"\"\"\n", " # In this example, the State is restored entirely, but this can be\n", " # optimized. For example, but writing a `metadata` file in `self.save()`,\n", " # and reading the file in this method.\n", " result = self.restore(directory, args=MyStateRestore())\n", " return MyStateMetadata(\n", " a_shape=result.a.shape,\n", " b_shape=result.b.shape,\n", " name='my_state',\n", " directory=directory / 'my_state',\n", " )\n", "\n", "\n", "@ocp.args.register_with_handler(MyStateCheckpointHandler, for_save=True)\n", "@dataclass\n", "class MyStateSave(ocp.args.CheckpointArgs):\n", " item: MyState\n", " use_npz: bool = True\n", "\n", "\n", "@ocp.args.register_with_handler(MyStateCheckpointHandler, for_restore=True)\n", "@dataclass\n", "class MyStateRestore(ocp.args.CheckpointArgs):\n", " restore_as_dict: bool = False" ] }, { "cell_type": "markdown", "metadata": { "id": "XIonT22wdf_5" }, "source": [ "These classes can be passed to create a new `Checkpointer`, which can be used to save or restore a new checkpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_wglMdEnDh9b" }, "outputs": [], "source": [ "legacy_path2 = epath.Path('/tmp/legacy-checkpoint-handler-example/')\n", "legacy_checkpointer = ocp.Checkpointer(LegacyMyStateCheckpointHandler())\n", "\n", "if legacy_path2.exists():\n", " legacy_path2.rmtree()\n", "legacy_path2.mkdir()\n", "\n", "legacy_checkpointer.save(legacy_path2 / 'state', state, use_npz=False)\n", "!echo \"Files in legacy checkpoint path:\" $(ls /tmp/legacy-checkpoint-handler-example/)\n", "!echo \"Files in legacy 'state' directory:\" $(ls /tmp/legacy-checkpoint-handler-example/state)\n", "\n", "print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))\n", "print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))\n", "print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tjyECSh6XYiL" }, "outputs": [], "source": [ "path2 = epath.Path('/tmp/checkpoint-handler-example/')\n", "checkpointer = ocp.Checkpointer(MyStateCheckpointHandler())\n", "\n", "if path2.exists():\n", " path2.rmtree()\n", "path2.mkdir()\n", "\n", "checkpointer.save(path2 / 'state', args=MyStateSave(item=state, use_npz=False))\n", "!echo \"Files in checkpoint path:\" $(ls /tmp/checkpoint-handler-example/)\n", "!echo \"Files in 'state' directory:\" $(ls /tmp/checkpoint-handler-example/state)\n", "\n", "print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))\n", "print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))\n", "print('metadata:',checkpointer.metadata(path2 / 'state'))" ] }, { "cell_type": "markdown", "metadata": { "id": "5kI9hF4xdjUI" }, "source": [ "## Async vs Non-Async\n", "Asynchronous checkpointing allows training to proceed during the I/O, which prevents expensive computational resources from stalling during the CPU writes. When possible, we highly recommend implementing async handlers.\n", "\n", "Async saving can be implemented by copying data to the corresponding worker CPU (if necessary), then parallelizing the writing tasks (e.g. by using the await keyword).\n", "\n", "`TypeHandler` deserialization should be defined using async to allow multiple objects to be deserialized at a time." ] }, { "cell_type": "markdown", "metadata": { "id": "gGLtI3Gk2zHS" }, "source": [ "### AsyncCheckpointHandler\n", "The `AsyncCheckpointHandler` interface adds a new `async_save` abstract method, and should be used with `AsyncCheckpointer` to write checkpoints asynchronously.\n", "\n", "Note that in the new style, `AsyncCheckpointHandler`'s `save()` and `async_save()` methods work on `args` instead of the legacy `item` etc arguments. Also, the `args` type needs to be registered against the `AsyncCheckpointHandler` concrete class.\n", "\n", "\n", "**Example**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2_3KeiUVYJYj" }, "outputs": [], "source": [ "class MyStateAsyncCheckpointHandler(ocp.AsyncCheckpointHandler, MyStateCheckpointHandler):\n", " def __init__(self):\n", " self._executor = futures.ThreadPoolExecutor(max_workers=1)\n", "\n", " def save(self, directory: epath.Path, args: MyStateSave):\n", " time.sleep(.5) # Artificially inflate the time spent in this method.\n", " super().save(directory, args)\n", "\n", " async def async_save(self, directory: epath.Path, args: MyStateSave):\n", " return [self._executor.submit(functools.partial(\n", " self.save, directory, args))]\n", "\n", " def close(self):\n", " self._executor.shutdown()\n", "\n", "# Register MyStateAsyncCheckpointHandler for MyStateSave and MyStateRestore.\n", "# NOTE: This registration will overwrite the previous one with MyStateCheckpointHandler.\n", "# It is just for illustrating this example and should be avoided in real world systems.\n", "ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_save=True)(MyStateSave)\n", "ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_restore=True)(MyStateRestore)\n", "\n", "path3 = epath.Path('/tmp/checkpoint-handler-async/')\n", "if path3.exists():\n", " path3.rmtree()\n", "path3.mkdir()\n", "\n", "async_checkpointer = ocp.AsyncCheckpointer(MyStateAsyncCheckpointHandler())\n", "async_checkpointer.save(path3 / 'async-state', args=MyStateSave(item=state))\n", "!echo \"directory contents: \"; ls /tmp/checkpoint-handler-async/" ] }, { "cell_type": "markdown", "metadata": { "id": "LjdunHROaSmJ" }, "source": [ "After the write is complete, the tmp folder is renamed to just `async_state`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vmEnM95gaXew" }, "outputs": [], "source": [ "async_checkpointer.wait_until_finished()\n", "async_checkpointer.close()\n", "\n", "!ls /tmp/checkpoint-handler-async/\n", "!ls /tmp/checkpoint-handler-async/async-state" ] } ], "metadata": { "colab": { "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }