{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "uAttKaKmT435" }, "source": [ "\u003cdiv class=\"devsite-table-wrapper\"\u003e\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "\u003ctd\u003e\u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfx/tutorials/transform/census\"\u003e\n", "\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\u003c/td\u003e\n", "\u003ctd\u003e\u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/transform/census.ipynb\"\u003e\n", "\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"\u003eRun in Google Colab\u003c/a\u003e\u003c/td\u003e\n", "\u003ctd\u003e\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tfx/blob/master/docs/tutorials/transform/census.ipynb\"\u003e\n", "\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\"\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\n", "\u003ctd\u003e\u003ca target=\"_blank\" href=\"https://storage.googleapis.com/tensorflow_docs/tfx/docs/tutorials/transform/census.ipynb\"\u003e\n", "\u003cimg width=32px src=\"https://www.tensorflow.org/images/download_logo_32px.png\"\u003eDownload notebook\u003c/a\u003e\u003c/td\u003e\n", "\u003c/table\u003e\u003c/div\u003e" ] }, { "cell_type": "markdown", "metadata": { "id": "tghWegsjhpkt" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "rSGJWC5biBiG" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "mPt5BHTwy_0F" }, "source": [ "# Preprocessing data with TensorFlow Transform\n", "***The Feature Engineering Component of TensorFlow Extended (TFX)***\n", "\n", "This example colab notebook provides a somewhat more advanced example of how \u003ca target='_blank' href='https://www.tensorflow.org/tfx/transform/'\u003eTensorFlow Transform\u003c/a\u003e (`tf.Transform`) can be used to preprocess data using exactly the same code for both training a model and serving inferences in production.\n", "\n", "TensorFlow Transform is a library for preprocessing input data for TensorFlow, including creating features that require a full pass over the training dataset. For example, using TensorFlow Transform you could:\n", "\n", "* Normalize an input value by using the mean and standard deviation\n", "* Convert strings to integers by generating a vocabulary over all of the input values\n", "* Convert floats to integers by assigning them to buckets, based on the observed data distribution\n", "\n", "TensorFlow has built-in support for manipulations on a single example or a batch of examples. `tf.Transform` extends these capabilities to support full passes over the entire training dataset.\n", "\n", "The output of `tf.Transform` is exported as a TensorFlow graph which you can use for both training and serving. Using the same graph for both training and serving can prevent skew, since the same transformations are applied in both stages.\n", "\n", "Key Point: In order to understand `tf.Transform` and how it works with Apache Beam, you'll need to know a little bit about Apache Beam itself. The \u003ca target='_blank' href='https://beam.apache.org/documentation/programming-guide/'\u003eBeam Programming Guide\u003c/a\u003e is a great place to start." ] }, { "cell_type": "markdown", "metadata": { "id": "_tQUubddMvnP" }, "source": [ "##What we're doing in this example\n", "\n", "In this example we'll be processing a \u003ca target='_blank' href='https://archive.ics.uci.edu/ml/machine-learning-databases/adult'\u003ewidely used dataset containing census data\u003c/a\u003e, and training a model to do classification. Along the way we'll be transforming the data using `tf.Transform`.\n", "\n", "Key Point: As a modeler and developer, think about how this data is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is a feature relevant to the problem you want to solve or will it introduce bias? For more information, read about \u003ca target='_blank' href='https://developers.google.com/machine-learning/fairness-overview/'\u003eML fairness\u003c/a\u003e.\n", "\n", "Note: \u003ca target='_blank' href='https://www.tensorflow.org/tfx/model_analysis'\u003eTensorFlow Model Analysis\u003c/a\u003e is a powerful tool for understanding how well your model predicts for various segments of your data, including understanding how your model may reinforce societal biases and disparities." ] }, { "cell_type": "markdown", "metadata": { "id": "OeonII4omTr1" }, "source": [ "### Install TensorFlow Transform\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9Ak6XDO5mT3m" }, "outputs": [], "source": [ "!pip install tensorflow-transform" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R0mXLOJR_-dv" }, "outputs": [], "source": [ "# This cell is only necessary because packages were installed while python was\n", "# running. It avoids the need to restart the runtime when running in Colab.\n", "import pkg_resources\n", "import importlib\n", "\n", "importlib.reload(pkg_resources)" ] }, { "cell_type": "markdown", "metadata": { "id": "RptgLn2RYuK3" }, "source": [ "## Imports and globals\n", "\n", "First import the stuff we need." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K4QXVIM7iglN" }, "outputs": [], "source": [ "import math\n", "import os\n", "import pprint\n", "\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "print('TF: {}'.format(tf.__version__))\n", "\n", "import apache_beam as beam\n", "print('Beam: {}'.format(beam.__version__))\n", "\n", "import tensorflow_transform as tft\n", "import tensorflow_transform.beam as tft_beam\n", "print('Transform: {}'.format(tft.__version__))\n", "\n", "from tfx_bsl.public import tfxio\n", "from tfx_bsl.coders.example_coder import RecordBatchToExamples" ] }, { "cell_type": "markdown", "metadata": { "id": "sutRmRNSGT5p" }, "source": [ "Next download the data files:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mKEYRl2g_vzl" }, "outputs": [], "source": [ "!wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data\n", "!wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.test\n", "\n", "train_path = './adult.data'\n", "test_path = './adult.test'" ] }, { "cell_type": "markdown", "metadata": { "id": "CxOxaaOYRfl7" }, "source": [ "### Name our columns\n", "We'll create some handy lists for referencing the columns in our dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-bsr1nLHqyg_" }, "outputs": [], "source": [ "CATEGORICAL_FEATURE_KEYS = [\n", " 'workclass',\n", " 'education',\n", " 'marital-status',\n", " 'occupation',\n", " 'relationship',\n", " 'race',\n", " 'sex',\n", " 'native-country',\n", "]\n", "\n", "NUMERIC_FEATURE_KEYS = [\n", " 'age',\n", " 'capital-gain',\n", " 'capital-loss',\n", " 'hours-per-week',\n", " 'education-num'\n", "]\n", "\n", "ORDERED_CSV_COLUMNS = [\n", " 'age', 'workclass', 'fnlwgt', 'education', 'education-num',\n", " 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n", " 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label'\n", "]\n", "\n", "LABEL_KEY = 'label'" ] }, { "cell_type": "markdown", "metadata": { "id": "R52dXlw0G0CN" }, "source": [ "Here's a quick preview of the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "312cQ5vwGjOu" }, "outputs": [], "source": [ "pandas_train = pd.read_csv(train_path, header=None, names=ORDERED_CSV_COLUMNS)\n", "\n", "pandas_train.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zzjzjR3351j0" }, "outputs": [], "source": [ "one_row = dict(pandas_train.loc[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zk2b8IPd4uPr" }, "outputs": [], "source": [ "COLUMN_DEFAULTS = [\n", " '' if isinstance(v, str) else 0.0\n", " for v in dict(pandas_train.loc[1]).values()]" ] }, { "cell_type": "markdown", "metadata": { "id": "LefAguV5ICMc" }, "source": [ "The test data has 1 header line that needs to be skipped, and a trailing \".\" at the end of each line." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RasgDIUKHCpV" }, "outputs": [], "source": [ "pandas_test = pd.read_csv(test_path, header=1, names=ORDERED_CSV_COLUMNS)\n", "\n", "pandas_test.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s9aH5ZnDdD_z" }, "outputs": [], "source": [ "testing = os.getenv(\"WEB_TEST_BROWSER\", False)\n", "if testing:\n", " pandas_train = pandas_train.loc[:1]\n", " pandas_test = pandas_test.loc[:1]" ] }, { "cell_type": "markdown", "metadata": { "id": "qtTn4at8rurk" }, "source": [ "###Define our features and schema\n", "Let's define a schema based on what types the columns are in our input. Among other things this will help with importing them correctly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5oS2RfyCrzMr" }, "outputs": [], "source": [ "RAW_DATA_FEATURE_SPEC = dict(\n", " [(name, tf.io.FixedLenFeature([], tf.string))\n", " for name in CATEGORICAL_FEATURE_KEYS] +\n", " [(name, tf.io.FixedLenFeature([], tf.float32))\n", " for name in NUMERIC_FEATURE_KEYS] + \n", " [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]\n", ")\n", "\n", "SCHEMA = tft.tf_metadata.dataset_metadata.DatasetMetadata(\n", " tft.tf_metadata.schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)).schema" ] }, { "cell_type": "markdown", "metadata": { "id": "_j6M7ObpaLHi" }, "source": [ "### [Optional] Encode and decode tf.train.Example protos" ] }, { "cell_type": "markdown", "metadata": { "id": "rgGO9-GkZ5Kv" }, "source": [ "This tutorial needs to convert examples from the dataset to and from `tf.train.Example` protos in a few places. \n", "\n", "The hidden `encode_example` function below converts a dictionary of features forom the dataset to a `tf.train.Example`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wbhndy7uWqYp" }, "outputs": [], "source": [ "#@title\n", "def encode_example(input_features):\n", " input_features = dict(input_features)\n", " output_features = {}\n", " \n", " for key in CATEGORICAL_FEATURE_KEYS:\n", " value = input_features[key]\n", " feature = tf.train.Feature(\n", " bytes_list=tf.train.BytesList(value=[value.strip().encode()]))\n", " output_features[key] = feature \n", "\n", " for key in NUMERIC_FEATURE_KEYS:\n", " value = input_features[key]\n", " feature = tf.train.Feature(\n", " float_list=tf.train.FloatList(value=[value]))\n", " output_features[key] = feature \n", "\n", " label_value = input_features.get(LABEL_KEY, None)\n", " if label_value is not None:\n", " output_features[LABEL_KEY] = tf.train.Feature(\n", " bytes_list = tf.train.BytesList(value=[label_value.strip().encode()]))\n", "\n", " example = tf.train.Example(\n", " features = tf.train.Features(feature=output_features)\n", " )\n", " return example" ] }, { "cell_type": "markdown", "metadata": { "id": "4qx7fSVmmwIQ" }, "source": [ "Now you can convert dataset examples into `Example` protos:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sWd95yxJceXy" }, "outputs": [], "source": [ "tf_example = encode_example(pandas_train.loc[0])\n", "tf_example.features.feature['age']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EutF2aPXbAUd" }, "outputs": [], "source": [ "serialized_example_batch = tf.constant([\n", " encode_example(pandas_train.loc[i]).SerializeToString()\n", " for i in range(3)\n", "])\n", "\n", "serialized_example_batch" ] }, { "cell_type": "markdown", "metadata": { "id": "DTqlJcI_m6az" }, "source": [ "You can also convert batches of serialized Example protos back into a dictionary of tensors:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jXlrur1vc4n_" }, "outputs": [], "source": [ "decoded_tensors = tf.io.parse_example(\n", " serialized_example_batch,\n", " features=RAW_DATA_FEATURE_SPEC\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "eUAcdCrEdDe3" }, "source": [ "In some cases the label will not be passed in, so the encode function is written so that the label is optional:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EEt3nPr_o59f" }, "outputs": [], "source": [ "features_dict = dict(pandas_train.loc[0])\n", "features_dict.pop(LABEL_KEY)\n", "\n", "LABEL_KEY in features_dict" ] }, { "cell_type": "markdown", "metadata": { "id": "O0yqvsHtpDdX" }, "source": [ "When creating an `Example` proto it will simply not contain the label key. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7N5FMXO7dRzM" }, "outputs": [], "source": [ "no_label_example = encode_example(features_dict)\n", "\n", "LABEL_KEY in no_label_example.features.feature.keys()" ] }, { "cell_type": "markdown", "metadata": { "id": "zdXy9lo4t45d" }, "source": [ "###Setting hyperparameters and basic housekeeping\n", "\n", "Constants and hyperparameters used for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8WHyOkC9uL71" }, "outputs": [], "source": [ "NUM_OOV_BUCKETS = 1\n", "\n", "EPOCH_SPLITS = 10\n", "TRAIN_NUM_EPOCHS = 2*EPOCH_SPLITS\n", "NUM_TRAIN_INSTANCES = len(pandas_train)\n", "NUM_TEST_INSTANCES = len(pandas_test)\n", "\n", "BATCH_SIZE = 128\n", "\n", "STEPS_PER_TRAIN_EPOCH = tf.math.ceil(NUM_TRAIN_INSTANCES/BATCH_SIZE/EPOCH_SPLITS)\n", "EVALUATION_STEPS = tf.math.ceil(NUM_TEST_INSTANCES/BATCH_SIZE)\n", "\n", "# Names of temp files\n", "TRANSFORMED_TRAIN_DATA_FILEBASE = 'train_transformed'\n", "TRANSFORMED_TEST_DATA_FILEBASE = 'test_transformed'\n", "EXPORTED_MODEL_DIR = 'exported_model_dir'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lG2uO-88c6R9" }, "outputs": [], "source": [ "if testing:\n", " TRAIN_NUM_EPOCHS = 1" ] }, { "cell_type": "markdown", "metadata": { "id": "0a1ns5KswDb2" }, "source": [ "##Preprocessing with `tf.Transform`" ] }, { "cell_type": "markdown", "metadata": { "id": "KKd3mCLNVYmg" }, "source": [ "###Create a `tf.Transform` preprocessing_fn\n", "The _preprocessing function_ is the most important concept of tf.Transform. A preprocessing function is where the transformation of the dataset really happens. It accepts and returns a dictionary of tensors, where a tensor means a [`Tensor`](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/Tensor) or [`SparseTensor`](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/SparseTensor). There are two main groups of API calls that typically form the heart of a preprocessing function:\n", "\n", "1. **TensorFlow Ops:** Any function that accepts and returns tensors, which usually means TensorFlow ops. These add TensorFlow operations to the graph that transforms raw data into transformed data one feature vector at a time. These will run for every example, during both training and serving.\n", "2. **Tensorflow Transform Analyzers/Mappers:** Any of the analyzers/mappers provided by tf.Transform. These also accept and return tensors, and typically contain a combination of Tensorflow ops and Beam computation, but unlike TensorFlow ops they only run in the Beam pipeline during analysis requiring a full pass over the entire training dataset. The Beam computation runs only once, (prior to training, during analysis), and typically make a full pass over the entire training dataset. They create `tf.constant` tensors, which are added to your graph. For example, `tft.min` computes the minimum of a tensor over the training dataset.\n", "\n", "Caution: When you apply your preprocessing function to serving inferences, the constants that were created by analyzers during training do not change. If your data has trend or seasonality components, plan accordingly.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DZopPfpaH4sB" }, "source": [ "Here is a `preprocessing_fn` for this dataset. It does several things:\n", "\n", "1. Using `tft.scale_to_0_1`, it scales the numeric features to the `[0,1]` range.\n", "2. Using `tft.compute_and_apply_vocabulary`, it computes a vocabulary for each of the categorical features, and returns the integer IDs for each input as an `tf.int64`. This applies both to string and integer categorical-inputs.\n", "3. It applies some manual transformations to the data using standard TensorFlow operations. Here these operations are applied to the label but could transform the features as well. The TensorFlow operations do several things: \n", " * They build a lookup table for the label (the `tf.init_scope` ensures that the table is only created the first time the function is called).\n", " * They normalize the text of the label.\n", " * They convert the label to a one-hot. \n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LDrzuYH0WFc2" }, "outputs": [], "source": [ "def preprocessing_fn(inputs):\n", " \"\"\"Preprocess input columns into transformed columns.\"\"\"\n", " # Since we are modifying some features and leaving others unchanged, we\n", " # start by setting `outputs` to a copy of `inputs.\n", " outputs = inputs.copy()\n", "\n", " # Scale numeric columns to have range [0, 1].\n", " for key in NUMERIC_FEATURE_KEYS:\n", " outputs[key] = tft.scale_to_0_1(inputs[key])\n", "\n", " # For all categorical columns except the label column, we generate a\n", " # vocabulary but do not modify the feature. This vocabulary is instead\n", " # used in the trainer, by means of a feature column, to convert the feature\n", " # from a string to an integer id.\n", " for key in CATEGORICAL_FEATURE_KEYS:\n", " outputs[key] = tft.compute_and_apply_vocabulary(\n", " tf.strings.strip(inputs[key]),\n", " num_oov_buckets=NUM_OOV_BUCKETS,\n", " vocab_filename=key)\n", "\n", " # For the label column we provide the mapping from string to index.\n", " table_keys = ['\u003e50K', '\u003c=50K']\n", " with tf.init_scope():\n", " initializer = tf.lookup.KeyValueTensorInitializer(\n", " keys=table_keys,\n", " values=tf.cast(tf.range(len(table_keys)), tf.int64),\n", " key_dtype=tf.string,\n", " value_dtype=tf.int64)\n", " table = tf.lookup.StaticHashTable(initializer, default_value=-1)\n", "\n", " # Remove trailing periods for test data when the data is read with tf.data.\n", " # label_str = tf.sparse.to_dense(inputs[LABEL_KEY])\n", " label_str = inputs[LABEL_KEY]\n", " label_str = tf.strings.regex_replace(label_str, r'\\.$', '')\n", " label_str = tf.strings.strip(label_str)\n", " data_labels = table.lookup(label_str)\n", " transformed_label = tf.one_hot(\n", " indices=data_labels, depth=len(table_keys), on_value=1.0, off_value=0.0)\n", " outputs[LABEL_KEY] = tf.reshape(transformed_label, [-1, len(table_keys)])\n", "\n", " return outputs" ] }, { "cell_type": "markdown", "metadata": { "id": "sA1Eg2JXFzzZ" }, "source": [ "## Syntax\n", "\n", "You're almost ready to put everything together and use \u003ca target='_blank' href='https://beam.apache.org/'\u003eApache Beam\u003c/a\u003e to run it.\n", "\n", "Apache Beam uses a \u003ca target='_blank' href='https://beam.apache.org/documentation/programming-guide/#applying-transforms'\u003especial syntax to define and invoke transforms\u003c/a\u003e. For example, in this line:\n", "\n", "```\n", "result = pass_this | 'name this step' \u003e\u003e to_this_call\n", "```\n", "\n", "The method `to_this_call` is being invoked and passed the object called `pass_this`, and \u003ca target='_blank' href='https://stackoverflow.com/questions/50519662/what-does-the-redirection-mean-in-apache-beam-python'\u003ethis operation will be referred to as `name this step` in a stack trace\u003c/a\u003e. The result of the call to `to_this_call` is returned in `result`. You will often see stages of a pipeline chained together like this:\n", "\n", "```\n", "result = apache_beam.Pipeline() | 'first step' \u003e\u003e do_this_first() | 'second step' \u003e\u003e do_this_last()\n", "```\n", "\n", "and since that started with a new pipeline, you can continue like this:\n", "\n", "```\n", "next_result = result | 'doing more stuff' \u003e\u003e another_function()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "rgAGOAdFWRn2" }, "source": [ "### Transform the data\n", "\n", "Now we're ready to start transforming our data in an Apache Beam pipeline.\n", "\n", "1. Read in the data using the `tfxio.CsvTFXIO` CSV reader (to process lines of text in a pipeline use `tfxio.BeamRecordCsvTFXIO` instead).\n", "1. Analyse and transform the data using the `preprocessing_fn` defined above.\n", "1. Write out the result as a `TFRecord` of `Example` protos, which we will use for training a model later\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PCeYucVoRRfo" }, "outputs": [], "source": [ "def transform_data(train_data_file, test_data_file, working_dir):\n", " \"\"\"Transform the data and write out as a TFRecord of Example protos.\n", "\n", " Read in the data using the CSV reader, and transform it using a\n", " preprocessing pipeline that scales numeric data and converts categorical data\n", " from strings to int64 values indices, by creating a vocabulary for each\n", " category.\n", "\n", " Args:\n", " train_data_file: File containing training data\n", " test_data_file: File containing test data\n", " working_dir: Directory to write transformed data and metadata to\n", " \"\"\"\n", "\n", " # The \"with\" block will create a pipeline, and run that pipeline at the exit\n", " # of the block.\n", " with beam.Pipeline() as pipeline:\n", " with tft_beam.Context(temp_dir=tempfile.mkdtemp()):\n", " # Create a TFXIO to read the census data with the schema. To do this we\n", " # need to list all columns in order since the schema doesn't specify the\n", " # order of columns in the csv.\n", " # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource()\n", " # accepts a PCollection[bytes] because we need to patch the records first\n", " # (see \"FixCommasTrainData\" below). Otherwise, tfxio.CsvTFXIO can be used\n", " # to both read the CSV files and parse them to TFT inputs:\n", " # csv_tfxio = tfxio.CsvTFXIO(...)\n", " # raw_data = (pipeline | 'ToRecordBatches' \u003e\u003e csv_tfxio.BeamSource())\n", " train_csv_tfxio = tfxio.CsvTFXIO(\n", " file_pattern=train_data_file,\n", " telemetry_descriptors=[],\n", " column_names=ORDERED_CSV_COLUMNS,\n", " schema=SCHEMA)\n", "\n", " # Read in raw data and convert using CSV TFXIO.\n", " raw_data = (\n", " pipeline |\n", " 'ReadTrainCsv' \u003e\u003e train_csv_tfxio.BeamSource())\n", "\n", " # Combine data and schema into a dataset tuple. Note that we already used\n", " # the schema to read the CSV data, but we also need it to interpret\n", " # raw_data.\n", " cfg = train_csv_tfxio.TensorAdapterConfig()\n", " raw_dataset = (raw_data, cfg)\n", "\n", " # The TFXIO output format is chosen for improved performance.\n", " transformed_dataset, transform_fn = (\n", " raw_dataset | tft_beam.AnalyzeAndTransformDataset(\n", " preprocessing_fn, output_record_batches=True))\n", "\n", " # Transformed metadata is not necessary for encoding.\n", " transformed_data, _ = transformed_dataset\n", "\n", " # Extract transformed RecordBatches, encode and write them to the given\n", " # directory.\n", " # TODO(b/223384488): Switch to `RecordBatchToExamplesEncoder`.\n", " _ = (\n", " transformed_data\n", " | 'EncodeTrainData' \u003e\u003e\n", " beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))\n", " | 'WriteTrainData' \u003e\u003e beam.io.WriteToTFRecord(\n", " os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)))\n", "\n", " # Now apply transform function to test data. In this case we remove the\n", " # trailing period at the end of each line, and also ignore the header line\n", " # that is present in the test data file.\n", " test_csv_tfxio = tfxio.CsvTFXIO(\n", " file_pattern=test_data_file,\n", " skip_header_lines=1,\n", " telemetry_descriptors=[],\n", " column_names=ORDERED_CSV_COLUMNS,\n", " schema=SCHEMA)\n", " raw_test_data = (\n", " pipeline\n", " | 'ReadTestCsv' \u003e\u003e test_csv_tfxio.BeamSource())\n", "\n", " raw_test_dataset = (raw_test_data, test_csv_tfxio.TensorAdapterConfig())\n", "\n", " # The TFXIO output format is chosen for improved performance.\n", " transformed_test_dataset = (\n", " (raw_test_dataset, transform_fn)\n", " | tft_beam.TransformDataset(output_record_batches=True))\n", "\n", " # Transformed metadata is not necessary for encoding.\n", " transformed_test_data, _ = transformed_test_dataset\n", "\n", " # Extract transformed RecordBatches, encode and write them to the given\n", " # directory.\n", " _ = (\n", " transformed_test_data\n", " | 'EncodeTestData' \u003e\u003e\n", " beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))\n", " | 'WriteTestData' \u003e\u003e beam.io.WriteToTFRecord(\n", " os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)))\n", "\n", " # Will write a SavedModel and metadata to working_dir, which can then\n", " # be read by the tft.TFTransformOutput class.\n", " _ = (\n", " transform_fn\n", " | 'WriteTransformFn' \u003e\u003e tft_beam.WriteTransformFn(working_dir))" ] }, { "cell_type": "markdown", "metadata": { "id": "huaj5EgCVRD9" }, "source": [ "Run the pipeline:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pjC7eDWFyA8K" }, "outputs": [], "source": [ "import tempfile\n", "import pathlib\n", "\n", "output_dir = os.path.join(tempfile.mkdtemp(), 'keras')\n", "\n", "\n", "transform_data(train_path, test_path, output_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "iqln2AClsA0z" }, "source": [ "Wrap up the output directory as a `tft.TFTransformOutput`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FXd4Mgj6sAGB" }, "outputs": [], "source": [ "tf_transform_output = tft.TFTransformOutput(output_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "59hNe7oY9vqG" }, "outputs": [], "source": [ "tf_transform_output.transformed_feature_spec()" ] }, { "cell_type": "markdown", "metadata": { "id": "oBBlL2EIVVF8" }, "source": [ "If you look in the directory you'll see it contains three things:\n", "\n", "1. The `train_transformed` and `test_transformed` data files\n", "2. The `transform_fn` directory (a `tf.saved_model`)\n", "3. The `transformed_metadata` \n", "\n", "The followning sections show how to use these artifacts to train a model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NG6nrHEP2L65" }, "outputs": [], "source": [ "!ls -l {output_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "TnaMyRMJ03bR" }, "source": [ "##Using our preprocessed data to train a model using tf.keras\n", "\n", "To show how `tf.Transform` enables us to use the same code for both training and serving, and thus prevent skew, we're going to train a model. To train our model and prepare our trained model for production we need to create input functions. The main difference between our training input function and our serving input function is that training data contains the labels, and production data does not. The arguments and returns are also somewhat different." ] }, { "cell_type": "markdown", "metadata": { "id": "M8xCZKNc2wAS" }, "source": [ "###Create an input function for training" ] }, { "cell_type": "markdown", "metadata": { "id": "StezlX-Uv0ae" }, "source": [ "Running the pipeline in the previous section created `TFRecord` files containing the the transformed data.\n", "\n", "The following code uses `tf.data.experimental.make_batched_features_dataset` and `tft.TFTransformOutput.transformed_feature_spec` to read these data files as a `tf.data.Dataset`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "775Y7BTpHBmb" }, "outputs": [], "source": [ "def _make_training_input_fn(tf_transform_output, train_file_pattern,\n", " batch_size):\n", " \"\"\"An input function reading from transformed data, converting to model input.\n", "\n", " Args:\n", " tf_transform_output: Wrapper around output of tf.Transform.\n", " transformed_examples: Base filename of examples.\n", " batch_size: Batch size.\n", "\n", " Returns:\n", " The input data for training or eval, in the form of k.\n", " \"\"\"\n", " def input_fn():\n", " return tf.data.experimental.make_batched_features_dataset(\n", " file_pattern=train_file_pattern,\n", " batch_size=batch_size,\n", " features=tf_transform_output.transformed_feature_spec(),\n", " reader=tf.data.TFRecordDataset,\n", " label_key=LABEL_KEY,\n", " shuffle=True)\n", "\n", " return input_fn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-b8BgvBvkCnX" }, "outputs": [], "source": [ "train_file_pattern = pathlib.Path(output_dir)/f'{TRANSFORMED_TRAIN_DATA_FILEBASE}*'\n", "\n", "input_fn = _make_training_input_fn(\n", " tf_transform_output=tf_transform_output,\n", " train_file_pattern = str(train_file_pattern),\n", " batch_size = 10\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Q0PwPLBqxsg2" }, "source": [ "Below you can see a transformed sample of the data. Note how the numeric columns like `education-num` and `hourd-per-week` are converted to floats with a range of [0,1], and the string columns have been converted to IDs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SpiS26IWlD-1" }, "outputs": [], "source": [ "for example, label in input_fn().take(1):\n", " break\n", "\n", "pd.DataFrame(example)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yaMzMnij88_v" }, "outputs": [], "source": [ "label" ] }, { "cell_type": "markdown", "metadata": { "id": "LyNTX7CO8AAz" }, "source": [ "### Train, Evaluate the model" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg9jXuLWuyK" }, "source": [ "Build the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uK4brUuDTAJ4" }, "outputs": [], "source": [ "def build_keras_model(working_dir):\n", " inputs = build_keras_inputs(working_dir)\n", "\n", " encoded_inputs = encode_inputs(inputs)\n", "\n", " stacked_inputs = tf.concat(tf.nest.flatten(encoded_inputs), axis=1)\n", " output = tf.keras.layers.Dense(100, activation='relu')(stacked_inputs)\n", " output = tf.keras.layers.Dense(50, activation='relu')(output)\n", " output = tf.keras.layers.Dense(2)(output)\n", " model = tf.keras.Model(inputs=inputs, outputs=output)\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6fJwIbdCRFER" }, "outputs": [], "source": [ "def build_keras_inputs(working_dir):\n", " tf_transform_output = tft.TFTransformOutput(working_dir)\n", "\n", " feature_spec = tf_transform_output.transformed_feature_spec().copy()\n", " feature_spec.pop(LABEL_KEY)\n", "\n", " # Build the `keras.Input` objects.\n", " inputs = {}\n", " for key, spec in feature_spec.items():\n", " if isinstance(spec, tf.io.VarLenFeature):\n", " inputs[key] = tf.keras.layers.Input(\n", " shape=[None], name=key, dtype=spec.dtype, sparse=True)\n", " elif isinstance(spec, tf.io.FixedLenFeature):\n", " inputs[key] = tf.keras.layers.Input(\n", " shape=spec.shape, name=key, dtype=spec.dtype)\n", " else:\n", " raise ValueError('Spec type is not supported: ', key, spec)\n", "\n", " return inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9dHD5SoqRqOh" }, "outputs": [], "source": [ "def encode_inputs(inputs):\n", " encoded_inputs = {}\n", " for key in inputs:\n", " feature = tf.expand_dims(inputs[key], -1)\n", " if key in CATEGORICAL_FEATURE_KEYS:\n", " num_buckets = tf_transform_output.num_buckets_for_transformed_feature(key)\n", " encoding_layer = (\n", " tf.keras.layers.CategoryEncoding(\n", " num_tokens=num_buckets, output_mode='binary', sparse=False))\n", " encoded_inputs[key] = encoding_layer(feature)\n", " else:\n", " encoded_inputs[key] = feature\n", " \n", " return encoded_inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5xNhSq8lTTx3" }, "outputs": [], "source": [ "model = build_keras_model(output_dir)\n", "\n", "tf.keras.utils.plot_model(model,rankdir='LR', show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "kQSpw_XzXVn1" }, "source": [ "Build the datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "afi3NOC0OMUa" }, "outputs": [], "source": [ "def get_dataset(working_dir, filebase):\n", " tf_transform_output = tft.TFTransformOutput(working_dir)\n", "\n", " data_path_pattern = os.path.join(\n", " working_dir,\n", " filebase + '*')\n", " \n", " input_fn = _make_training_input_fn(\n", " tf_transform_output,\n", " data_path_pattern,\n", " batch_size=BATCH_SIZE)\n", " \n", " dataset = input_fn()\n", "\n", " return dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "-fE_3jyzX_h2" }, "source": [ "Train and evaluate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6i_lhWH8IZrk" }, "outputs": [], "source": [ "def train_and_evaluate(\n", " model,\n", " working_dir):\n", " \"\"\"Train the model on training data and evaluate on test data.\n", "\n", " Args:\n", " working_dir: The location of the Transform output.\n", " num_train_instances: Number of instances in train set\n", " num_test_instances: Number of instances in test set\n", "\n", " Returns:\n", " The results from the estimator's 'evaluate' method\n", " \"\"\"\n", " train_dataset = get_dataset(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)\n", " validation_dataset = get_dataset(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)\n", "\n", " model = build_keras_model(working_dir)\n", "\n", " history = train_model(model, train_dataset, validation_dataset)\n", "\n", " metric_values = model.evaluate(validation_dataset,\n", " steps=EVALUATION_STEPS,\n", " return_dict=True)\n", " return model, history, metric_values" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rcVsByIsViRy" }, "outputs": [], "source": [ "def train_model(model, train_dataset, validation_dataset):\n", " model.compile(optimizer='adam',\n", " loss=tf.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", " history = model.fit(train_dataset, validation_data=validation_dataset,\n", " epochs=TRAIN_NUM_EPOCHS,\n", " steps_per_epoch=STEPS_PER_TRAIN_EPOCH,\n", " validation_steps=EVALUATION_STEPS)\n", " return history" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5xoioogYTle" }, "outputs": [], "source": [ "model, history, metric_values = train_and_evaluate(model, output_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gQCbdPIQeXeZ" }, "outputs": [], "source": [ "plt.plot(history.history['loss'], label='Train')\n", "plt.plot(history.history['val_loss'], label='Eval')\n", "plt.ylim(0,max(plt.ylim()))\n", "plt.legend()\n", "plt.title('Loss');" ] }, { "cell_type": "markdown", "metadata": { "id": "nYeuthrs27vl" }, "source": [ "### Transform new data\n", "\n", "In the previous section the training process used the hard-copies of the transformed data that were generated by `tft_beam.AnalyzeAndTransformDataset` in the `transform_dataset` function. \n", "\n", "For operating on new data you'll need to load final version of the `preprocessing_fn` that was saved by `tft_beam.WriteTransformFn`. \n", "\n", "The `TFTransformOutput.transform_features_layer` method loads the `preprocessing_fn` SavedModel from the output directory." ] }, { "cell_type": "markdown", "metadata": { "id": "zxi9aS106CLd" }, "source": [ "Here's a function to load new, unprocessed batches from a source file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tMHDZhp82ZjM" }, "outputs": [], "source": [ "def read_csv(file_name, batch_size):\n", " return tf.data.experimental.make_csv_dataset(\n", " file_pattern=file_name,\n", " batch_size=batch_size,\n", " column_names=ORDERED_CSV_COLUMNS,\n", " column_defaults=COLUMN_DEFAULTS,\n", " prefetch_buffer_size=0,\n", " ignore_errors=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AradAjmW2vyd" }, "outputs": [], "source": [ "for ex in read_csv(test_path, batch_size=5):\n", " break\n", "\n", "pd.DataFrame(ex)" ] }, { "cell_type": "markdown", "metadata": { "id": "OX1f6SgM6LZc" }, "source": [ "Load the `tft.TransformFeaturesLayer` to transform this data with the `preprocessing_fn`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nma2Bzi--11x" }, "outputs": [], "source": [ "ex2 = ex.copy()\n", "ex2.pop('fnlwgt')\n", "\n", "tft_layer = tf_transform_output.transform_features_layer()\n", "t_ex = tft_layer(ex2)\n", "\n", "label = t_ex.pop(LABEL_KEY)\n", "pd.DataFrame(t_ex)" ] }, { "cell_type": "markdown", "metadata": { "id": "P43ixyQNz1zq" }, "source": [ "The `tft_layer` is smart enough to still execute the transformation if only a subset of features are passed in. For example, if you only pass in two features, you'll get just the transformed versions of those features back: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "swEPuZsR0Y5S" }, "outputs": [], "source": [ "ex2 = pd.DataFrame(ex)[['education', 'hours-per-week']]\n", "ex2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_s4SxutV1DTI" }, "outputs": [], "source": [ "pd.DataFrame(tft_layer(dict(ex2)))" ] }, { "cell_type": "markdown", "metadata": { "id": "x5wo3dN-vhFL" }, "source": [ "Here's a more robust version that drops features that are not in the feature-spec, and returns a `(features, label)` pair if the label is in the provided features:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hdMKDnafJh64" }, "outputs": [], "source": [ "class Transform(tf.Module):\n", " def __init__(self, working_dir):\n", " self.working_dir = working_dir\n", " self.tf_transform_output = tft.TFTransformOutput(working_dir)\n", " self.tft_layer = tf_transform_output.transform_features_layer()\n", " \n", " @tf.function\n", " def __call__(self, features):\n", " raw_features = {}\n", "\n", " for key, val in features.items():\n", " # Skip unused keys\n", " if key not in RAW_DATA_FEATURE_SPEC:\n", " continue\n", "\n", " raw_features[key] = val\n", "\n", " # Apply the `preprocessing_fn`.\n", " transformed_features = tft_layer(raw_features)\n", " \n", " if LABEL_KEY in transformed_features:\n", " # Pop the label and return a (features, labels) pair.\n", " data_labels = transformed_features.pop(LABEL_KEY)\n", " return (transformed_features, data_labels)\n", " else:\n", " return transformed_features\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mm5HI578Ku1B" }, "outputs": [], "source": [ "transform = Transform(output_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4jeenwN_3ZRj" }, "outputs": [], "source": [ "t_ex, t_label = transform(ex)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yIavZAqALO8H" }, "outputs": [], "source": [ "pd.DataFrame(t_ex)" ] }, { "cell_type": "markdown", "metadata": { "id": "LVQead0fwVuy" }, "source": [ "Now you can use `Dataset.map` to apply that transformation, on the fly to new data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VN3IO6u1Mk83" }, "outputs": [], "source": [ "model.evaluate(\n", " read_csv(test_path, batch_size=5).map(transform),\n", " steps=EVALUATION_STEPS,\n", " return_dict=True\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Ymlco3hfU_-E" }, "source": [ "### Export the model\n", "\n", "So you have a trained model, and a method to apply the `preporcessing_fn` to new data. Assemble them into a new model that accepts serialized `tf.train.Example` protos as input." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AZ2WICuwEwqC" }, "outputs": [], "source": [ "class ServingModel(tf.Module):\n", " def __init__(self, model, working_dir):\n", " self.model = model\n", " self.working_dir = working_dir\n", " self.transform = Transform(working_dir)\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])\n", " def __call__(self, serialized_tf_examples):\n", " # parse the tf.train.Example\n", " feature_spec = RAW_DATA_FEATURE_SPEC.copy()\n", " feature_spec.pop(LABEL_KEY)\n", " parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)\n", " # Apply the `preprocessing_fn`\n", " transformed_features = self.transform(parsed_features)\n", " # Run the model\n", " outputs = self.model(transformed_features)\n", " # Format the output\n", " classes_names = tf.constant([['0', '1']])\n", " classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1])\n", " return {'classes': classes, 'scores': outputs}\n", "\n", " def export(self, output_dir):\n", " # Increment the directory number. This is required in order to make this\n", " # model servable with model_server.\n", " save_model_dir = pathlib.Path(output_dir)/'model'\n", " number_dirs = [int(p.name) for p in save_model_dir.glob('*')\n", " if p.name.isdigit()]\n", " id = max([0] + number_dirs)+1\n", " save_model_dir = save_model_dir/str(id)\n", "\n", " # Set the signature to make it visible for serving.\n", " concrete_serving_fn = self.__call__.get_concrete_function()\n", " signatures = {'serving_default': concrete_serving_fn}\n", "\n", " # Export the model.\n", " tf.saved_model.save(\n", " self,\n", " str(save_model_dir),\n", " signatures=signatures)\n", " \n", " return save_model_dir" ] }, { "cell_type": "markdown", "metadata": { "id": "M8TZf2di24L2" }, "source": [ "Build the model and test-run it on the batch of serialized examples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u2mSC1UMGAwJ" }, "outputs": [], "source": [ "serving_model = ServingModel(model, output_dir)\n", "\n", "serving_model(serialized_example_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "BWhighof3AK8" }, "source": [ "Export the model as a SavedModel:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kodDWTJIEr77" }, "outputs": [], "source": [ "saved_model_dir = serving_model.export(output_dir)\n", "saved_model_dir" ] }, { "cell_type": "markdown", "metadata": { "id": "ohbWxp3-3aQu" }, "source": [ "Reload the the model and test it on the same batch of examples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nShh6GqcEr78" }, "outputs": [], "source": [ "reloaded = tf.saved_model.load(str(saved_model_dir))\n", "run_model = reloaded.signatures['serving_default']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UiYJhQySEr78" }, "outputs": [], "source": [ "run_model(serialized_example_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "ICqetCnSjwp1" }, "source": [ "##What we did\n", "In this example we used `tf.Transform` to preprocess a dataset of census data, and train a model with the cleaned and transformed data. We also created an input function that we could use when we deploy our trained model in a production environment to perform inference. By using the same code for both training and inference we avoid any issues with data skew. Along the way we learned about creating an Apache Beam transform to perform the transformation that we needed for cleaning the data. We also saw how to use this transformed data to train a model using `tf.keras`. This is just a small piece of what TensorFlow Transform can do! We encourage you to dive into `tf.Transform` and discover what it can do for you." ] }, { "cell_type": "markdown", "metadata": { "id": "APEUSA9boKgT" }, "source": [ "## [Optional] Using our preprocessed data to train a model using tf.estimator\n", "\n", "\u003e Warning: Estimators are not recommended for new code. Estimators run\n", "\u003ca href=\"https://www.tensorflow.org/api_docs/python/tf/compat/v1/Session\"\u003e\u003ccode\u003ev1.Session\u003c/code\u003e\u003c/a\u003e-style code which is more difficult to write correctly, and\n", "can behave unexpectedly, especially when combined with TF 2 code. Estimators\n", "do fall under our\n", "[compatibility guarantees](https://tensorflow.org/guide/versions), but will\n", "receive no fixes other than security vulnerabilities. See the\n", "[migration guide](https://tensorflow.org/guide/migrate) for details.\n", "\n", " \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n", " \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eShow Section\u003c/button\u003e --\u003e\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QcBWjr3ioZbl" }, "source": [ "###Create an input function for training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kFO0MeWQ228a" }, "outputs": [], "source": [ "def _make_training_input_fn(tf_transform_output, transformed_examples,\n", " batch_size):\n", " \"\"\"Creates an input function reading from transformed data.\n", "\n", " Args:\n", " tf_transform_output: Wrapper around output of tf.Transform.\n", " transformed_examples: Base filename of examples.\n", " batch_size: Batch size.\n", "\n", " Returns:\n", " The input function for training or eval.\n", " \"\"\"\n", " def input_fn():\n", " \"\"\"Input function for training and eval.\"\"\"\n", " dataset = tf.data.experimental.make_batched_features_dataset(\n", " file_pattern=transformed_examples,\n", " batch_size=batch_size,\n", " features=tf_transform_output.transformed_feature_spec(),\n", " reader=tf.data.TFRecordDataset,\n", " shuffle=True)\n", "\n", " transformed_features = tf.compat.v1.data.make_one_shot_iterator(\n", " dataset).get_next()\n", "\n", " # Extract features and label from the transformed tensors.\n", " transformed_labels = tf.where(\n", " tf.equal(transformed_features.pop(LABEL_KEY), 1))\n", "\n", " return transformed_features, transformed_labels[:,1]\n", "\n", " return input_fn" ] }, { "cell_type": "markdown", "metadata": { "id": "22XOsZ-noez-" }, "source": [ "###Create an input function for serving\n", "\n", "Let's create an input function that we could use in production, and prepare our trained model for serving." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "code", "id": "NN5FVg343Jea" }, "outputs": [], "source": [ "def _make_serving_input_fn(tf_transform_output):\n", " \"\"\"Creates an input function reading from raw data.\n", "\n", " Args:\n", " tf_transform_output: Wrapper around output of tf.Transform.\n", "\n", " Returns:\n", " The serving input function.\n", " \"\"\"\n", " raw_feature_spec = RAW_DATA_FEATURE_SPEC.copy()\n", " # Remove label since it is not available during serving.\n", " raw_feature_spec.pop(LABEL_KEY)\n", "\n", " def serving_input_fn():\n", " \"\"\"Input function for serving.\"\"\"\n", " # Get raw features by generating the basic serving input_fn and calling it.\n", " # Here we generate an input_fn that expects a parsed Example proto to be fed\n", " # to the model at serving time. See also\n", " # tf.estimator.export.build_raw_serving_input_receiver_fn.\n", " raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n", " raw_feature_spec, default_batch_size=None)\n", " serving_input_receiver = raw_input_fn()\n", "\n", " # Apply the transform function that was used to generate the materialized\n", " # data.\n", " raw_features = serving_input_receiver.features\n", " transformed_features = tf_transform_output.transform_raw_features(\n", " raw_features)\n", "\n", " return tf.estimator.export.ServingInputReceiver(\n", " transformed_features, serving_input_receiver.receiver_tensors)\n", "\n", " return serving_input_fn" ] }, { "cell_type": "markdown", "metadata": { "id": "Vc9Edp8A7dsI" }, "source": [ "###Wrap our input data in FeatureColumns\n", "Our model will expect our data in TensorFlow FeatureColumns." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6qOFOvBk7oJX" }, "outputs": [], "source": [ "def get_feature_columns(tf_transform_output):\n", " \"\"\"Returns the FeatureColumns for the model.\n", "\n", " Args:\n", " tf_transform_output: A `TFTransformOutput` object.\n", "\n", " Returns:\n", " A list of FeatureColumns.\n", " \"\"\"\n", " # Wrap scalars as real valued columns.\n", " real_valued_columns = [tf.feature_column.numeric_column(key, shape=())\n", " for key in NUMERIC_FEATURE_KEYS]\n", "\n", " # Wrap categorical columns.\n", " one_hot_columns = [\n", " tf.feature_column.indicator_column(\n", " tf.feature_column.categorical_column_with_identity(\n", " key=key,\n", " num_buckets=(NUM_OOV_BUCKETS +\n", " tf_transform_output.vocabulary_size_by_name(\n", " vocab_filename=key))))\n", " for key in CATEGORICAL_FEATURE_KEYS]\n", "\n", " return real_valued_columns + one_hot_columns" ] }, { "cell_type": "markdown", "metadata": { "id": "f6FyMzMcpOgT" }, "source": [ "###Train, Evaluate, and Export our model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8iGQ0jeq8IWr" }, "outputs": [], "source": [ "def train_and_evaluate(working_dir, num_train_instances=NUM_TRAIN_INSTANCES,\n", " num_test_instances=NUM_TEST_INSTANCES):\n", " \"\"\"Train the model on training data and evaluate on test data.\n", "\n", " Args:\n", " working_dir: Directory to read transformed data and metadata from and to\n", " write exported model to.\n", " num_train_instances: Number of instances in train set\n", " num_test_instances: Number of instances in test set\n", "\n", " Returns:\n", " The results from the estimator's 'evaluate' method\n", " \"\"\"\n", " tf_transform_output = tft.TFTransformOutput(working_dir)\n", "\n", " run_config = tf.estimator.RunConfig()\n", "\n", " estimator = tf.estimator.LinearClassifier(\n", " feature_columns=get_feature_columns(tf_transform_output),\n", " config=run_config,\n", " loss_reduction=tf.losses.Reduction.SUM)\n", "\n", " # Fit the model using the default optimizer.\n", " train_input_fn = _make_training_input_fn(\n", " tf_transform_output,\n", " os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + '*'),\n", " batch_size=BATCH_SIZE)\n", " estimator.train(\n", " input_fn=train_input_fn,\n", " max_steps=TRAIN_NUM_EPOCHS * num_train_instances / BATCH_SIZE)\n", "\n", " # Evaluate model on test dataset.\n", " eval_input_fn = _make_training_input_fn(\n", " tf_transform_output,\n", " os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE + '*'),\n", " batch_size=1)\n", "\n", " # Export the model.\n", " serving_input_fn = _make_serving_input_fn(tf_transform_output)\n", " exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR)\n", " estimator.export_saved_model(exported_model_dir, serving_input_fn)\n", "\n", " return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances)" ] }, { "cell_type": "markdown", "metadata": { "id": "5k8LrDPZpZsK" }, "source": [ "###Put it all together\n", "We've created all the stuff we need to preprocess our census data, train a model, and prepare it for serving. So far we've just been getting things ready. It's time to start running!\n", "\n", "Note: Scroll the output from this cell to see the whole process. The results will be at the bottom." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P_1_2dB6pdc2" }, "outputs": [], "source": [ "import tempfile\n", "temp = temp = os.path.join(tempfile.mkdtemp(),'estimator')\n", "\n", "transform_data(train_path, test_path, temp)\n", "results = train_and_evaluate(temp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O_IqGL90GCIq" }, "outputs": [], "source": [ "pprint.pprint(results)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z6T3aHoRsjgR" }, "source": [ " \u003c/devsite-expandable\u003e\u003c/div\u003e\n" ] } ], "metadata": { "colab": { "collapsed_sections": [ "APEUSA9boKgT" ], "name": "census.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }