{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "6bYaCABobL5q" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "FlUw7tSKbtg4" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "xc1srSc51n_4" }, "source": [ "# Using the SavedModel format" ] }, { "cell_type": "markdown", "metadata": { "id": "-nBUqG2rchGH" }, "source": [ "
\n",
" View on TensorFlow.org\n",
" | \n",
" \n",
" Run in Google Colab\n",
" | \n",
" \n",
" View source on GitHub\n",
" | \n",
" \n",
" Download notebook\n",
" | \n",
"
\n",
"ValueError: Could not find matching function to call for canonicalized inputs ((,), {}). Only existing signatures are [((TensorSpec(shape=(), dtype=tf.float32, name=u'x'),), {})].\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4Vsva3UZ-2sf"
},
"source": [
"### Basic fine-tuning\n",
"\n",
"Variable objects are available, and you can backprop through imported functions. That is enough to fine-tune (i.e. retrain) a SavedModel in simple cases."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PEkQNarJ-7nT"
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(0.05)\n",
"\n",
"def train_step():\n",
" with tf.GradientTape() as tape:\n",
" loss = (10. - imported(tf.constant(2.))) ** 2\n",
" variables = tape.watched_variables()\n",
" grads = tape.gradient(loss, variables)\n",
" optimizer.apply_gradients(zip(grads, variables))\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p41NM6fF---3"
},
"outputs": [],
"source": [
"for _ in range(10):\n",
" # \"v\" approaches 5, \"loss\" approaches 0\n",
" print(\"loss={:.2f} v={:.2f}\".format(train_step(), imported.v.numpy()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XuXtkHSD_KSW"
},
"source": [
"### General fine-tuning\n",
"\n",
"A SavedModel from Keras provides [more details](https://github.com/tensorflow/community/blob/master/rfcs/20190509-keras-saved-model.md#serialization-details) than a plain `__call__` to address more advanced cases of fine-tuning. TensorFlow Hub recommends to provide the following of those, if applicable, in SavedModels shared for the purpose of fine-tuning:\n",
"\n",
" * If the model uses dropout or another technique in which the forward pass differs between training and inference (like batch normalization), the `__call__` method takes an optional, Python-valued `training=` argument that defaults to `False` but can be set to `True`.\n",
" * Next to the `__call__` attribute, there are `.variable` and `.trainable_variable` attributes with the corresponding lists of variables. A variable that was originally trainable but is meant to be frozen during fine-tuning is omitted from `.trainable_variables`.\n",
" * For the sake of frameworks like Keras that represent weight regularizers as attributes of layers or sub-models, there can also be a `.regularization_losses` attribute. It holds a list of zero-argument functions whose values are meant for addition to the total loss.\n",
"\n",
"Going back to the initial MobileNet example, you can see some of those in action:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y6EUFdY8_PRD"
},
"outputs": [],
"source": [
"loaded = tf.saved_model.load(mobilenet_save_path)\n",
"print(\"MobileNet has {} trainable variables: {}, ...\".format(\n",
" len(loaded.trainable_variables),\n",
" \", \".join([v.name for v in loaded.trainable_variables[:5]])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B-mQJ8iP_R0h"
},
"outputs": [],
"source": [
"trainable_variable_ids = {id(v) for v in loaded.trainable_variables}\n",
"non_trainable_variables = [v for v in loaded.variables\n",
" if id(v) not in trainable_variable_ids]\n",
"print(\"MobileNet also has {} non-trainable variables: {}, ...\".format(\n",
" len(non_trainable_variables),\n",
" \", \".join([v.name for v in non_trainable_variables[:3]])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qGlHlbd3_eyO"
},
"source": [
"## Specifying signatures during export\n",
"\n",
"Tools like TensorFlow Serving and `saved_model_cli` can interact with SavedModels. To help these tools determine which ConcreteFunctions to use, you need to specify serving signatures. `tf.keras.Model`s automatically specify serving signatures, but you'll have to explicitly declare a serving signature for our custom modules.\n",
"\n",
"IMPORTANT: Unless you need to export your model to an environment other than TensorFlow 2.x with Python, you probably don't need to export signatures explicitly. If you're looking for a way of enforcing an input signature for a specific function, see the [`input_signature`](https://www.tensorflow.org/api_docs/python/tf/function#args_1) argument to `tf.function`.\n",
"\n",
"By default, no signatures are declared in a custom `tf.Module`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h-IB5Xa0NxLa"
},
"outputs": [],
"source": [
"assert len(imported.signatures) == 0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BiNtaMZSI8Tb"
},
"source": [
"To declare a serving signature, specify a ConcreteFunction using the `signatures` kwarg. When specifying a single signature, its signature key will be `'serving_default'`, which is saved as the constant `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_pAdgIORR2yH"
},
"outputs": [],
"source": [
"module_with_signature_path = os.path.join(tmpdir, 'module_with_signature')\n",
"call = module.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))\n",
"tf.saved_model.save(module, module_with_signature_path, signatures=call)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nAzRHR0UT4hv"
},
"outputs": [],
"source": [
"imported_with_signatures = tf.saved_model.load(module_with_signature_path)\n",
"list(imported_with_signatures.signatures.keys()) # [\"serving_default\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_gH91j1IR4tq"
},
"source": [
"To export multiple signatures, pass a dictionary of signature keys to ConcreteFunctions. Each signature key corresponds to one ConcreteFunction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6VYAiQmLUiox"
},
"outputs": [],
"source": [
"module_multiple_signatures_path = os.path.join(tmpdir, 'module_with_multiple_signatures')\n",
"signatures = {\"serving_default\": call,\n",
" \"array_input\": module.__call__.get_concrete_function(tf.TensorSpec([None], tf.float32))}\n",
"\n",
"tf.saved_model.save(module, module_multiple_signatures_path, signatures=signatures)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8IPx_0RWEx07"
},
"outputs": [],
"source": [
"imported_with_multiple_signatures = tf.saved_model.load(\n",
" module_multiple_signatures_path\n",
")\n",
"list(\n",
" imported_with_multiple_signatures.signatures.keys()\n",
") # [\"serving_default\", \"array_input\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "43_Qv2W_DJZZ"
},
"source": [
"By default, the output tensor names are fairly generic, like `output_0`. To control the names of outputs, modify your `tf.function` to return a dictionary that maps output names to outputs. The names of inputs are derived from the Python function arg names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ACKPl1X8G1gw"
},
"outputs": [],
"source": [
"class CustomModuleWithOutputName(tf.Module):\n",
" def __init__(self):\n",
" super(CustomModuleWithOutputName, self).__init__()\n",
" self.v = tf.Variable(1.)\n",
"\n",
" @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])\n",
" def __call__(self, x):\n",
" return {'custom_output_name': x * self.v}\n",
"\n",
"module_output = CustomModuleWithOutputName()\n",
"call_output = module_output.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))\n",
"module_output_path = os.path.join(tmpdir, 'module_with_output_name')\n",
"tf.saved_model.save(module_output, module_output_path,\n",
" signatures={'serving_default': call_output})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1yGVy4MuH-V0"
},
"outputs": [],
"source": [
"imported_with_output_name = tf.saved_model.load(module_output_path)\n",
"imported_with_output_name.signatures[\n",
" 'serving_default'\n",
"].structured_outputs # {'custom_output_name': TensorSpec(shape=\n",
"$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu\n",
"\n",
"\n",
"To show all inputs and outputs TensorInfo for a specific `SignatureDef`, pass in\n",
"the `SignatureDef` key to `signature_def` option. This is very useful when you\n",
"want to know the tensor key value, dtype and shape of the input tensors for\n",
"executing the computation graph later. For example:\n",
"\n",
"```\n",
"$ saved_model_cli show --dir \\\n",
"/tmp/saved_model_dir --tag_set serve --signature_def serving_default\n",
"The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['x'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: x:0\n",
"The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['y'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: y:0\n",
"Method name is: tensorflow/serving/predict\n",
"```\n",
"\n",
"To show all available information in the SavedModel, use the `--all` option.\n",
"For example:\n",
"\n",
"\n",
"$ saved_model_cli show --dir /tmp/saved_model_dir --all\n",
"MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
"\n",
"signature_def['classify_x2_to_y3']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['inputs'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: x2:0\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['scores'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: y3:0\n",
" Method name is: tensorflow/serving/classify\n",
"\n",
"...\n",
"\n",
"signature_def['serving_default']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['x'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: x:0\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['y'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 1)\n",
" name: y:0\n",
" Method name is: tensorflow/serving/predict\n",
"\n",
"\n",
"\n",
"### `run` command\n",
"\n",
"Invoke the `run` command to run a graph computation, passing\n",
"inputs and then displaying (and optionally saving) the outputs.\n",
"Here's the syntax:\n",
"\n",
"```\n",
"usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def\n",
" SIGNATURE_DEF_KEY [--inputs INPUTS]\n",
" [--input_exprs INPUT_EXPRS]\n",
" [--input_examples INPUT_EXAMPLES] [--outdir OUTDIR]\n",
" [--overwrite] [--tf_debug]\n",
"```\n",
"\n",
"The `run` command provides the following three ways to pass inputs to the model:\n",
"\n",
"* `--inputs` option enables you to pass numpy ndarray in files.\n",
"* `--input_exprs` option enables you to pass Python expressions.\n",
"* `--input_examples` option enables you to pass `tf.train.Example`.\n",
"\n",
"#### `--inputs`\n",
"\n",
"To pass input data in files, specify the `--inputs` option, which takes the\n",
"following general format:\n",
"\n",
"```bsh\n",
"--inputs