{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "linear.ipynb", "version": "0.3.2", "provenance": [], "private_outputs": true, "collapsed_sections": [ "MWW1TyjaecRh" ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "cells": [ { "metadata": { "id": "MWW1TyjaecRh", "colab_type": "text" }, "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "metadata": { "id": "mOtR1FzCef-u", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Zr7KpBhMcYvE", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Build a linear model with Estimators" ] }, { "metadata": { "id": "uJl4gaPFzxQz", "colab_type": "text" }, "cell_type": "markdown", "source": [ "\n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", "
" ] }, { "metadata": { "id": "77aETSYDcdoK", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This tutorial uses the `tf.estimator` API in TensorFlow to solve a benchmark binary classification problem. Estimators are TensorFlow's most scalable and production-oriented model type. For more information see the [Estimator guide](https://www.tensorflow.org/guide/estimators).\n", "\n", "## Overview\n", "\n", "Using census data which contains data a person's age, education, marital status, and occupation (the *features*), we will try to predict whether or not the person earns more than 50,000 dollars a year (the target *label*). We will train a *logistic regression* model that, given an individual's information, outputs a number between 0 and 1—this can be interpreted as the probability that the individual has an annual income of over 50,000 dollars.\n", "\n", "Key Point: As a modeler and developer, think about how this data is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is each feature relevant to the problem you want to solve or will it introduce bias? For more information, read about [ML fairness](https://developers.google.com/machine-learning/fairness-overview/).\n", "\n", "## Setup\n", "\n", "Import TensorFlow, feature column support, and supporting modules:" ] }, { "metadata": { "id": "NQgONe5ecYvE", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function, unicode_literals\n", "\n", "import tensorflow as tf\n", "import tensorflow.feature_column as fc\n", "\n", "import os\n", "import sys\n", "\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Rpb1JSMj1nqk", "colab_type": "text" }, "cell_type": "markdown", "source": [ "And let's enable [eager execution](https://www.tensorflow.org/guide/eager) to inspect this program as we run it:" ] }, { "metadata": { "id": "tQzxON782Eby", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "tf.enable_eager_execution()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "-MPr95UccYvL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Download the official implementation\n", "\n", "We'll use the [wide and deep model](https://github.com/tensorflow/models/tree/master/official/wide_deep/) available in TensorFlow's [model repository](https://github.com/tensorflow/models/). Download the code, add the root directory to your Python path, and jump to the `wide_deep` directory:" ] }, { "metadata": { "id": "tTwQzWcn8aBu", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "! pip install requests\n", "! git clone --depth 1 https://github.com/tensorflow/models" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "sRpuysc73Eb-", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Add the root directory of the repository to your Python path:" ] }, { "metadata": { "id": "yVvFyhnkcYvL", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "models_path = os.path.join(os.getcwd(), 'models')\n", "\n", "sys.path.append(models_path)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "15Ethw-wcYvP", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Download the dataset:" ] }, { "metadata": { "id": "6QilS4-0cYvQ", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from official.wide_deep import census_dataset\n", "from official.wide_deep import census_main\n", "\n", "census_dataset.download(\"/tmp/census_data/\")" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "cD5e3ibAcYvS", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Command line usage\n", "\n", "The repo includes a complete program for experimenting with this type of model.\n", "\n", "To execute the tutorial code from the command line first add the path to tensorflow/models to your `PYTHONPATH`." ] }, { "metadata": { "id": "DYOkY8boUptJ", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "#export PYTHONPATH=${PYTHONPATH}:\"$(pwd)/models\"\n", "#running from python you need to set the `os.environ` or the subprocess will not see the directory.\n", "\n", "if \"PYTHONPATH\" in os.environ:\n", " os.environ['PYTHONPATH'] += os.pathsep + models_path\n", "else:\n", " os.environ['PYTHONPATH'] = models_path" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "5r0V9YUMUyoh", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Use `--help` to see what command line options are available:" ] }, { "metadata": { "id": "1_3tBaLW4YM4", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!python -m official.wide_deep.census_main --help" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "RrMLazEN6DMj", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Now run the model:\n" ] }, { "metadata": { "id": "py7MarZl5Yh6", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!python -m official.wide_deep.census_main --model_type=wide --train_epochs=2" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "AmZ4CpaOcYvV", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Read the U.S. Census data\n", "\n", "This example uses the [U.S Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income) from 1994 and 1995. We have provided the [census_dataset.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_dataset.py) script to download the data and perform a little cleanup.\n", "\n", "Since the task is a *binary classification problem*, we'll construct a label column named \"label\" whose value is 1 if the income is over 50K, and 0 otherwise. For reference, see the `input_fn` in [census_main.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_main.py).\n", "\n", "Let's look at the data to see which columns we can use to predict the target label:" ] }, { "metadata": { "id": "N6Tgye8bcYvX", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!ls /tmp/census_data/" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "6y3mj9zKcYva", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "train_file = \"/tmp/census_data/adult.data\"\n", "test_file = \"/tmp/census_data/adult.test\"" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "EO_McKgE5il2", "colab_type": "text" }, "cell_type": "markdown", "source": [ "[pandas](https://pandas.pydata.org/) provides some convenient utilities for data analysis. Here's a list of columns available in the Census Income dataset:" ] }, { "metadata": { "id": "vkn1FNmpcYvb", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import pandas\n", "\n", "train_df = pandas.read_csv(train_file, header = None, names = census_dataset._CSV_COLUMNS)\n", "test_df = pandas.read_csv(test_file, header = None, names = census_dataset._CSV_COLUMNS)\n", "\n", "train_df.head()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "QZZtXes4cYvf", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The columns are grouped into two types: *categorical* and *continuous* columns:\n", "\n", "* A column is called *categorical* if its value can only be one of the categories in a finite set. For example, the relationship status of a person (wife, husband, unmarried, etc.) or the education level (high school, college, etc.) are categorical columns.\n", "* A column is called *continuous* if its value can be any numerical value in a continuous range. For example, the capital gain of a person (e.g. $14,084) is a continuous column.\n", "\n", "## Converting Data into Tensors\n", "\n", "When building a `tf.estimator` model, the input data is specified by using an *input function* (or `input_fn`). This builder function returns a `tf.data.Dataset` of batches of `(features-dict, label)` pairs. It is not called until it is passed to `tf.estimator.Estimator` methods such as `train` and `evaluate`.\n", "\n", "The input builder function returns the following pair:\n", "\n", "1. `features`: A dict from feature names to `Tensors` or `SparseTensors` containing batches of features.\n", "2. `labels`: A `Tensor` containing batches of labels.\n", "\n", "The keys of the `features` are used to configure the model's input layer.\n", "\n", "Note: The input function is called while constructing the TensorFlow graph, *not* while running the graph. It is returning a representation of the input data as a sequence of TensorFlow graph operations.\n", "\n", "For small problems like this, it's easy to make a `tf.data.Dataset` by slicing the `pandas.DataFrame`:" ] }, { "metadata": { "id": "N7zNJflKcYvg", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def easy_input_function(df, label_key, num_epochs, shuffle, batch_size):\n", " label = df[label_key]\n", " ds = tf.data.Dataset.from_tensor_slices((dict(df),label))\n", "\n", " if shuffle:\n", " ds = ds.shuffle(10000)\n", "\n", " ds = ds.batch(batch_size).repeat(num_epochs)\n", "\n", " return ds" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "WeEgNR9AcYvh", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Since we have eager execution enabled, it's easy to inspect the resulting dataset:" ] }, { "metadata": { "id": "ygaKuikecYvi", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "ds = easy_input_function(train_df, label_key='income_bracket', num_epochs=5, shuffle=True, batch_size=10)\n", "\n", "for feature_batch, label_batch in ds.take(1):\n", " print('Some feature keys:', list(feature_batch.keys())[:5])\n", " print()\n", " print('A batch of Ages :', feature_batch['age'])\n", " print()\n", " print('A batch of Labels:', label_batch )" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "O_KZxQUucYvm", "colab_type": "text" }, "cell_type": "markdown", "source": [ "But this approach has severly-limited scalability. Larger datasets should be streamed from disk. The `census_dataset.input_fn` provides an example of how to do this using `tf.decode_csv` and `tf.data.TextLineDataset`:\n", "\n", "" ] }, { "metadata": { "id": "vUTeXaEUcYvn", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import inspect\n", "print(inspect.getsource(census_dataset.input_fn))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "yyGcv_e-cYvq", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This `input_fn` returns equivalent output:" ] }, { "metadata": { "id": "Mv3as_CEcYvu", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "ds = census_dataset.input_fn(train_file, num_epochs=5, shuffle=True, batch_size=10)\n", "\n", "for feature_batch, label_batch in ds.take(1):\n", " print('Feature keys:', list(feature_batch.keys())[:5])\n", " print()\n", " print('Age batch :', feature_batch['age'])\n", " print()\n", " print('Label batch :', label_batch )" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "810fnfY5cYvz", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Because `Estimators` expect an `input_fn` that takes no arguments, we typically wrap configurable input function into an obejct with the expected signature. For this notebook configure the `train_inpf` to iterate over the data twice:" ] }, { "metadata": { "id": "wnQdpEcVcYv0", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import functools\n", "\n", "train_inpf = functools.partial(census_dataset.input_fn, train_file, num_epochs=2, shuffle=True, batch_size=64)\n", "test_inpf = functools.partial(census_dataset.input_fn, test_file, num_epochs=1, shuffle=False, batch_size=64)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "pboNpNWhcYv4", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Selecting and Engineering Features for the Model\n", "\n", "Estimators use a system called [feature columns](https://www.tensorflow.org/guide/feature_columns) to describe how the model should interpret each of the raw input features. An Estimator expects a vector of numeric inputs, and feature columns describe how the model should convert each feature.\n", "\n", "Selecting and crafting the right set of feature columns is key to learning an effective model. A *feature column* can be either one of the raw inputs in the original features `dict` (a *base feature column*), or any new columns created using transformations defined over one or multiple base columns (a *derived feature columns*).\n", "\n", "A feature column is an abstract concept of any raw or derived variable that can be used to predict the target label." ] }, { "metadata": { "id": "_hh-cWdU__Lq", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Base Feature Columns" ] }, { "metadata": { "id": "BKz6LA8_ACI7", "colab_type": "text" }, "cell_type": "markdown", "source": [ "#### Numeric columns\n", "\n", "The simplest `feature_column` is `numeric_column`. This indicates that a feature is a numeric value that should be input to the model directly. For example:" ] }, { "metadata": { "id": "ZX0r2T5OcYv6", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "age = fc.numeric_column('age')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "tnLUiaHxcYv-", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The model will use the `feature_column` definitions to build the model input. You can inspect the resulting output using the `input_layer` function:" ] }, { "metadata": { "id": "kREtIPfwcYv_", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "fc.input_layer(feature_batch, [age]).numpy()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "OPuLduCucYwD", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The following will train and evaluate a model using only the `age` feature:" ] }, { "metadata": { "id": "9R5eSJ1pcYwE", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "classifier = tf.estimator.LinearClassifier(feature_columns=[age])\n", "classifier.train(train_inpf)\n", "result = classifier.evaluate(test_inpf)\n", "\n", "clear_output() # used for display in notebook\n", "print(result)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "YDZGcdTdcYwI", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Similarly, we can define a `NumericColumn` for each continuous feature column\n", "that we want to use in the model:" ] }, { "metadata": { "id": "uqPbUqlxcYwJ", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "education_num = tf.feature_column.numeric_column('education_num')\n", "capital_gain = tf.feature_column.numeric_column('capital_gain')\n", "capital_loss = tf.feature_column.numeric_column('capital_loss')\n", "hours_per_week = tf.feature_column.numeric_column('hours_per_week')\n", "\n", "my_numeric_columns = [age,education_num, capital_gain, capital_loss, hours_per_week]\n", "\n", "fc.input_layer(feature_batch, my_numeric_columns).numpy()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "cBGDN97IcYwQ", "colab_type": "text" }, "cell_type": "markdown", "source": [ "You could retrain a model on these features by changing the `feature_columns` argument to the constructor:" ] }, { "metadata": { "id": "XN8k5S95cYwR", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "classifier = tf.estimator.LinearClassifier(feature_columns=my_numeric_columns)\n", "classifier.train(train_inpf)\n", "\n", "result = classifier.evaluate(test_inpf)\n", "\n", "clear_output()\n", "\n", "for key,value in sorted(result.items()):\n", " print('%s: %s' % (key, value))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "jBRq9_AzcYwU", "colab_type": "text" }, "cell_type": "markdown", "source": [ "#### Categorical columns\n", "\n", "To define a feature column for a categorical feature, create a `CategoricalColumn` using one of the `tf.feature_column.categorical_column*` functions.\n", "\n", "If you know the set of all possible feature values of a column—and there are only a few of them—use `categorical_column_with_vocabulary_list`. Each key in the list is assigned an auto-incremented ID starting from 0. For example, for the `relationship` column we can assign the feature string `Husband` to an integer ID of 0 and \"Not-in-family\" to 1, etc." ] }, { "metadata": { "id": "0IjqSi9tcYwV", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "relationship = fc.categorical_column_with_vocabulary_list(\n", " 'relationship',\n", " ['Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', 'Other-relative'])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "-RjoWv-7cYwW", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This creates a sparse one-hot vector from the raw input feature.\n", "\n", "The `input_layer` function we're using is designed for DNN models and expects dense inputs. To demonstrate the categorical column we must wrap it in a `tf.feature_column.indicator_column` to create the dense one-hot output (Linear `Estimators` can often skip this dense-step).\n", "\n", "Note: the other sparse-to-dense option is `tf.feature_column.embedding_column`.\n", "\n", "Run the input layer, configured with both the `age` and `relationship` columns:" ] }, { "metadata": { "id": "kI43CYlncYwY", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "fc.input_layer(feature_batch, [age, fc.indicator_column(relationship)])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "tTudP7WHcYwb", "colab_type": "text" }, "cell_type": "markdown", "source": [ "If we don't know the set of possible values in advance, use the `categorical_column_with_hash_bucket` instead:" ] }, { "metadata": { "id": "8pSBaliCcYwb", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "occupation = tf.feature_column.categorical_column_with_hash_bucket(\n", " 'occupation', hash_bucket_size=1000)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "fSAPrqQkcYwd", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Here, each possible value in the feature column `occupation` is hashed to an integer ID as we encounter them in training. The example batch has a few different occupations:" ] }, { "metadata": { "id": "dCvQNv36cYwe", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "for item in feature_batch['occupation'].numpy():\n", " print(item.decode())" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "KP5hN2rAcYwh", "colab_type": "text" }, "cell_type": "markdown", "source": [ "If we run `input_layer` with the hashed column, we see that the output shape is `(batch_size, hash_bucket_size)`:" ] }, { "metadata": { "id": "0Y16peWacYwh", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "occupation_result = fc.input_layer(feature_batch, [fc.indicator_column(occupation)])\n", "\n", "occupation_result.numpy().shape" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "HMW2MzWAcYwk", "colab_type": "text" }, "cell_type": "markdown", "source": [ "It's easier to see the actual results if we take the `tf.argmax` over the `hash_bucket_size` dimension. Notice how any duplicate occupations are mapped to the same pseudo-random index:" ] }, { "metadata": { "id": "q_ryRglmcYwk", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "tf.argmax(occupation_result, axis=1).numpy()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "j1e5NfyKcYwn", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Note: Hash collisions are unavoidable, but often have minimal impact on model quality. The effect may be noticable if the hash buckets are being used to compress the input space. See [this notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/outreach/blogs/housing_prices.ipynb) for a more visual example of the effect of these hash collisions.\n", "\n", "No matter how we choose to define a `SparseColumn`, each feature string is mapped into an integer ID by looking up a fixed mapping or by hashing. Under the hood, the `LinearModel` class is responsible for managing the mapping and creating `tf.Variable` to store the model parameters (model *weights*) for each feature ID. The model parameters are learned through the model training process described later.\n", "\n", "Let's do the similar trick to define the other categorical features:" ] }, { "metadata": { "id": "0Z5eUrd_cYwo", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "education = tf.feature_column.categorical_column_with_vocabulary_list(\n", " 'education', [\n", " 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',\n", " 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',\n", " '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])\n", "\n", "marital_status = tf.feature_column.categorical_column_with_vocabulary_list(\n", " 'marital_status', [\n", " 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',\n", " 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])\n", "\n", "workclass = tf.feature_column.categorical_column_with_vocabulary_list(\n", " 'workclass', [\n", " 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',\n", " 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])\n", "\n", "\n", "my_categorical_columns = [relationship, occupation, education, marital_status, workclass]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "ASQJM1pEcYwr", "colab_type": "text" }, "cell_type": "markdown", "source": [ "It's easy to use both sets of columns to configure a model that uses all these features:" ] }, { "metadata": { "id": "_i_MLoo9cYws", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "classifier = tf.estimator.LinearClassifier(feature_columns=my_numeric_columns+my_categorical_columns)\n", "classifier.train(train_inpf)\n", "result = classifier.evaluate(test_inpf)\n", "\n", "clear_output()\n", "\n", "for key,value in sorted(result.items()):\n", " print('%s: %s' % (key, value))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "zdKEqF6xcYwv", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Derived feature columns" ] }, { "metadata": { "id": "RgYaf_48FSU2", "colab_type": "text" }, "cell_type": "markdown", "source": [ "#### Make Continuous Features Categorical through Bucketization\n", "\n", "Sometimes the relationship between a continuous feature and the label is not linear. For example, *age* and *income*—a person's income may grow in the early stage of their career, then the growth may slow at some point, and finally, the income decreases after retirement. In this scenario, using the raw `age` as a real-valued feature column might not be a good choice because the model can only learn one of the three cases:\n", "\n", "1. Income always increases at some rate as age grows (positive correlation),\n", "2. Income always decreases at some rate as age grows (negative correlation), or\n", "3. Income stays the same no matter at what age (no correlation).\n", "\n", "If we want to learn the fine-grained correlation between income and each age group separately, we can leverage *bucketization*. Bucketization is a process of dividing the entire range of a continuous feature into a set of consecutive buckets, and then converting the original numerical feature into a bucket ID (as a categorical feature) depending on which bucket that value falls into. So, we can define a `bucketized_column` over `age` as:" ] }, { "metadata": { "id": "KT4pjD9AcYww", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "age_buckets = tf.feature_column.bucketized_column(\n", " age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "S-XOscrEcYwx", "colab_type": "text" }, "cell_type": "markdown", "source": [ "`boundaries` is a list of bucket boundaries. In this case, there are 10 boundaries, resulting in 11 age group buckets (from age 17 and below, 18-24, 25-29, ..., to 65 and over).\n", "\n", "With bucketing, the model sees each bucket as a one-hot feature:" ] }, { "metadata": { "id": "Lr40vm3qcYwy", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "fc.input_layer(feature_batch, [age, age_buckets]).numpy()" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Z_tQI9j8cYw1", "colab_type": "text" }, "cell_type": "markdown", "source": [ "#### Learn complex relationships with crossed column\n", "\n", "Using each base feature column separately may not be enough to explain the data. For example, the correlation between education and the label (earning > 50,000 dollars) may be different for different occupations. Therefore, if we only learn a single model weight for `education=\"Bachelors\"` and `education=\"Masters\"`, we won't capture every education-occupation combination (e.g. distinguishing between `education=\"Bachelors\"` AND `occupation=\"Exec-managerial\"` AND `education=\"Bachelors\" AND occupation=\"Craft-repair\"`).\n", "\n", "To learn the differences between different feature combinations, we can add *crossed feature columns* to the model:" ] }, { "metadata": { "id": "IAPhPzXscYw1", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "education_x_occupation = tf.feature_column.crossed_column(\n", " ['education', 'occupation'], hash_bucket_size=1000)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "UeTxMunbcYw5", "colab_type": "text" }, "cell_type": "markdown", "source": [ "We can also create a `crossed_column` over more than two columns. Each constituent column can be either a base feature column that is categorical (`SparseColumn`), a bucketized real-valued feature column, or even another `CrossColumn`. For example:" ] }, { "metadata": { "id": "y8UaBld9cYw7", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "age_buckets_x_education_x_occupation = tf.feature_column.crossed_column(\n", " [age_buckets, 'education', 'occupation'], hash_bucket_size=1000)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "HvKmW6U5cYw8", "colab_type": "text" }, "cell_type": "markdown", "source": [ "These crossed columns always use hash buckets to avoid the exponential explosion in the number of categories, and put the control over number of model weights in the hands of the user.\n", "\n", "For a visual example the effect of hash-buckets with crossed columns see [this notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/outreach/blogs/housing_prices.ipynb)\n" ] }, { "metadata": { "id": "HtjpheB6cYw9", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Define the logistic regression model\n", "\n", "After processing the input data and defining all the feature columns, we can put them together and build a *logistic regression* model. The previous section showed several types of base and derived feature columns, including:\n", "\n", "* `CategoricalColumn`\n", "* `NumericColumn`\n", "* `BucketizedColumn`\n", "* `CrossedColumn`\n", "\n", "All of these are subclasses of the abstract `FeatureColumn` class and can be added to the `feature_columns` field of a model:" ] }, { "metadata": { "id": "Klmf3OxpcYw-", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import tempfile\n", "\n", "base_columns = [\n", " education, marital_status, relationship, workclass, occupation,\n", " age_buckets,\n", "]\n", "\n", "crossed_columns = [\n", " tf.feature_column.crossed_column(\n", " ['education', 'occupation'], hash_bucket_size=1000),\n", " tf.feature_column.crossed_column(\n", " [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),\n", "]\n", "\n", "model = tf.estimator.LinearClassifier(\n", " model_dir=tempfile.mkdtemp(),\n", " feature_columns=base_columns + crossed_columns,\n", " optimizer=tf.train.FtrlOptimizer(learning_rate=0.1))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "jRhnPxUucYxC", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The model automatically learns a bias term, which controls the prediction made without observing any features. The learned model files are stored in `model_dir`.\n", "\n", "## Train and evaluate the model\n", "\n", "After adding all the features to the model, let's train the model. Training a model is just a single command using the `tf.estimator` API:" ] }, { "metadata": { "id": "ZlrIBuoecYxD", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "train_inpf = functools.partial(census_dataset.input_fn, train_file,\n", " num_epochs=40, shuffle=True, batch_size=64)\n", "\n", "model.train(train_inpf)\n", "\n", "clear_output() # used for notebook display" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "IvY3a9pzcYxH", "colab_type": "text" }, "cell_type": "markdown", "source": [ "After the model is trained, evaluate the accuracy of the model by predicting the labels of the holdout data:" ] }, { "metadata": { "id": "L9nVJEO8cYxI", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "results = model.evaluate(test_inpf)\n", "\n", "clear_output()\n", "\n", "for key,value in sorted(results.items()):\n", " print('%s: %0.2f' % (key, value))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "E0fAibNDcYxL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The first line of the output should display something like: `accuracy: 0.84`, which means the accuracy is 84%. You can try using more features and transformations to see if you can do better!\n", "\n", "After the model is evaluated, we can use it to predict whether an individual has an annual income of over 50,000 dollars given an individual's information input.\n", "\n", "Let's look in more detail how the model performed:" ] }, { "metadata": { "id": "8R5bz5CxcYxL", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import numpy as np\n", "\n", "predict_df = test_df[:20].copy()\n", "\n", "pred_iter = model.predict(\n", " lambda:easy_input_function(predict_df, label_key='income_bracket',\n", " num_epochs=1, shuffle=False, batch_size=10))\n", "\n", "classes = np.array(['<=50K', '>50K'])\n", "pred_class_id = []\n", "\n", "for pred_dict in pred_iter:\n", " pred_class_id.append(pred_dict['class_ids'])\n", "\n", "predict_df['predicted_class'] = classes[np.array(pred_class_id)]\n", "predict_df['correct'] = predict_df['predicted_class'] == predict_df['income_bracket']\n", "\n", "clear_output()\n", "\n", "predict_df[['income_bracket','predicted_class', 'correct']]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "N_uCpFTicYxN", "colab_type": "text" }, "cell_type": "markdown", "source": [ "For a working end-to-end example, download our [example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_main.py) and set the `model_type` flag to `wide`." ] }, { "metadata": { "id": "oyKy1lM_3gkL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Adding Regularization to Prevent Overfitting\n", "\n", "Regularization is a technique used to avoid overfitting. Overfitting happens when a model performs well on the data it is trained on, but worse on test data that the model has not seen before. Overfitting can occur when a model is excessively complex, such as having too many parameters relative to the number of observed training data. Regularization allows you to control the model's complexity and make the model more generalizable to unseen data.\n", "\n", "You can add L1 and L2 regularizations to the model with the following code:" ] }, { "metadata": { "id": "lzMUSBQ03hHx", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "model_l1 = tf.estimator.LinearClassifier(\n", " feature_columns=base_columns + crossed_columns,\n", " optimizer=tf.train.FtrlOptimizer(\n", " learning_rate=0.1,\n", " l1_regularization_strength=10.0,\n", " l2_regularization_strength=0.0))\n", "\n", "model_l1.train(train_inpf)\n", "\n", "results = model_l1.evaluate(test_inpf)\n", "clear_output()\n", "for key in sorted(results):\n", " print('%s: %0.2f' % (key, results[key]))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "ofmPL212JIy2", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "model_l2 = tf.estimator.LinearClassifier(\n", " feature_columns=base_columns + crossed_columns,\n", " optimizer=tf.train.FtrlOptimizer(\n", " learning_rate=0.1,\n", " l1_regularization_strength=0.0,\n", " l2_regularization_strength=10.0))\n", "\n", "model_l2.train(train_inpf)\n", "\n", "results = model_l2.evaluate(test_inpf)\n", "clear_output()\n", "for key in sorted(results):\n", " print('%s: %0.2f' % (key, results[key]))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Lp1Rfy_k4e7w", "colab_type": "text" }, "cell_type": "markdown", "source": [ "These regularized models don't perform much better than the base model. Let's look at the model's weight distributions to better see the effect of the regularization:" ] }, { "metadata": { "id": "Wb6093N04XlS", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def get_flat_weights(model):\n", " weight_names = [\n", " name for name in model.get_variable_names()\n", " if \"linear_model\" in name and \"Ftrl\" not in name]\n", "\n", " weight_values = [model.get_variable_value(name) for name in weight_names]\n", "\n", " weights_flat = np.concatenate([item.flatten() for item in weight_values], axis=0)\n", "\n", " return weights_flat\n", "\n", "weights_flat = get_flat_weights(model)\n", "weights_flat_l1 = get_flat_weights(model_l1)\n", "weights_flat_l2 = get_flat_weights(model_l2)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "GskJmtfmL0p-", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The models have many zero-valued weights caused by unused hash bins (there are many more hash bins than categories in some columns). We can mask these weights when viewing the weight distributions:" ] }, { "metadata": { "id": "rM3agZe3MT3D", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "weight_mask = weights_flat != 0\n", "\n", "weights_base = weights_flat[weight_mask]\n", "weights_l1 = weights_flat_l1[weight_mask]\n", "weights_l2 = weights_flat_l2[weight_mask]" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "NqBpxLLQNEBE", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Now plot the distributions:" ] }, { "metadata": { "id": "IdFK7wWa5_0K", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "plt.figure()\n", "_ = plt.hist(weights_base, bins=np.linspace(-3,3,30))\n", "plt.title('Base Model')\n", "plt.ylim([0,500])\n", "\n", "plt.figure()\n", "_ = plt.hist(weights_l1, bins=np.linspace(-3,3,30))\n", "plt.title('L1 - Regularization')\n", "plt.ylim([0,500])\n", "\n", "plt.figure()\n", "_ = plt.hist(weights_l2, bins=np.linspace(-3,3,30))\n", "plt.title('L2 - Regularization')\n", "_=plt.ylim([0,500])\n", "\n" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Mv6knhFa5-iJ", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Both types of regularization squeeze the distribution of weights towards zero. L2 regularization has a greater effect in the tails of the distribution eliminating extreme weights. L1 regularization produces more exactly-zero values, in this case it sets ~200 to zero." ] } ] }