{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.tabular import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tabular data should be in a Pandas `DataFrame`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = '>=50k'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(list(range(800,1000)))\n", " .label_from_df(cols=dep_var)\n", " .add_test(test, label=0)\n", " .databunch())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
workclasseducationmarital-statusoccupationrelationshipraceeducation-num_naagefnlwgteducation-numtarget
Private Prof-school Married-civ-spouse Prof-specialty Husband WhiteFalse0.10360.92241.92451
Self-emp-inc Bachelors Married-civ-spouse Farming-fishing Husband WhiteFalse1.7161-1.26541.14221
Private HS-grad Never-married Adm-clerical Other-relative BlackFalse-0.77601.1905-0.42240
Private 10th Married-civ-spouse Sales Own-child WhiteFalse-1.5823-0.0268-1.59580
Private Some-college Never-married Handlers-cleaners Own-child WhiteFalse-1.36240.0284-0.03120
Private Some-college Married-civ-spouse Prof-specialty Husband WhiteFalse0.39680.4367-0.03121
? Some-college Never-married ? Own-child WhiteFalse-1.4357-0.7295-0.03120
Self-emp-not-inc 5th-6th Married-civ-spouse Sales Husband WhiteFalse0.6166-0.6503-2.76921
Private Some-college Married-civ-spouse Sales Husband WhiteFalse1.5695-0.8876-0.03121
Local-gov Some-college Never-married Handlers-cleaners Own-child WhiteFalse-0.6294-1.5422-0.03120
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data.show_batch(rows=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TabularModel(\n", " (embeds): ModuleList(\n", " (0): Embedding(10, 6)\n", " (1): Embedding(17, 9)\n", " (2): Embedding(8, 5)\n", " (3): Embedding(16, 9)\n", " (4): Embedding(7, 4)\n", " (5): Embedding(6, 4)\n", " (6): Embedding(3, 2)\n", " )\n", " (emb_drop): Dropout(p=0.0)\n", " (bn_cont): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): Linear(in_features=42, out_features=200, bias=True)\n", " (1): ReLU(inplace)\n", " (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Linear(in_features=200, out_features=100, bias=True)\n", " (4): ReLU(inplace)\n", " (5): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): Linear(in_features=100, out_features=2, bias=True)\n", " )\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:03\n", "epoch train_loss valid_loss accuracy\n", "1 0.362837 0.413169 0.785000 (00:03)\n", "\n" ] } ], "source": [ "learn.fit(1, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "row = df.iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, tensor(0), tensor([0.6365, 0.3635]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.predict(row)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }