{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tabular models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.tabular.all import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tabular data should be in a Pandas `DataFrame`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep_var = 'salary'\n",
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [Categorify, FillMissing, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splits = IndexSplitter(list(range(800,1000)))(range_of(df))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#splits = (L(splits[0], use_list=True), L(splits[1], use_list=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to = TabularPandas(df, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = to.dataloaders(bs=64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" age_na | \n",
" fnlwgt_na | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" salary | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" Private | \n",
" Bachelors | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Not-in-family | \n",
" Asian-Pac-Islander | \n",
" False | \n",
" False | \n",
" False | \n",
" 27.0 | \n",
" 104457.001298 | \n",
" 13.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 1 | \n",
" Self-emp-not-inc | \n",
" HS-grad | \n",
" Never-married | \n",
" Farming-fishing | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 20.0 | \n",
" 306709.997905 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" Bachelors | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 40.0 | \n",
" 209547.000700 | \n",
" 13.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" Bachelors | \n",
" Never-married | \n",
" Prof-specialty | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 26.0 | \n",
" 184120.000065 | \n",
" 13.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 4 | \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Adm-clerical | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 38.0 | \n",
" 248886.000709 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 5 | \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Not-in-family | \n",
" Asian-Pac-Islander | \n",
" False | \n",
" False | \n",
" False | \n",
" 28.0 | \n",
" 149769.001037 | \n",
" 9.0 | \n",
" <50k | \n",
"
\n",
" \n",
" | 6 | \n",
" Private | \n",
" Bachelors | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Wife | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 40.0 | \n",
" 225659.999761 | \n",
" 13.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 7 | \n",
" Private | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Craft-repair | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" False | \n",
" False | \n",
" False | \n",
" 27.0 | \n",
" 100668.997583 | \n",
" 10.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 8 | \n",
" Private | \n",
" Masters | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 46.0 | \n",
" 55720.003421 | \n",
" 14.0 | \n",
" >=50k | \n",
"
\n",
" \n",
" | 9 | \n",
" ? | \n",
" Assoc-acdm | \n",
" Married-civ-spouse | \n",
" ? | \n",
" Wife | \n",
" White | \n",
" False | \n",
" False | \n",
" False | \n",
" 35.0 | \n",
" 144172.001567 | \n",
" 12.0 | \n",
" <50k | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.372055 | \n",
" 0.369126 | \n",
" 0.840000 | \n",
" 00:10 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit(1, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference -> To do"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"row = df.iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.predict(row)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}