{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "yufmhxydT045" }, "source": [ "# Debugging Guide" ] }, { "cell_type": "markdown", "metadata": { "id": "n9YETCDlT6q-" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "7y9nJBKiVbmP" }, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QaD4otJQg6nJ" }, "outputs": [], "source": [ "import jax\n", "import numpy as np\n", "from etils import epath\n", "import orbax.checkpoint as ocp\n", "import tensorstore as ts\n", "import collections\n", "import operator\n", "import asyncio" ] }, { "cell_type": "markdown", "metadata": { "id": "OmtbJzLO9Z8z" }, "source": [ "### Create Sample Checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfm1K_z19ejF" }, "outputs": [], "source": [ "state = {\n", " 'a': {\n", " 'x': np.arange(2 ** 24),\n", " 'y': np.arange(1024),\n", " },\n", " 'b': np.ones(8),\n", " 'c': 42,\n", "}\n", "\n", "default_param_name = 'a.x'\n", "default_path = epath.Path('/tmp/checkpoint')\n", "if default_path.exists():\n", " default_path.rmtree()\n", "with ocp.StandardCheckpointer() as ckptr:\n", " ckptr.save(default_path, state)" ] }, { "cell_type": "markdown", "metadata": { "id": "k67ycpnbV1Qx" }, "source": [ "## Checkpoint Size" ] }, { "cell_type": "markdown", "metadata": { "id": "72Qb7Emvye7z" }, "source": [ "### Actual Size on Disk" ] }, { "cell_type": "markdown", "metadata": { "id": "2yZQqxlRyo9P" }, "source": [ "This is the actual size of the checkpoint on disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TZItnibuV727" }, "outputs": [], "source": [ "path = \"\" # @param {type:\"string\"}\n", "path = default_path or epath.Path(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_cmsf2A2a0-" }, "outputs": [], "source": [ "async def disk_usage(path: epath.Path) -\u003e int:\n", " \"\"\"Returns the size of the checkpoint on disk.\n", "\n", " Note: this uses recurision because Orbax checkpoint directories are never\n", " more than a few levels deep.\n", "\n", " Args:\n", " path: The path to the checkpoint.\n", " Returns:\n", " The size of the checkpoint on disk.\n", " \"\"\"\n", "\n", " async def helper(p):\n", " if p.is_dir():\n", " return await disk_usage(p)\n", " else:\n", " stat = await ocp.path.async_utils.async_stat(path)\n", " return stat.length\n", "\n", " futures = []\n", " for p in path.iterdir():\n", " futures.append(helper(p))\n", " return sum(await asyncio.gather(*futures))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QpxBYU1z2fE3" }, "outputs": [], "source": [ "print('{0:0.3f} GB'.format(float(asyncio.run(disk_usage(path))) / 1e9))" ] }, { "cell_type": "markdown", "metadata": { "id": "i_1s-wLpykXf" }, "source": [ "### Implied Size from Checkpoint Metadata" ] }, { "cell_type": "markdown", "metadata": { "id": "DMNc1l7hyuEW" }, "source": [ "Users sometimes run into a problem where the checkpoint size on disk seems much larger or smaller than we would expect based on the model itself. Determining the implied size of the checkpoint based on the checkpoint's own metadata and cross-referencing it against the actual on-disk size can provide some insight.\n", "\n", "The actual size on disk is typically expected to be somewhat smaller than the implied size, due to compression." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KNm7aLDCheiN" }, "outputs": [], "source": [ "path = \"\" # @param {type:\"string\"}\n", "path = default_path or epath.Path(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hXpl9tHgV3W-" }, "outputs": [], "source": [ "metadata = ocp.StandardCheckpointer().metadata(path)\n", "size_counts = collections.defaultdict(int)\n", "\n", "def get_arr_bytes(meta):\n", " dtype = meta.dtype\n", " shape = meta.shape\n", " size_counts[dtype] += 1\n", " return np.prod(shape) * np.dtype(dtype).itemsize\n", "\n", "total_bytes = jax.tree.reduce(operator.add, jax.tree.map(get_arr_bytes, metadata))\n", "print('{0:0.3f} GB'.format(float(total_bytes) / 1e9))\n", "print()\n", "print('leaf dtype counts:')\n", "for dtype, count in size_counts.items():\n", " print(f'{dtype}: {count}')" ] }, { "cell_type": "markdown", "metadata": { "id": "w0Heru4RSIkV" }, "source": [ "## Tree Metadata" ] }, { "cell_type": "markdown", "metadata": { "id": "dzOi1xNvSNpX" }, "source": [ "Inspecting the tree structure of the checkpoint is crucial, as it allows you to verify that the parameters present in the checkpoint are correct, to say nothing of the array metadata associated with the parameter." ] }, { "cell_type": "markdown", "metadata": { "id": "kuFzHLStbG6K" }, "source": [ "The following can be useful when debugging errors where the loading code was searching for a particular parameter that was not found. A few things could be going wrong here:\n", "* The parameter is missing from the checkpoint. Ensure the checkpoint is what you think it is, and that it has the correct parameters.\n", "* If running model surgery, the transformations may be misconfigured. See below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_efiuykVSzPd" }, "outputs": [], "source": [ "path = \"\" # @param {type:\"string\"}\n", "path = default_path or epath.Path(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Il_7lTqnWhnK" }, "outputs": [], "source": [ "metadata = ocp.StandardCheckpointer().metadata(path)\n", "metadata_contents = ['.'.join(k) for k in ocp.tree.to_flat_dict(metadata)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PFkiZfkNY9Ml" }, "outputs": [], "source": [ "# Here are the parameters present in the checkpoint tree.\n", "for p in metadata_contents:\n", " print(p)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1acr4zM2S5-B" }, "outputs": [], "source": [ "# Note: instead of \"file\", use:\n", "# - \"gfile\" on Google-internal filesystems.\n", "# - \"gs\" on GCS (do not repeat the \"gs://\" prefix)\n", "ts_contents = ts.KvStore.open({\"driver\": \"ocdbt\", \"base\": f\"file://{path.as_posix()}\"}).result().list().result()\n", "ts_contents = [p.decode(\"utf-8\") for p in ts_contents]\n", "ts_contents = [p.replace('.zarray', '')[:-1] for p in ts_contents if '.zarray' in p]\n", "\n", "# We can assert that the parameters tracked by the metadata file are\n", "# the same as those tracked by Tensorstore. If there is a discrepancy, there may\n", "# be a deeper underlying problem.\n", "\n", "assert len(metadata_contents) == len(ts_contents) and sorted(metadata_contents) == sorted(ts_contents)" ] }, { "cell_type": "markdown", "metadata": { "id": "110FmWk00pDW" }, "source": [ "## Individual Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mKSaIEAWgD8g" }, "outputs": [], "source": [ "path = \"\" # @param {type:\"string\"}\n", "# The `param_name` can be obtained by inspecting tree metadata (see above).\n", "param_name = \"\" # @param {type:\"string\"}\n", "path = default_path or epath.Path(path)\n", "param_name = default_param_name or param_name" ] }, { "cell_type": "markdown", "metadata": { "id": "Fpu1uKS9f_sq" }, "source": [ "### Value Metadata" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "42-RRPsOgKIa" }, "outputs": [], "source": [ "metadata = ocp.StandardCheckpointer().metadata(path)\n", "value_metadata = {'.'.join(k): v for k, v in ocp.tree.to_flat_dict(metadata).items()}[param_name]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nrnEPjIrhin0" }, "outputs": [], "source": [ "print(f'shape: {value_metadata.shape}')\n", "print(f'dtype: {value_metadata.dtype}')" ] }, { "cell_type": "markdown", "metadata": { "id": "YlLSc1vpf-Th" }, "source": [ "### Array Value" ] }, { "cell_type": "markdown", "metadata": { "id": "cXgmdL_s0zbM" }, "source": [ "It can often be helpful to check the raw value of a particular parameter as saved in the checkpoint. This can be done to establish the correctness of a parameter as saved, to eliminate any possibility that saving was done incorrectly for the given parameter (or that the checkpoint has been corrupted). This can help confine the bounds of debugging to restoration." ] }, { "cell_type": "markdown", "metadata": { "id": "-lV94I_jnHdL" }, "source": [ "CAUTION: The read below loads the entire array into memory. For very large arrays, this could result in OOM. To load a smaller slice of the array, simply index into the `TensorStore` object (`t`), like this: `t[:2, :4].read().result()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "945T7n5nWNLT" }, "outputs": [], "source": [ "ParamInfo = ocp.type_handlers.ParamInfo\n", "ts_context = ts.Context({\n", " 'file_io_concurrency': {'limit': 128},\n", " 'cache_pool#ocdbt': {'total_bytes_limit': 100000000},\n", "})\n", "\n", "info = ParamInfo(name=param_name, path=path / param_name, parent_dir=path, is_ocdbt_checkpoint=True, use_zarr3=False)\n", "tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)\n", "\n", "t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()\n", "arr = t.read().result()\n", "print(arr)" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "name": "debug_guide.ipynb", "private_outputs": true, "provenance": [ { "file_id": "1eoUI8u1JTufrxCRl5XRm3qxonwWAXXG3", "timestamp": 1723222392258 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }