{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "tKGlDfvNJM8R" }, "source": [ "# Optimized Checkpointing with Tensorstore" ] }, { "cell_type": "markdown", "metadata": { "id": "AIgEkzUjJRUt" }, "source": [ "Orbax relies on [Tensorstore](https://google.github.io/tensorstore/) to store\n", "individual arrays in a checkpoint. Tensorstore provides efficient, scalable library for reading and writing arrays.\n", "\n", "Until recently, however, our use of Tensorstore came with a few drawbacks. Chief among them was the fact that every parameter in a training state would be saved as a separate directory. This approach can be quite performant, even for models with hundreds of billions of parameters, *provided that model layers are stacked*. Otherwise, hundreds or thousands of directories may be created in the checkpoint.\n", "\n", "This fact can lead to very slow restore times, which is undesirable in and of itself, but is particularly painful for jobs that may be preempted frequently and need to restart, for example.\n", "\n", "While it is slightly less of a concern at save time, since writes to disk can happen asynchronously, the synchronous portion of the save can still be slow as many directories are created.\n", "\n", "Additionally, if individual parameters are small, storage may be wasted on filesystems with minimum file sizes." ] }, { "cell_type": "markdown", "metadata": { "id": "0zou_vidLNMd" }, "source": [ "## Towards an Improved Checkpoint Format\n", "\n", "The new, optimized checkpoint format provided by Orbax is backed by Tensorstore's [OCDBT](https://google.github.io/tensorstore/kvstore/ocdbt/index.html) driver.\n", "\n", "For practical purposes, this means that we will no longer store one parameter per directory, but will aggregate many parameters into a smaller set of large files.\n", "\n", "Empirically, we have observed substantial speed-ups in both save and restore when using the new format." ] }, { "cell_type": "markdown", "metadata": { "id": "uA_OTs-GEmty" }, "source": [ "### Save Performance (sec)" ] }, { "cell_type": "markdown", "metadata": { "id": "MUs2AbpcaG8U" }, "source": [ "\u003cimg src=https://orbax.readthedocs.io/en/latest/_static/checkpoint_benchmarks_save_ocdbt.png\u003e" ] }, { "cell_type": "markdown", "metadata": { "id": "FhyHhgwqEgmR" }, "source": [ "### Restore Performance (sec)" ] }, { "cell_type": "markdown", "metadata": { "id": "5Yrd8RQ3cZer" }, "source": [ "\u003cimg src=https://orbax.readthedocs.io/en/latest/_static/checkpoint_benchmarks_restore_ocdbt.png\u003e" ] }, { "cell_type": "markdown", "metadata": { "id": "kfO6-6ZENhEG" }, "source": [ "## Checkpoint Format\n", "\n", "Concretely, what does the new checkpoint format look like in comparison to the old?" ] }, { "cell_type": "markdown", "metadata": { "id": "j7veWqRzQ7Jb" }, "source": [ "### Old Format" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QOHzh4hAPUJQ" }, "outputs": [], "source": [ "f = \"\"\"\n", "path/to/my/checkpoint/dir/\n", " 0/\n", " state/\n", " layer0.param0/\n", " .zarray\n", " 0.0\n", " 0.1\n", " 1.0\n", " 1.1\n", " layer1.param0/\n", " .zarray\n", " 0.0\n", " ...\n", " \u003canother_item\u003e/\n", " ...\n", " 1/\n", " ...\n", " 2/\n", " ...\n", "\n", "Note: in this case, `0.0`, `0.1`, etc. provides an indication of how the array\n", "was sharded when originally saved.\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": { "id": "-0J3DhoFQ-x1" }, "source": [ "### New Format" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K5WqzSTpRBp3" }, "outputs": [], "source": [ "f = \"\"\"\n", "path/to/my/checkpoint/dir/\n", " 0/\n", " state/\n", " checkpoint # legacy msgpack file, stores tree structure\n", " tree_metadata # (maybe) new proto file, stores tree structure\n", " d/ # array data stored here\n", " 012b2c6e5c9d2a16c240a59d5f0f35c0\n", " 056e0816bdc5496a86251e58a0ec202b\n", " ...\n", " manifest.0000000000000001\n", " ...\n", " manifest.ocdbt\n", " \u003canother_item\u003e/\n", " ...\n", " 1/\n", " ...\n", " 2/\n", " ...\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": { "id": "hOf2vscWRF5u" }, "source": [ "## Enabling the new format" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ADRsxIkFRPZR" }, "outputs": [], "source": [ "import jax\n", "import tempfile\n", "import subprocess\n", "import os\n", "from etils import epath\n", "\n", "import orbax.checkpoint as ocp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gXXzqbco_UgX" }, "outputs": [], "source": [ "# Initialize PyTreeCheckpointHandler with `use_ocdbt=True`.\n", "# This option already defaults to True, so it's optional to pass it in.\n", "ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True))" ] }, { "cell_type": "markdown", "metadata": { "id": "HxLoTiZnAOvw" }, "source": [ "## Additional Notes\n", "\n", "All checkpoints previously produced by Orbax in the old format will still be\n", "readable when the new format is enabled. However, if a checkpoint is produced in the new format, it cannot be read if `use_ocdbt` is disabled." ] }, { "cell_type": "markdown", "metadata": { "id": "CLtrEkdPVXCD" }, "source": [ "## Custom Chunk Sizes\n", "Orbax [Zarr3](https://zarr-specs.readthedocs.io/en/latest/specs.html#specifications), a multidimensional array storage format, offers customizable chunk sizes in bytes for optimal memory management. The default chunk size, which corresponds one-to-one with the array shard size, can cause out-of-memory errors when reading on hosts with different sharding layouts. For example, this can often arise when arrays are saved with a fully-sharded sharding, but loaded with a fully-replicated sharding. To prevent this, set `chunk_byte_size` smaller than or equal to the anticipated read size. Anything above 1MB generally won't affect impact performance. Consider following example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rTEh-cIuG7WS" }, "outputs": [], "source": [ "# setup checkpoint data\n", "array_len = 8 * 1024\n", "key = jax.random.PRNGKey(0)\n", "key, subkey = jax.random.split(key)\n", "pytree = {\n", " 'a': jax.random.normal(subkey, (array_len, ), dtype=jax.numpy.float32), # 32KB\n", " 'b': jax.random.normal(subkey, (array_len * 2, ), dtype=jax.numpy.float32), # 64KB\n", "}\n", "\n", "# create save_args to customize the chunk_byte_size\n", "save_args = jax.tree_util.tree_map(\n", " lambda x: ocp.SaveArgs(\n", " chunk_byte_size=\n", " 1024, # 1KB\n", " ),\n", " pytree,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Z1nBQlTWQAf" }, "outputs": [], "source": [ "temp_dir = tempfile.TemporaryDirectory()\n", "mgr = ocp.CheckpointManager(epath.Path(temp_dir.name),\n", " item_handlers=ocp.PyTreeCheckpointHandler(use_zarr3=True)) # make sure zarr3 is enabled\n", "\n", "mgr.save(\n", " 0,\n", " args=ocp.args.PyTreeSave(\n", " pytree,\n", " save_args=save_args,\n", " ),\n", ")\n", "\n", "mgr.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "7pVdmRy1FwO1" }, "source": [ "## Customizing Data File Size\n", "To improve file I/O parallelism when working with large files on remote storages like GCS, use the `PyTreeSaveArgs.ocdbt_target_data_file_size` parameter to control the size of output files." ] }, { "cell_type": "markdown", "metadata": { "id": "qcFX1u4DR9UO" }, "source": [ "### BEFORE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0_P5jdDvR576" }, "outputs": [], "source": [ "def print_directory_file_size(dir: epath.Path) -\u003e None:\n", " print(f\"dir={dir}:\")\n", " for f in data_dir.iterdir():\n", " if f.is_file():\n", " print(f\"file={f.name}, size={f.stat().length}\")\n", "\n", "# continue from above example, examine the data file sizes\n", "data_dir = epath.Path(temp_dir.name) / '0'/ 'default'/ 'ocdbt.process_0'/ 'd'\n", "print_directory_file_size(data_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "T5f6o8T4SJW2" }, "source": [ "### AFTER" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VclDRyryHQ7M" }, "outputs": [], "source": [ "temp_dir = tempfile.TemporaryDirectory()\n", "mgr = ocp.CheckpointManager(temp_dir.name,\n", " item_handlers=ocp.PyTreeCheckpointHandler(use_zarr3=True))\n", "\n", "mgr.save(\n", " 0,\n", " args=ocp.args.PyTreeSave(\n", " pytree,\n", " save_args=save_args,\n", " ocdbt_target_data_file_size=10 * 1024, #10 KB, should be much larger than chunk_byte_size\n", " ),\n", ")\n", "\n", "mgr.close()\n", "\n", "data_dir = epath.Path(temp_dir.name) / '0'/ 'default'/ 'ocdbt.process_0'/ 'd'\n", "print_directory_file_size(data_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ntl88VZedCM1" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1bRC6p0AstPPAAW0AUoxHaOFEWpaW_GjI", "timestamp": 1688077923387 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }