{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tabular training\n",
"\n",
"> How to use the tabular application in fastai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To illustrate the tabular application, we will use the example of the [Adult dataset](https://archive.ics.uci.edu/ml/datasets/Adult) where we have to predict if a person is earning more or less than $50k per year using some general data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.tabular.all import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can download a sample of this dataset with the usual command:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#3) [Path('/home/sgugger/.fastai/data/adult_sample/models'),Path('/home/sgugger/.fastai/data/adult_sample/adult.csv'),Path('/home/sgugger/.fastai/data/adult_sample/export.pkl')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"path.ls()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can have a look at how the data is structured:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 49 | \n",
" Private | \n",
" 101320 | \n",
" Assoc-acdm | \n",
" 12.0 | \n",
" Married-civ-spouse | \n",
" NaN | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 1902 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 1 | \n",
" 44 | \n",
" Private | \n",
" 236746 | \n",
" Masters | \n",
" 14.0 | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 10520 | \n",
" 0 | \n",
" 45 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 2 | \n",
" 38 | \n",
" Private | \n",
" 96185 | \n",
" HS-grad | \n",
" NaN | \n",
" Divorced | \n",
" NaN | \n",
" Unmarried | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 32 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" 38 | \n",
" Self-emp-inc | \n",
" 112847 | \n",
" Prof-school | \n",
" 15.0 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >=50k | \n",
"
\n",
" \n",
" | 4 | \n",
" 42 | \n",
" Self-emp-not-inc | \n",
" 82297 | \n",
" 7th-8th | \n",
" NaN | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" <50k | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country salary \n",
"0 Female 0 1902 40 United-States >=50k \n",
"1 Male 10520 0 45 United-States >=50k \n",
"2 Female 0 0 32 United-States <50k \n",
"3 Male 0 0 40 United-States >=50k \n",
"4 Female 0 0 50 United-States <50k "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'adult.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the coumns are continuous (like age) and we will treat them as float numbers we can feed our model directly. Others are categorical (like workclass or education) and we will convert them to a unique index that we will feed to embedding layers. We can specify our categorical and continuous column names, as well as the name of the dependent variable in `TabularDataLoaders` factory methods:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=\"salary\",\n",
" cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],\n",
" cont_names = ['age', 'fnlwgt', 'education-num'],\n",
" procs = [Categorify, FillMissing, Normalize])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The last part is the list of pre-processors we apply to our data:\n",
"\n",
"- `Categorify` is going to take every categorical variable and make a map from integer to unique categories, then replace the values by the corresponding index.\n",
"- `FillMissing` will fille the missing values in the continuous variables by the median of existing values (you can choose a specific value if you prefer)\n",
"- `Normalize` will normalize the continuous variables (substract the mean and divide by the std)\n",
"\n",
"The `show_batch` method works like for every other application:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" ? | \n",
" Some-college | \n",
" Never-married | \n",
" ? | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" 22.000000 | \n",
" 32731.996436 | \n",
" 10.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 1 | \n",
" Private | \n",
" 7th-8th | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 44.000000 | \n",
" 99202.998578 | \n",
" 4.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" HS-grad | \n",
" Divorced | \n",
" Farming-fishing | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 63.000001 | \n",
" 117680.996997 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 33.000000 | \n",
" 194141.000170 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 4 | \n",
" Private | \n",
" Assoc-voc | \n",
" Divorced | \n",
" Transport-moving | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 35.000000 | \n",
" 172570.999732 | \n",
" 11.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 5 | \n",
" Local-gov | \n",
" HS-grad | \n",
" Divorced | \n",
" Exec-managerial | \n",
" Unmarried | \n",
" Amer-Indian-Eskimo | \n",
" False | \n",
" 43.000000 | \n",
" 196308.000036 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 6 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Exec-managerial | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 43.000000 | \n",
" 336642.996235 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Other-service | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 27.000000 | \n",
" 158156.001081 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 8 | \n",
" ? | \n",
" Bachelors | \n",
" Never-married | \n",
" ? | \n",
" Unmarried | \n",
" White | \n",
" False | \n",
" 26.000000 | \n",
" 130832.001756 | \n",
" 13.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 9 | \n",
" Private | \n",
" Assoc-voc | \n",
" Married-civ-spouse | \n",
" Tech-support | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 27.000000 | \n",
" 62737.003461 | \n",
" 11.0 | \n",
" <50k | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can define a model using the `tabular_learner` method. When we define our model, `fastai` will try to infer the loss function based on our `y_names` earlier. \n",
"\n",
"**Note**: Sometimes with tabular data, your `y`'s may be encoded (such as 0 and 1). In such a case you should explicitly pass `y_block = CategoryBlock` in your constructor so `fastai` won't presume you are doing regression."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(dls, metrics=accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we can train that model with the `fit_one_cycle` method (the `fine_tune` method won't be useful here since we don't have a pretrained model)>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.366727 | \n",
" 0.351524 | \n",
" 0.835842 | \n",
" 00:05 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then have a look at some predictions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
" salary_pred | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 5.0 | \n",
" 12.0 | \n",
" 3.0 | \n",
" 15.0 | \n",
" 1.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" -0.333356 | \n",
" -0.900977 | \n",
" -0.419934 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 7.0 | \n",
" 12.0 | \n",
" 5.0 | \n",
" 6.0 | \n",
" 5.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" 0.916167 | \n",
" -1.457755 | \n",
" -0.419934 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 5.0 | \n",
" 10.0 | \n",
" 3.0 | \n",
" 2.0 | \n",
" 1.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" -0.774364 | \n",
" -0.030944 | \n",
" 1.150726 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 5.0 | \n",
" 13.0 | \n",
" 3.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" -0.259855 | \n",
" -0.668491 | \n",
" 1.543390 | \n",
" 0.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 5.0 | \n",
" 13.0 | \n",
" 1.0 | \n",
" 13.0 | \n",
" 2.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" 0.622161 | \n",
" 0.409060 | \n",
" 1.543390 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 5 | \n",
" 3.0 | \n",
" 16.0 | \n",
" 3.0 | \n",
" 4.0 | \n",
" 1.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" 0.254654 | \n",
" -0.870132 | \n",
" -0.027269 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 6 | \n",
" 5.0 | \n",
" 12.0 | \n",
" 5.0 | \n",
" 13.0 | \n",
" 2.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" -0.259855 | \n",
" -0.464552 | \n",
" -0.419934 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 7 | \n",
" 5.0 | \n",
" 9.0 | \n",
" 3.0 | \n",
" 4.0 | \n",
" 1.0 | \n",
" 5.0 | \n",
" 1.0 | \n",
" 0.989668 | \n",
" -0.430562 | \n",
" 0.365396 | \n",
" 1.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 8 | \n",
" 6.0 | \n",
" 16.0 | \n",
" 3.0 | \n",
" 4.0 | \n",
" 1.0 | \n",
" 3.0 | \n",
" 1.0 | \n",
" -0.627362 | \n",
" -0.110140 | \n",
" -0.027269 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.show_results()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or use the predict method on a row:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"( workclass education marital-status occupation relationship race \\\n",
" 0 5.0 8.0 3.0 0.0 6.0 5.0 \n",
" \n",
" education-num_na age fnlwgt education-num salary \n",
" 0 1.0 0.769164 -0.835926 0.758061 0.0 ,\n",
" tensor(0),\n",
" tensor([0.5200, 0.4800]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.predict(df.iloc[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To get prediction on a new dataframe, you can use the `test_dl` method of the `DataLoaders`. That dataframe does not need to have the dependent variable in its column."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_df = df.copy()\n",
"test_df.drop(['salary'], axis=1, inplace=True)\n",
"dl = learn.dls.test_dl(test_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then `Learner.get_preds` will give you the predictions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(tensor([[0.5200, 0.4800],\n",
" [0.5536, 0.4464],\n",
" [0.9767, 0.0233],\n",
" ...,\n",
" [0.6025, 0.3975],\n",
" [0.7228, 0.2772],\n",
" [0.5157, 0.4843]]), None)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.get_preds(dl=dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}