{ "cells": [ { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[36.]\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "tf.reset_default_graph()\n", "\n", "# define simple graphs\n", "a = tf.Variable([3.], dtype=tf.float32, name='a')\n", "b = tf.placeholder(tf.float32, shape=(), name='input')\n", "c = tf.multiply(a, b)\n", "d = tf.multiply(c, c, name='output')\n", "init = tf.global_variables_initializer()\n", "with tf.Session() as sess:\n", " sess.run(init)\n", " print(sess.run(d, feed_dict={b:2}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The following define the basic graph and use a session to compute the value" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[36.]\n" ] } ], "source": [ "\n", "\n", "import tensorflow as tf\n", "\n", "tf.reset_default_graph()\n", "\n", "# case1 normal save and restore\n", "# define simple graphs\n", "a = tf.Variable([3.], dtype=tf.float32, name='a')\n", "b = tf.placeholder(tf.float32, shape=(), name='input')\n", "c = tf.multiply(a, b, name='output_0')\n", "d = tf.multiply(c, c, name='output')\n", "init = tf.global_variables_initializer()\n", "\n", "# session will bind to the global default graph\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " sess.run(init)\n", " print(sess.run(d, feed_dict={b:2}))\n", " saver.save(sess, './tmp/model.ckpt')\n", "# then under the directory ./tmp you will find the two files\n", "# model.ckpt.meta : The definition of graph\n", "# model.ckpt.data-00000-of-00001 : The data (the value for the nodes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The first way to restore the model (Defined the new set of variables)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n", "[81.]\n" ] } ], "source": [ "\n", "\n", "# Then we have two ways to restore the graph\n", "\n", "tf.reset_default_graph()\n", "\n", "# This cell show the first case is we define a complete same graph with the original one\n", "# and load the model.ckpt.data's value back to this new graph\n", "# the name: 'a', 'input', and 'output' should be consistent with original graph\n", "j = tf.Variable([3.], dtype=tf.float32, name='a')\n", "k = tf.placeholder(tf.float32, shape=(), name='input')\n", "l = tf.multiply(j, k)\n", "m = tf.multiply(l, l, name='output')\n", "init = tf.global_variables_initializer()\n", "\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " sess.run(init)\n", " saver.restore(sess, './tmp/model.ckpt')\n", " # ok now we can test with the new samples to the placeholder!\n", " # the answer for the following should be: (3*3) * (3*3)\n", " print(sess.run(m, feed_dict={k:3}))\n", " \n", "# You may found a little trivial for this example, but this method is useful for re-training the same\n", "# graph. Since usually we'll define a inference model as a function, we don't need to manually create\n", "# every nodes on as in this example.\n", "\n", "# something like:\n", "# logits = inference(X)\n", "# sess.run(logits, feed_dict={X: 123})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The second way to restore the model (Load the graph by DEF)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n", "[81.]\n" ] } ], "source": [ "\n", "\n", "tf.reset_default_graph()\n", "\n", "# This cell we want to introduce another way for restoring the model\n", "# This method we don't want to define a set of new nodes(j, k, l, m) but just load\n", "# graph defined before into the current session and restore the data back to the loading\n", "# nodes of graph\n", "\n", "# we import the meta file which is the graph for the model\n", "tf.train.import_meta_graph('./tmp/model.ckpt.meta')\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " # now we want restore all the values to the graph\n", " saver.restore(sess, './tmp/model.ckpt')\n", " # Before we use sess.run, we need to got the output tensor\n", " # and input tensors\n", " \n", " # you may notice that there are :0, this is created by tensorflow\n", " _input = sess.graph.get_tensor_by_name(\"input:0\")\n", " _output = sess.graph.get_tensor_by_name(\"output:0\")\n", " # then we can run!\n", " print(sess.run(_output, feed_dict={_input:3}))\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Frozen model and strip out the nodes irrelevant to inference" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n", "INFO:tensorflow:Froze 1 variables.\n", "Converted 1 variables to const ops.\n" ] } ], "source": [ "# The following two approaches are common flow for training and re-training the model\n", "# but when we want to deploy the model we may want to frozen the model (convert nodes to constants)\n", "# and strip out all the training related nodes (like optimizer, minimize, ...)\n", "\n", "import tensorflow as tf\n", "\n", "tf.reset_default_graph()\n", "\n", "tf.train.import_meta_graph('./tmp/model.ckpt.meta')\n", "\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " # like before we import meta graph and restore the weights first\n", " saver.restore(sess, './tmp/model.ckpt')\n", " output_graph_def = tf.graph_util.convert_variables_to_constants(\n", " sess, sess.graph_def,\n", " ['output_0'] # this is the output node list, here we only need the output_0\n", " )\n", " # we can write it to the output pb file\n", " tf.train.write_graph(output_graph_def, \"./tmp\", \"output_text.pb\", True)\n", " tf.train.write_graph(output_graph_def, \"./tmp\", \"output_binary.pb\", False)\n", "# now observe under the directory of ./tmp, there will be two files\n", "# one is human readable text file and the other is binary file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### notice that we don't do saver.save(sess, './tmp/model.ckpt')\n", "Since we have stored the nodes as constants so there is no need to store the data to model.ckpt.data, right~" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Restore the model from the Protocol Buffer (PB) file" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[27.]\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow.python.platform import gfile\n", "\n", "with tf.Session() as sess:\n", " # we now want to load the graph_def from the pb file\n", " with gfile.FastGFile('./tmp/output_binary.pb', 'rb') as f:\n", " graph_def = tf.GraphDef()\n", " graph_def.ParseFromString(f.read())\n", " tf.import_graph_def(graph_def, name='')\n", " \n", " # get the input node\n", " _input = sess.graph.get_tensor_by_name('input:0')\n", " _output = sess.graph.get_tensor_by_name('output_0:0')\n", " sess.run(tf.global_variables_initializer())\n", " print (sess.run(_output, feed_dict={_input: 9}))\n", " # if you try to access the output node as before you will get exception\n", " output = sess.graph.get_tensor_by_name('output:0')\n", " sess.run(output, feed)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-15.0248]]\n", "INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt\n", "INFO:tensorflow:Froze 1 variables.\n", "INFO:tensorflow:Converted 1 variables to const ops.\n" ] }, { "data": { "text/plain": [ "564" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import tensorflow as tf\n", "\n", "tf.reset_default_graph()\n", "\n", "# case1 normal save and restore\n", "# define simple graphs\n", "a = tf.get_variable('a', [1,1], dtype=tf.float32)\n", "b = tf.placeholder(tf.float32, shape=(1,1), name='input')\n", "c = tf.matmul(a, b, name='output_0')\n", "init = tf.global_variables_initializer()\n", "\n", "# session will bind to the global default graph\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " sess.run(init)\n", " print(sess.run(c, feed_dict={b:[[9]]}))\n", " saver.save(sess, './tmp/model.ckpt')\n", "\n", "# Frozen the graph \n", "tf.reset_default_graph()\n", "tf.train.import_meta_graph('./tmp/model.ckpt.meta')\n", "\n", "saver = tf.train.Saver()\n", "with tf.Session() as sess:\n", " # like before we import meta graph and restore the weights first\n", " saver.restore(sess, './tmp/model.ckpt')\n", " output_graph_def = tf.graph_util.convert_variables_to_constants(\n", " sess, sess.graph_def,\n", " ['output_0'] # this is the output node list, here we only need the output_0\n", " )\n", " # we can write it to the output pb file\n", " tf.train.write_graph(output_graph_def, \"./tmp\", \"output_text.pb\", True)\n", " tf.train.write_graph(output_graph_def, \"./tmp\", \"output_binary.pb\", False) \n", "\n", "# Transform to TF-Lite\n", "graph_def_file = \"./tmp/output_binary.pb\"\n", "input_arrays = [\"input\"]\n", "output_arrays = [\"output_0\"]\n", "\n", "tf.reset_default_graph()\n", "\n", "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(graph_def_file,\n", " input_arrays,\n", " output_arrays)\n", "tflite_model = converter.convert()\n", "open(\"converterd_model.tflite\", \"wb\").write(tflite_model)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }