{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "fEE8DF5O7V7E" }, "source": [ "# Exporting with Orbax\n", "\n", "Orbax Export is a library for exporting JAX models to TensorFlow\n", "[SavedModel](https://www.tensorflow.org/guide/saved_model) format.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "PyYjiO59wq3g" }, "source": [ "## Exporting\n", "\n", "### API Overview\n", "\n", "Orbax Export provides three classes.\n", "\n", "- `JaxModule`\n", " wraps a JAX function and its parameters to an exportable and callable\n", " closure.\n", "- `ServingConfig`\n", " defines a serving configuration for a `JaxModule`, including\n", " [a signature key and an input signature][1], and optionally pre- and\n", " post-processing functions and extra [TrackableResources][2].\n", "- `ExportManager`\n", " builds the actual [serving signatures][1] based on a `JaxModule` and a list\n", " of `ServingConfig`s, and saves them to the SavedModel format. It is for CPU.\n", " Users can inherit `ExportManager` class and create their own \"ExportManager\"\n", " for different hardwares.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "9Qflhwff7lk3" }, "source": [ "### Simple Example Usage\n", "\n", "#### Setup\n", "\n", "\n", "Before start model exporting, users should have the JAX model and its model params, preprocess, postprocess function ready." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eEyYvKHm7z2S" }, "outputs": [], "source": [ "# Import Orbax Export classes.\n", "from orbax.export import ExportManager\n", "from orbax.export import JaxModule\n", "from orbax.export import ServingConfig\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "import tensorflow as tf\n", "\n", "# Prepare the parameters and model function to export.\n", "example1_params = { 'a': np.array(5.0), 'b': np.array(1.1), 'c': np.array(0.55)} # A pytree of the JAX model parameters.\n", "\n", "# model f(x) = a * sin(x) + b * x + c, here (a, b, c) are model parameters\n", "def example1_model_fn(params, inputs): # The JAX model function to export.\n", " a, b, c = params['a'], params['b'], params['c']\n", " return a * jnp.sin(inputs) + b * inputs + c\n", "\n", "def example1_preprocess(inputs): # Optional: preprocessor in TF.\n", " norm_inputs = tf.nest.map_structure(lambda x: x/tf.math.reduce_max(x), inputs)\n", " return norm_inputs\n", "\n", "def example1_postprocess(model_fn_outputs): # Optional: post-processor in TF.\n", " return {'outputs': model_fn_outputs}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RMv6wB0VstTk" }, "outputs": [], "source": [ "inputs = tf.random.normal([16], dtype=tf.float32)\n", "\n", "model_outputs = example1_postprocess(example1_model_fn(example1_params, np.array(example1_preprocess(inputs))))\n", "print(\"model output: \", model_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "9xEkmQyLpQPv" }, "source": [ "Exporting a JAX model to a CPU SavedModel" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KQpKRCIbpThw" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "# Construct a JaxModule where JAX-\u003eTF conversion happens.\n", "jax_module = JaxModule(example1_params, example1_model_fn)\n", "# Export the JaxModule along with one or more serving configs.\n", "export_mgr = ExportManager(\n", " jax_module, [\n", " ServingConfig(\n", " 'serving_default',\n", " input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],\n", " tf_preprocessor=example1_preprocess,\n", " tf_postprocessor=example1_postprocess\n", " ),\n", "])\n", "output_dir='/tmp/example1_output_dir'\n", "export_mgr.save(output_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "1GYvXqJ_tMM5" }, "source": [ "Load the TF saved_model model back and run it" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oyoaEPE9tQE8" }, "outputs": [], "source": [ "loaded_model = tf.saved_model.load(output_dir)\n", "loaded_model_outputs = loaded_model(inputs)\n", "print(\"loaded model output: \", loaded_model_outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R3Bf4K7fu6To" }, "outputs": [], "source": [ "np.testing.assert_allclose(model_outputs['outputs'], loaded_model_outputs['outputs'], atol=1e-5, rtol=1e-5)" ] }, { "cell_type": "markdown", "metadata": { "id": "qr8nAnPAvYQH" }, "source": [ "### Limitation\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bY0ScOH2vdde" }, "source": [ "#### \"JaxModule only take single arg as the input\".\n", "\n", "This error message means the JAX funtion `model_fn` only can take single arg as the input.\n", "Orbax is designed to take a JAX Module in the format of a Callable with\n", "parameters of type PyTree and model inputs of type PyTree. If your JAX function\n", "takes multiple inputs, you must pack them into a single JAX PyTree. Otherwise,\n", "you will encounter this error message.\n", "\n", "To solve this problem, you can update the `ServingConfig.tf_preprocessor`\n", "function to pack the inputs into a single JAX PyTree. For example, our model\n", "takes two inputs `x` and `y`. You can define the `ServingConfig.tf_preprocessor`\n", "pack them into a list `[x, y]`.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ucjcPZ7Dvkd9" }, "outputs": [], "source": [ "example2_params = {} # A pytree of the JAX model parameters.\n", "\n", "def example2_model_fn(params, inputs):\n", " x, y = inputs\n", " return x + y\n", "\n", "def example2_preprocessor(x, y):\n", " # put the normal tf_preprocessor codes here.\n", " return [x, y] # pack it into a single list for jax model_func.\n", "\n", "jax_module = JaxModule(example2_params, example2_model_fn)\n", "export_mgr = ExportManager(\n", " jax_module,\n", " [\n", " ServingConfig(\n", " 'serving_default',\n", " input_signature=[tf.TensorSpec([16]), tf.TensorSpec([16])],\n", " tf_preprocessor=example2_preprocessor,\n", " )\n", " ],\n", ")\n", "output_dir='/tmp/example2_output_dir'\n", "export_mgr.save(output_dir)\n", "\n", "loaded_model = tf.saved_model.load(output_dir)\n", "loaded_model_outputs = loaded_model(tf.random.normal([16]), tf.random.normal([16]))\n", "print(\"loaded model output: \", loaded_model_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "uZJ9ep0O3S0m" }, "source": [ "## Validating\n", "\n", "### API Overview\n", "\n", "Orbax.export.validate is library that can be used to validate the JAX model and\n", "its exported TF [SavedModel](https://www.tensorflow.org/guide/saved_model)\n", "format.\n", "\n", "Users must finish the JAX model exporting first. Users can export the model by\n", "orbax.export or manually.\n", "\n", "Orbax.export.validate provides those classes:\n", "\n", "* `ValidationJob`\n", " take the model and data as input, then output the result.\n", "* `ValidationReport`\n", " compare the JAX model and TF SavedModel results, then generate the formatted\n", " report.\n", "* `ValidationManager`\n", " take `JaxModule` as inputs and wrap the validation e2e flow.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "tFoHbN6wzfHY" }, "source": [ "### Simple Example Usage\n", "\n", "Here we same example as ExportManager.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ceBvaaMozkW4" }, "outputs": [], "source": [ "from orbax.export.validate import ValidationManager\n", "from orbax.export import JaxModule\n", "from orbax.export import ServingConfig\n", "\n", "jax_module = JaxModule(example1_params, example1_model_fn)\n", "batch_inputs = [inputs] * 16\n", "\n", "serving_configs = [\n", " ServingConfig(\n", " 'serving_default',\n", " input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],\n", " tf_preprocessor=example1_preprocess,\n", " tf_postprocessor=example1_postprocess\n", " ),\n", "]\n", "# Provide computation method for the baseline.\n", "validation_mgr = ValidationManager(jax_module, serving_configs,\n", " batch_inputs)\n", "\n", "tf_saved_model_path = \"/tmp/example1_output_dir\"\n", "loaded_model = tf.saved_model.load(tf_saved_model_path)\n", "\n", "# Provide the computation method for the candidate.\n", "validation_reports = validation_mgr.validate(loaded_model)\n", "\n", "# `validation_reports` is a python dict and the key is TF SavedModel serving_key.\n", "for key in validation_reports:\n", " assert(validation_reports[key].status.name == 'Pass')\n", " # Users can also save the converted json to file.\n", " print(validation_reports[key].to_json(indent=2))" ] }, { "cell_type": "markdown", "metadata": { "id": "mUolCMdY7J70" }, "source": [ "\n", "### Limitation\n", "\n", "Here we list those limitation of Orbax.export validate module.\n", "\n", "* Because the TF SavedModel the returned object is always a map. If the jax\n", " model output is a sequence, TF SavedModel will convert it to map. The tensor\n", " names are fairly generic, like output_0. To help `ValidationReport` module\n", " can do apple-to-apple comparison between JAX model and TF model result, we\n", " suggest users modify the model output as a dictionary.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "T_0M1ndhCBi4" }, "source": [ "## Examples\n", "\n", "Check-out the\n", "[examples](https://github.com/google/orbax/tree/main/export/orbax/export/examples)\n", "directory for a number of examples using Orbax Export.\n", "\n", "[1]: https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export\n", "[2]: https://www.tensorflow.org/api_docs/python/tf/saved_model/experimental/TrackableResource" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//third_party/py/jaxonnxruntime:jort_ml_notebook", "kind": "private" }, "provenance": [ { "file_id": "1QNxBBBN16Br9Xj-a7LvtJzJWjOBhjFps", "timestamp": 1686159333109 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }