{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Practical Deep Learning for Coders, v3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lesson4_tabular"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tabular models\n",
"# Tabular(表格)模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.tabular import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tabular data should be in a Pandas `DataFrame`.\n",
"\n",
"Tabular数据是Pandas里的`DataFrame`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"df = pd.read_csv(path/'adult.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep_var = 'salary'\n",
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
"cont_names = ['age', 'fnlwgt', 'education-num']\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
" .split_by_idx(list(range(800,1000)))\n",
" .label_from_df(cols=dep_var)\n",
" .add_test(test)\n",
" .databunch())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" workclass | \n",
" education | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" education-num_na | \n",
" age | \n",
" fnlwgt | \n",
" education-num | \n",
" target | \n",
"
\n",
" \n",
" Private | \n",
" HS-grad | \n",
" Never-married | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" -1.2158 | \n",
" 1.1004 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" ? | \n",
" HS-grad | \n",
" Widowed | \n",
" ? | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" 1.8627 | \n",
" 0.0976 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" Self-emp-not-inc | \n",
" HS-grad | \n",
" Never-married | \n",
" Craft-repair | \n",
" Own-child | \n",
" Black | \n",
" False | \n",
" 0.0303 | \n",
" 0.2092 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Protective-serv | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 1.5695 | \n",
" -0.5938 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Handlers-cleaners | \n",
" Husband | \n",
" White | \n",
" False | \n",
" -0.9959 | \n",
" -0.0318 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" 10th | \n",
" Married-civ-spouse | \n",
" Farming-fishing | \n",
" Wife | \n",
" White | \n",
" False | \n",
" -0.7027 | \n",
" 0.6071 | \n",
" -1.5958 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" HS-grad | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" White | \n",
" False | \n",
" 0.1036 | \n",
" -0.0968 | \n",
" -0.4224 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" Some-college | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" -0.7760 | \n",
" -0.6653 | \n",
" -0.0312 | \n",
" >=50k | \n",
"
\n",
" \n",
" State-gov | \n",
" Some-college | \n",
" Never-married | \n",
" Tech-support | \n",
" Own-child | \n",
" White | \n",
" False | \n",
" -0.8493 | \n",
" -1.4959 | \n",
" -0.0312 | \n",
" <50k | \n",
"
\n",
" \n",
" Private | \n",
" 11th | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Not-in-family | \n",
" White | \n",
" False | \n",
" -1.0692 | \n",
" -0.9516 | \n",
" -1.2046 | \n",
" <50k | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch(rows=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(data, layers=[200,100], metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:03 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 0.354604 | \n",
" 0.378520 | \n",
" 0.820000 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit(1, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference 预测"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"row = df.iloc[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Category >=50k, tensor(1), tensor([0.4402, 0.5598]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.predict(row)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}