{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import * # Quick access to most common functionality\n", "from fastai.tabular import * # Quick access to tabular functionality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tabular example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tabular data should be in a Pandas `DataFrame`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-country>=50k
049Private101320Assoc-acdm12.0Married-civ-spouseNaNWifeWhiteFemale0190240United-States1
144Private236746Masters14.0DivorcedExec-managerialNot-in-familyWhiteMale10520045United-States1
238Private96185HS-gradNaNDivorcedNaNUnmarriedBlackFemale0032United-States0
338Self-emp-inc112847Prof-school15.0Married-civ-spouseProf-specialtyHusbandAsian-Pac-IslanderMale0040United-States1
442Self-emp-not-inc822977th-8thNaNMarried-civ-spouseOther-serviceWifeBlackFemale0050United-States0
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 49 Private 101320 Assoc-acdm 12.0 \n", "1 44 Private 236746 Masters 14.0 \n", "2 38 Private 96185 HS-grad NaN \n", "3 38 Self-emp-inc 112847 Prof-school 15.0 \n", "4 42 Self-emp-not-inc 82297 7th-8th NaN \n", "\n", " marital-status occupation relationship race \\\n", "0 Married-civ-spouse NaN Wife White \n", "1 Divorced Exec-managerial Not-in-family White \n", "2 Divorced NaN Unmarried Black \n", "3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n", "4 Married-civ-spouse Other-service Wife Black \n", "\n", " sex capital-gain capital-loss hours-per-week native-country >=50k \n", "0 Female 0 1902 40 United-States 1 \n", "1 Male 10520 0 45 United-States 1 \n", "2 Female 0 0 32 United-States 0 \n", "3 Male 0 0 40 United-States 1 \n", "4 Female 0 0 50 United-States 0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = URLs.get_adult()\n", "train_df, valid_df = df[:-2000].copy(),df[-2000:].copy()\n", "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convert your `DataFrame` in to a `DataBunch` suitable for modeling by calling `tabular_data_from_df`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = '>=50k'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']\n", "data = TabularDataBunch.from_df(path, train_df, valid_df, dep_var, \n", " tfms=[FillMissing, Categorify], cat_names=cat_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now you can create a `Learner` with `gen_tabular_dta`, and fit your model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:05\n", "epoch train loss valid loss accuracy\n", "0 0.340980 0.330657 0.847000 (00:05)\n", "\n" ] } ], "source": [ "learn = get_tabular_learner(data, layers=[200,100], metrics=accuracy)\n", "learn.fit(1, 1e-2)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }