{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Practical Deep Learning for Coders, v3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Lesson4_tabular" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular models\n", "# Tabular(表格)模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.tabular import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tabular data should be in a Pandas `DataFrame`.\n", "\n", "Tabular数据是Pandas里的`DataFrame`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = 'salary'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(list(range(800,1000)))\n", " .label_from_df(cols=dep_var)\n", " .add_test(test)\n", " .databunch())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numtarget
Private HS-grad Never-married Sales Not-in-family WhiteFalse-1.21581.1004-0.4224<50k
? HS-grad Widowed ? Not-in-family WhiteFalse1.86270.0976-0.4224<50k
Self-emp-not-inc HS-grad Never-married Craft-repair Own-child BlackFalse0.03030.2092-0.4224<50k
Private HS-grad Married-civ-spouse Protective-serv Husband WhiteFalse1.5695-0.5938-0.4224<50k
Private HS-grad Married-civ-spouse Handlers-cleaners Husband WhiteFalse-0.9959-0.0318-0.4224<50k
Private 10th Married-civ-spouse Farming-fishing Wife WhiteFalse-0.70270.6071-1.5958<50k
Private HS-grad Married-civ-spouse Machine-op-inspct Husband WhiteFalse0.1036-0.0968-0.4224<50k
Private Some-college Married-civ-spouse Exec-managerial Own-child WhiteFalse-0.7760-0.6653-0.0312>=50k
State-gov Some-college Never-married Tech-support Own-child WhiteFalse-0.8493-1.4959-0.0312<50k
Private 11th Never-married Machine-op-inspct Not-in-family WhiteFalse-1.0692-0.9516-1.2046<50k
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data.show_batch(rows=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:03

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.3546040.3785200.820000
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(1, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference 预测" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "row = df.iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Category >=50k, tensor(1), tensor([0.4402, 0.5598]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict(row)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }