{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "bgBmt4wXccHa" }, "source": [ "# Transformations\n" ] }, { "cell_type": "markdown", "metadata": { "id": "5Fmv8ozNKF8q" }, "source": [ "## Overview\n", "\n", "The transform_utils library provides functions to perform structural PyTree transformations, which can facilitate model surgery for finetuning, migrations between different checkpoint versions, etc.\n", "\n", "The API consists of a `Transform` class and an `apply_transformations` function." ] }, { "cell_type": "markdown", "metadata": { "id": "nPShn0w53Pjs" }, "source": [ "#### apply_transformations\n", "\n", "The `apply_transformations` function accepts an original PyTree, a PyTree of `Transform` objects and the desired structure of the returned Pytree. The function returns a newly generated PyTree.\n", "\n", "```\n", "def apply_transformations(\n", " original_tree: PyTree,\n", " transformations: PyTree,\n", " new_tree: PyTree,\n", " default_to_original: Optional[bool] = True) -\u003e PyTree:\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4mReT741Abvm" }, "source": [ "\n", "### `Transform`\n", "\n", "`Transform` consists of the following elements:\n", "\n", "* `original_key`: Denotes the original name of the key. Represented as a\n", " string with '/' denoting successive levels of nesting. If the key\n", " corresponding to this Transform is a regex, backreferences (such as \\1) will\n", " be replaced with the appropriate matched group in the regex. Note: not\n", " needed if multi_value_fn is provided.\n", "* `use_fallback`: if True, takes the value from the fallback tree. If\n", " `default_to_original=True` in `apply_transformations`, the fallback tree is\n", " `new_tree`. If `default_to_original=False` in `apply_transformations`, the\n", " fallback tree is `original_tree`.\n", "* `value_fn`: A function accepting a single value and returning a single\n", " value. The value provided as an argument is the value of the transformation\n", " key in the original PyTree.\n", "* `multi_value_fn`: A function accepting a PyTree and returning any value. The\n", " PyTree argument will be the original PyTree, and the function should return\n", " the value of the key in the new PyTree." ] }, { "cell_type": "markdown", "metadata": { "id": "Rb2TSojmKsxJ" }, "source": [ "### Fallbacks\n", "\n", "Note that there is an additional option for `apply_transformations`, which is\n", "`default_to_original` (True by default). This means that the values keys\n", "unspecified in `transformations` but present in *both* trees will be taken from\n", "the *original* tree. If False, such values will be taken from the *new* tree.\n", "\n", "Remember that if a key is present in the new tree, but not in the old, the value\n", "will simply be taken from the new tree. If a key is present in the original tree\n", "but not in the new, it will be dropped in the result.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "7Jf6Zibp-xEu" }, "source": [ "## Examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5BRDxffo-9-p" }, "outputs": [], "source": [ "# Setup\n", "import orbax.checkpoint as ocp\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": { "id": "UheTp5oG-0Cd" }, "source": [ "\n", "### Renaming keys\n", "\n", "Key renames are common for reusing existing checkpointed state between different models or same model at different versions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oXo7R_aq_e7Q" }, "outputs": [], "source": [ "# Example: Migrate original tree into the new_tree, which has the same\n", "# nested structure but different keys.\n", "original_tree = {\n", " 'a': 1,\n", " 'b': 2\n", "}\n", "\n", "transformations = {\n", " 'a2': ocp.Transform(original_key='a'),\n", " 'b2': ocp.Transform(original_key='b')\n", "}\n", "\n", "new_tree = {\n", " 'a2': ...,\n", " 'b2': ...\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zds6j4ujC0Rf" }, "outputs": [], "source": [ "# Example 2: Renaming with regex\n", "\n", "original_tree = {\n", " 'a1': 1,\n", " 'b5': 2\n", "}\n", "\n", "transformations = {\n", " r'([a-z])_([0-9])': ocp.Transform(original_key=r'\\1\\2'),\n", "}\n", "\n", "new_tree = {\n", " 'a_1': ...,\n", " 'b_5': ...\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_SZpZcc-DzNx" }, "outputs": [], "source": [ "# Example 3: Renaming nested trees\n", "\n", "original_tree = {\n", " 'a': 1,\n", " 'dense_1': {'kernel': 2, 'bias': 3},\n", " 'dense_2': {'kernel': 4, 'bias': 5},\n", "}\n", "\n", "# Nested keys can be represented by a single string by separating each level\n", "# with '/'.\n", "transformations = {\n", " r'([a-z]+)_NEW': ocp.Transform(original_key=r'\\1'),\n", " r'([a-z]+)_([0-9])_NEW/([a-z]+)_1': ocp.Transform(original_key=r'\\1_\\2/\\3'),\n", "}\n", "\n", "# This is equivalent to:\n", "transformations = {\n", " r'([a-z]+)_NEW': ocp.Transform(original_key=r'\\1'),\n", " r'([a-z]+)_([0-9])_NEW': {\n", " '([a-z]+)_1': ocp.Transform(original_key=r'\\1_\\2/\\3'),}\n", "}\n", "\n", "new_tree = {\n", " 'a_NEW': ...,\n", " 'dense_1_NEW': {'kernel_1': ..., 'bias_1': ...},\n", " 'dense_2_NEW': {'kernel_1': ..., 'bias_1': ...},\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "markdown", "metadata": { "id": "MaZUe1wq--ql" }, "source": [ "### Updating the value\n", "\n", "To change a leaf node in the Pytree, define a `Transform` with a `value_fn`. This transformation could be used for quantization, modifying hyperparameters, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WWou-NU8-zqK" }, "outputs": [], "source": [ "# Example: Transform the values in a tree.\n", "original_tree = {\n", " 'a': 1,\n", " 'b': 2\n", "}\n", "\n", "transformations = {\n", " 'a': ocp.Transform(value_fn=lambda v: v * 2),\n", " 'b2': ocp.Transform(value_fn=lambda v: v * 3, original_key='b')\n", "}\n", "\n", "new_tree = {\n", " 'a': ...,\n", " 'b2': ... # Output different key\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CmulDyUMCx-K" }, "outputs": [], "source": [ "# Example 2: Transform values in a tree with regex (multiply all 'a' keys by 2\n", "# all 'b' keys by 3).\n", "original_tree = {\n", " 'a1': 1,\n", " 'a2': 2,\n", " 'b': 3\n", "}\n", "\n", "transformations = {\n", " r'a([0-9]?)\\*2': ocp.Transform(value_fn=lambda v: v * 2,\n", " original_key=r'a\\1'),\n", " r'b([0-9]?)\\*3': ocp.Transform(value_fn=lambda v: v * 3,\n", " original_key=r'b\\1')\n", "}\n", "\n", "new_tree = {\n", " 'a1*2': ...,\n", " 'a2*2': ...,\n", " 'b*3': ...\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "markdown", "metadata": { "id": "pfNyW6jaCIct" }, "source": [ "### Restructuring PyTrees" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MZlbVecL-wIS" }, "outputs": [], "source": [ "# Example: Flatten nested structure\n", "original_tree = {\n", " 'a': 1,\n", " 'dense_1': {'kernel': 2, 'bias': 3},\n", " 'dense_2': {'kernel': 4, 'bias': 5},\n", "}\n", "\n", "transformations = {\n", " r'([a-z]+)': ocp.Transform(original_key=r'\\1'),\n", " r'([a-z]+)_([0-9])_([a-z]+)': ocp.Transform(original_key=r'\\1_\\2/\\3'),\n", "}\n", "\n", "\n", "new_tree = {\n", " 'a': ...,\n", " 'dense_1_kernel': ...,\n", " 'dense_1_bias': ...,\n", " 'dense_2_kernel': ...,\n", " 'dense_2_bias': ...,\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "markdown", "metadata": { "id": "G6G2f-5A40p-" }, "source": [ "### Multi-value transformations\n", "\n", "Multi-value transformations can be used to combine multiple values from the original tree into the new tree." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NZkophc16edh" }, "outputs": [], "source": [ "# Example: various multi_value_fn usage\n", "original_tree = {\n", " 'a': np.array([1, 2, 3, 4]),\n", " 'b': {'c': np.array([5, 6, 7, 8])},\n", "}\n", "\n", "transformations = {\n", " 'a': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'][-1]),\n", " 'b': {\n", " 'c': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'] + kv['b']['c'])},\n", "}\n", "\n", "\n", "new_tree = {\n", " 'a': ...,\n", " 'b': {'c': ...}\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B9wbMbze5lnI" }, "outputs": [], "source": [ "# Example: Average the weights\n", "original_tree = {\n", " 'a': {'a_1': 1, 'a_2': 2},\n", " 'b': {'b_1': 3, 'b_2': 4, 'b_3': 5},\n", "\n", "}\n", "\n", "transformations = {\n", " r'([a-z]+)': ocp.Transform(\n", " multi_value_fn=lambda k, kv: sum(kv[k].values()) / len(kv[k])),\n", "}\n", "\n", "\n", "new_tree = {\n", " 'a': ...,\n", " 'b': ...,\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] }, { "cell_type": "markdown", "metadata": { "id": "-a90aX9FAlQt" }, "source": [ "\n", "### Real world example\n", "\n", "Let's consider a real-world example. In this scenario, we have a saved\n", "checkpoint with parameters `Dense_0`, `Dense_1`. We want to restore this\n", "checkpoint, with modifications, into a model for training with layers `Dense_0`,\n", "`Dense_1`, `Dense_2`, `Dense_3`.\n", "\n", "In this example, we will map original layers 0 and 1 onto the new layers 1 and\n", "2, respectively. We want the new layers 0 and 3 to be initialized randomly, or\n", "with some new values.\n", "\n", "The new model may be initialized as a Flax\n", "[TrainState](https://flax.readthedocs.io/en/latest/flax.training.html#train-state),\n", "for example.\n", "\n", "```py\n", "params = model.init(\n", " jax.random.PRNGKey(0), jnp.ones([1, model.input_size]))\n", "new_state = TrainState.create(\n", " apply_fn=model.apply, params=params, tx=optimizer)\n", "# Restore original state.\n", "original_state = manager.restore(step)\n", "```\n", "\n", "```py\n", " transformations = {\n", " # NewModel layer 0 is a newly inserted layer, thus use_fallback=True.\n", " r'(.*)Dense_0(.*)': Transform(use_fallback=True),\n", " # OriginalModel layer 0 maps to NewModel layer 1\n", " r'(.*)Dense_1(.*)': Transform(original_key=r'\\1Dense_0\\2'),\n", " # OriginalModel layer 1 maps to NewModel layer 2\n", " r'(.*)Dense_2(.*)': Transform(original_key=r'\\1Dense_1\\2')\n", " } # Note: NewModel layer 3 is newly added.\n", " restored_state = apply_transformations(original_state, transformations, new_state)\n", "```\n", "\n", "Let's unpack what's happening with these transformations.\n", "\n", "For layer 0, we want to instruct the function to ignore what's in\n", "`original_state`, and to instead use the value from `new_state`. For this, we\n", "set `use_fallback=True`.\n", "\n", "For `Dense_1` and `Dense_2`, we simple provide a regex mapping the original name\n", "of the key (`Dense_0` and `Dense_1`, respectively) to their new values using the\n", "`original_key` field. Note that we can use a regex to match any key containing\n", "the desired pattern, since a PyTree checkpoint will typically represent a single\n", "layer with multiple different arrays, each containing the pattern.\n", "\n", "Finally, we can simply omit `Dense_3` from `transformations`, as the `Dense_3`\n", "was provided as a key in `new_state` and the function will simply take the value\n", "from `new_state` and put it in the result.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "n6nVne1vDJCH" }, "source": [ "### Restoring a Checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OIZxk3qSC6Fm" }, "outputs": [], "source": [ "import flax.struct\n", "\n", "@flax.struct.dataclass\n", "class Small:\n", " key1: int\n", "\n", "@flax.struct.dataclass\n", "class Big:\n", " key1: int\n", " key2: int\n", "\n", "to_save = Big(key1=10, key2=100)\n", "to_restore = Small(key1=0)\n", "\n", "path = '/tmp/my-checkpoints/'\n", "ckptr = ocp.PyTreeCheckpointer()\n", "ckptr.save(path, to_save)\n", "\n", "restored1 = ckptr.restore(\n", " path, args=ocp.args.PyTreeRestore(\n", " to_restore,\n", " restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),\n", " transforms={}\n", " )\n", ")\n", "restored2 = ckptr.restore(\n", " path, args=ocp.args.PyTreeRestore(\n", " to_restore,\n", " restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),\n", " transforms={\n", " r'(.*)key1(.*)': ocp.Transform(original_key=r'\\1key2\\2')\n", " }\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WmgLQZbIIcmd" }, "outputs": [], "source": [ "restored1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0lMe9uWZIj6S" }, "outputs": [], "source": [ "restored2" ] }, { "cell_type": "markdown", "metadata": { "id": "n61hVFbQI7F4" }, "source": [ "## Tips and Tricks" ] }, { "cell_type": "markdown", "metadata": { "id": "jd32KbYtJAXZ" }, "source": [ "\n", "### Regex group names\n", "\n", "If your regex is getting complicated, you can set group names using `(?P\u003cname\u003e...)`. This group can be referenced using the standard `\\N`, where N is the numeric backreference, or `\\g\u003cname\u003e` where `name` is the named backreference.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "agtpy-DeJB1Y" }, "outputs": [], "source": [ "# Example:\n", "original_tree = {\n", " 'dense_1': {'kernel': 2, 'bias': 3},\n", "}\n", "\n", "transformations = {\n", " r'(?P\u003clayer\u003e[a-z]+)_(?P\u003cnum\u003e[0-9])_(?P\u003cweight\u003e[a-z]+)': ocp.Transform(\n", " original_key=r'\\g\u003clayer\u003e_\\g\u003cnum\u003e/\\g\u003cweight\u003e'),\n", "}\n", "\n", "\n", "new_tree = {\n", " 'dense_1_kernel': ...,\n", " 'dense_1_bias': ...,\n", "}\n", "\n", "ocp.apply_transformations(original_tree, transformations, new_tree)" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//quality/ranklab/experimental/notebook:rl_colab", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1wS2W3PHxK-wRr5W91LmfIMxNuVfL44rA", "timestamp": 1690833642146 }, { "file_id": "1f-qFS7PStOYkUg-wIOAwBgfzwhkzKQTN", "timestamp": 1690101090208 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }