{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "lesson4-tabular.ipynb", "version": "0.3.2", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "PcZh_7tRk7ke", "colab_type": "text" }, "source": [ "# Tabular models" ] }, { "cell_type": "code", "metadata": { "id": "vffyRZD0k7ki", "colab_type": "code", "colab": {} }, "source": [ "from fastai import *\n", "from fastai.tabular import *\n", "from sklearn.model_selection import train_test_split" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "gbM3J3vXk7kq", "colab_type": "text" }, "source": [ "Tabular data should be in a Pandas `DataFrame`." ] }, { "cell_type": "code", "metadata": { "id": "x6d23tFQk7kt", "colab_type": "code", "colab": {} }, "source": [ "path = untar_data(URLs.ADULT_SAMPLE)\n", "df = pd.read_csv(path/'adult.csv')" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "D32QgGP5k7k2", "colab_type": "code", "colab": {} }, "source": [ "dep_var = 'salary'\n", "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", "cont_names = ['age', 'fnlwgt', 'education-num']\n", "procs = [FillMissing, Categorify, Normalize]" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "JkNDyqKxDYNX", "colab_type": "text" }, "source": [ "# TL;DR:" ] }, { "cell_type": "code", "metadata": { "id": "vE1qhNm5D71H", "colab_type": "code", "colab": {} }, "source": [ "train, test = train_test_split(df, test_size=0.2)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "x98OSwAYDx-Z", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "4f95d827-63c9-46d0-8d5f-e3866fd2f1ce" }, "source": [ "print(len(train), len(test))" ], "execution_count": 12, "outputs": [ { "output_type": "stream", "text": [ "26048 6513\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "-3TX3j6YDk5O", "colab_type": "code", "colab": {} }, "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_rand_pct(0.2)\n", " .label_from_df(cols=dep_var)\n", " .databunch())" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "pS_baUUCDqBq", "colab_type": "code", "colab": {} }, "source": [ "data_test = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_none()\n", " .label_from_df(cols=dep_var))\n", "data_test.valid = data_test.train\n", "data_test = data_test.databunch()" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "PIaCvkwoETmC", "colab_type": "code", "colab": {} }, "source": [ "data.valid_dl = data_test.valid_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "INxLDM2ulRua", "colab_type": "text" }, "source": [ "# Regular Guided Example\n", "\n", "First I will show an example of what *will not* work" ] }, { "cell_type": "code", "metadata": { "id": "LQsXFcYWk7lA", "colab_type": "code", "colab": {} }, "source": [ "test = TabularList.from_df(df.iloc[700:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "MYb0GfUUk7lH", "colab_type": "code", "colab": {} }, "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(list(range(800,1000)))\n", " .label_from_df(cols=dep_var)\n", " .add_test(test)\n", " .databunch())" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FcYVsQqIk7lV", "colab_type": "code", "colab": {} }, "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "hgLRqgptk7lZ", "colab_type": "code", "outputId": "aaadc38e-7c5e-479a-acba-1240bc8241a5", "colab": { "base_uri": "https://localhost:8080/", "height": 78 } }, "source": [ "learn.fit(1, 1e-2)" ], "execution_count": 0, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.3664130.3878600.83000000:05
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "ICG0rpA4lg_K", "colab_type": "code", "outputId": "8ff585c0-1afd-436a-b504-fb9d3ead396a", "colab": { "base_uri": "https://localhost:8080/", "height": 437 } }, "source": [ "data" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TabularDataBunch;\n", "\n", "Train: LabelList (32361 items)\n", "x: TabularList\n", "workclass Private; education Assoc-acdm; marital-status Married-civ-spouse; occupation #na#; relationship Wife; race White; education-num_na False; age 0.7632; fnlwgt -0.8381; education-num 0.7511; ,workclass Private; education Masters; marital-status Divorced; occupation Exec-managerial; relationship Not-in-family; race White; education-num_na False; age 0.3968; fnlwgt 0.4458; education-num 1.5334; ,workclass Private; education HS-grad; marital-status Divorced; occupation #na#; relationship Unmarried; race Black; education-num_na True; age -0.0430; fnlwgt -0.8868; education-num -0.0312; ,workclass Self-emp-inc; education Prof-school; marital-status Married-civ-spouse; occupation Prof-specialty; relationship Husband; race Asian-Pac-Islander; education-num_na False; age -0.0430; fnlwgt -0.7288; education-num 1.9245; ,workclass Self-emp-not-inc; education 7th-8th; marital-status Married-civ-spouse; occupation Other-service; relationship Wife; race Black; education-num_na True; age 0.2502; fnlwgt -1.0185; education-num -0.0312; \n", "y: CategoryList\n", ">=50k,>=50k,<50k,>=50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Valid: LabelList (200 items)\n", "x: TabularList\n", "workclass Private; education Some-college; marital-status Divorced; occupation Handlers-cleaners; relationship Unmarried; race White; education-num_na True; age 0.4701; fnlwgt -0.8793; education-num -0.0312; ,workclass Self-emp-inc; education Prof-school; marital-status Married-civ-spouse; occupation Prof-specialty; relationship Husband; race White; education-num_na True; age 0.5434; fnlwgt 0.0290; education-num -0.0312; ,workclass Private; education Assoc-voc; marital-status Divorced; occupation #na#; relationship Not-in-family; race White; education-num_na True; age -0.1896; fnlwgt 1.7704; education-num -0.0312; ,workclass Federal-gov; education Bachelors; marital-status Never-married; occupation Tech-support; relationship Not-in-family; race White; education-num_na True; age -0.9959; fnlwgt -1.3242; education-num -0.0312; ,workclass Private; education Bachelors; marital-status Married-civ-spouse; occupation #na#; relationship Husband; race White; education-num_na True; age -0.1163; fnlwgt -0.2389; education-num -0.0312; \n", "y: CategoryList\n", "<50k,>=50k,<50k,<50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Test: LabelList (300 items)\n", "x: TabularList\n", "workclass Private; education HS-grad; marital-status Never-married; occupation #na#; relationship Own-child; race White; education-num_na True; age -0.6294; fnlwgt -1.2432; education-num -0.0312; ,workclass Federal-gov; education HS-grad; marital-status Married-civ-spouse; occupation Farming-fishing; relationship Husband; race White; education-num_na False; age 0.3235; fnlwgt 0.0586; education-num -0.4224; ,workclass Private; education Bachelors; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age -0.1896; fnlwgt -1.4639; education-num 1.1422; ,workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation #na#; relationship Husband; race White; education-num_na False; age -0.1163; fnlwgt -0.2014; education-num -0.4224; ,workclass Self-emp-inc; education HS-grad; marital-status Never-married; occupation #na#; relationship Not-in-family; race White; education-num_na True; age -0.2629; fnlwgt -0.4633; education-num -0.0312; \n", "y: EmptyLabelList\n", ",,,,\n", "Path: /root/.fastai/data/adult_sample" ] }, "metadata": { "tags": [] }, "execution_count": 19 } ] }, { "cell_type": "code", "metadata": { "id": "M_Gl3-4ElvQ-", "colab_type": "code", "outputId": "cb148cc2-baf3-4c43-9b7d-3c2cc237c821", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "learn.validate(dl=learn.data.train_dl)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[0.35220727, tensor(0.8369)]" ] }, "metadata": { "tags": [] }, "execution_count": 20 } ] }, { "cell_type": "code", "metadata": { "id": "v2qIFYIMlt2E", "colab_type": "code", "outputId": "257695e5-efc3-43c6-d8ea-574d939cd297", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "learn.validate()" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[0.3878597, tensor(0.8300)]" ] }, "metadata": { "tags": [] }, "execution_count": 21 } ] }, { "cell_type": "code", "metadata": { "id": "HNQyu7J6lkIN", "colab_type": "code", "outputId": "5fbb39b7-9851-46c8-ddf8-d71c3ee510ce", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "learn.validate(dl = learn.data.test_dl)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[0.2829722, tensor(0.8967)]" ] }, "metadata": { "tags": [] }, "execution_count": 22 } ] }, { "cell_type": "markdown", "metadata": { "id": "GjXxYBRXmMWD", "colab_type": "text" }, "source": [ "This looks very good right? But let's try doing it a different way to be sure... as this is above **any** research level results" ] }, { "cell_type": "markdown", "metadata": { "id": "dA3FvBMImUqF", "colab_type": "text" }, "source": [ "# Train/Valid/Test Split - The proper way\n", "\n", "I'm going to first use train_test_split to split our data into a 90/10 split" ] }, { "cell_type": "code", "metadata": { "id": "ohxrLZj_mash", "colab_type": "code", "colab": {} }, "source": [ "train, test = train_test_split(df, test_size=0.1)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "jqOzg1LpmeGZ", "colab_type": "code", "outputId": "40761769-87aa-4f2a-bf7d-5ee0b50c2266", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "len(train), len(test)" ], "execution_count": 17, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(29304, 3257)" ] }, "metadata": { "tags": [] }, "execution_count": 17 } ] }, { "cell_type": "markdown", "metadata": { "id": "tVdMgbWKmlVL", "colab_type": "text" }, "source": [ "Great, we have a 10% split. Now lets make our train and test databunches" ] }, { "cell_type": "code", "metadata": { "id": "etLZkbe-mr2D", "colab_type": "code", "colab": {} }, "source": [ "data = (TabularList.from_df(train, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_rand_pct(0.2) # So we can get a 20% split into the validation\n", " .label_from_df(cols=dep_var)\n", " #.add_test(test) we are not using this though\n", " .databunch())" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6tRgg10Im3Cy", "colab_type": "code", "colab": {} }, "source": [ "data_test = (TabularList.from_df(test, path=path, cat_names=cat_names,\n", " cont_names=cont_names, procs=procs, \n", " processor = data.processor) # NOTICE THIS STEP, this is so the procs are all applied the exact same\n", " .split_none() # we only want it\n", " .label_from_df(cols=dep_var)\n", " )" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "6bj-HvdFEdK2", "colab_type": "text" }, "source": [ "Here we do not databunch yet. This is due to the training dataloader is shuffled, which we don't want, and the last batch is dropped if not complete. How do we fix this? Set the valid dataloader to the train ***before*** databunching." ] }, { "cell_type": "code", "metadata": { "id": "RdaBEZzqEmUM", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 326 }, "outputId": "c32b82c7-10a9-4d10-e9ed-8f0fc901e93c" }, "source": [ "data_test" ], "execution_count": 20, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "LabelLists;\n", "\n", "Train: LabelList (3257 items)\n", "x: TabularList\n", "workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Transport-moving; relationship Husband; race White; education-num_na False; age -0.4820; fnlwgt 2.1986; education-num -0.4231; ,workclass Private; education Doctorate; marital-status Married-civ-spouse; occupation Prof-specialty; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.6453; education-num 2.3199; ,workclass Self-emp-not-inc; education Prof-school; marital-status Married-civ-spouse; occupation #na#; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.0722; education-num 1.9281; ,workclass Private; education Bachelors; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age 0.8387; fnlwgt -0.5257; education-num 1.1443; ,workclass Private; education 11th; marital-status Divorced; occupation Sales; relationship Own-child; race White; education-num_na False; age -0.9956; fnlwgt -1.1058; education-num -1.2069; \n", "y: CategoryList\n", "<50k,>=50k,>=50k,>=50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Valid: LabelList (0 items)\n", "x: TabularList\n", "\n", "y: CategoryList\n", "\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Test: None" ] }, "metadata": { "tags": [] }, "execution_count": 20 } ] }, { "cell_type": "code", "metadata": { "id": "HtdzBH9AEnND", "colab_type": "code", "colab": {} }, "source": [ "data_test.valid = data_test.train\n", "data_test = data_test.databunch()" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Duda399rnUq9", "colab_type": "text" }, "source": [ "Okay now let's look at the two" ] }, { "cell_type": "code", "metadata": { "id": "LNOsJy8YnWnD", "colab_type": "code", "outputId": "aead6062-c3e8-4d5a-c450-48903a8ab6ff", "colab": { "base_uri": "https://localhost:8080/", "height": 326 } }, "source": [ "data" ], "execution_count": 22, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TabularDataBunch;\n", "\n", "Train: LabelList (23444 items)\n", "x: TabularList\n", "workclass Local-gov; education Some-college; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age -0.7755; fnlwgt 0.8878; education-num -0.0313; ,workclass Private; education HS-grad; marital-status Never-married; occupation Craft-repair; relationship Not-in-family; race White; education-num_na False; age 0.1050; fnlwgt 0.1760; education-num -0.4231; ,workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Protective-serv; relationship Husband; race White; education-num_na False; age -0.6288; fnlwgt 0.1266; education-num -0.4231; ,workclass Local-gov; education Bachelors; marital-status Never-married; occupation Prof-specialty; relationship Not-in-family; race White; education-num_na False; age -0.9223; fnlwgt -0.2255; education-num 1.1443; ,workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age 1.2056; fnlwgt -0.7925; education-num -0.4231; \n", "y: CategoryList\n", "<50k,<50k,>=50k,<50k,>=50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Valid: LabelList (5860 items)\n", "x: TabularList\n", "workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Machine-op-inspct; relationship Wife; race White; education-num_na False; age -0.1885; fnlwgt -0.0029; education-num -0.4231; ,workclass State-gov; education Masters; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age 1.0589; fnlwgt 0.0686; education-num 1.5362; ,workclass Private; education 9th; marital-status Never-married; occupation Other-service; relationship Own-child; race White; education-num_na False; age -0.8489; fnlwgt -1.4557; education-num -1.9906; ,workclass Private; education HS-grad; marital-status Divorced; occupation Other-service; relationship Not-in-family; race White; education-num_na False; age -0.7755; fnlwgt 0.0106; education-num -0.4231; ,workclass Private; education HS-grad; marital-status Never-married; occupation Transport-moving; relationship Unmarried; race White; education-num_na False; age -0.7755; fnlwgt 0.7598; education-num -0.4231; \n", "y: CategoryList\n", "<50k,>=50k,<50k,<50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Test: None" ] }, "metadata": { "tags": [] }, "execution_count": 22 } ] }, { "cell_type": "code", "metadata": { "id": "qu4FG_6EnXl2", "colab_type": "code", "outputId": "5ba2e243-a556-4c85-db99-c5e462e06954", "colab": { "base_uri": "https://localhost:8080/", "height": 326 } }, "source": [ "data_test" ], "execution_count": 23, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TabularDataBunch;\n", "\n", "Train: LabelList (3257 items)\n", "x: TabularList\n", "workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Transport-moving; relationship Husband; race White; education-num_na False; age -0.4820; fnlwgt 2.1986; education-num -0.4231; ,workclass Private; education Doctorate; marital-status Married-civ-spouse; occupation Prof-specialty; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.6453; education-num 2.3199; ,workclass Self-emp-not-inc; education Prof-school; marital-status Married-civ-spouse; occupation #na#; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.0722; education-num 1.9281; ,workclass Private; education Bachelors; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age 0.8387; fnlwgt -0.5257; education-num 1.1443; ,workclass Private; education 11th; marital-status Divorced; occupation Sales; relationship Own-child; race White; education-num_na False; age -0.9956; fnlwgt -1.1058; education-num -1.2069; \n", "y: CategoryList\n", "<50k,>=50k,>=50k,>=50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Valid: LabelList (3257 items)\n", "x: TabularList\n", "workclass Private; education HS-grad; marital-status Married-civ-spouse; occupation Transport-moving; relationship Husband; race White; education-num_na False; age -0.4820; fnlwgt 2.1986; education-num -0.4231; ,workclass Private; education Doctorate; marital-status Married-civ-spouse; occupation Prof-specialty; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.6453; education-num 2.3199; ,workclass Self-emp-not-inc; education Prof-school; marital-status Married-civ-spouse; occupation #na#; relationship Husband; race White; education-num_na False; age 0.3985; fnlwgt 0.0722; education-num 1.9281; ,workclass Private; education Bachelors; marital-status Married-civ-spouse; occupation Exec-managerial; relationship Husband; race White; education-num_na False; age 0.8387; fnlwgt -0.5257; education-num 1.1443; ,workclass Private; education 11th; marital-status Divorced; occupation Sales; relationship Own-child; race White; education-num_na False; age -0.9956; fnlwgt -1.1058; education-num -1.2069; \n", "y: CategoryList\n", "<50k,>=50k,>=50k,>=50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Test: None" ] }, "metadata": { "tags": [] }, "execution_count": 23 } ] }, { "cell_type": "markdown", "metadata": { "id": "utojlXy8ni5m", "colab_type": "text" }, "source": [ "The numbers look right, lets do a quick train and try switching them again" ] }, { "cell_type": "code", "metadata": { "id": "Bo-vqk1XnnBY", "colab_type": "code", "colab": {} }, "source": [ "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Ncz8tfdCnrX7", "colab_type": "code", "outputId": "f8fe605f-70b1-4160-bb89-7ed0fb87d7c7", "colab": { "base_uri": "https://localhost:8080/", "height": 80 } }, "source": [ "learn.fit(1, 1e-2)" ], "execution_count": 25, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.3686150.3596520.83208200:05
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "uOR6Wc1inthl", "colab_type": "code", "colab": {} }, "source": [ "learn.data.valid_dl = data_test.valid_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "uE-ngZVXnxkR", "colab_type": "code", "outputId": "9f360435-8c58-40a2-ed75-ef79b1b1b621", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "%time learn.validate()" ], "execution_count": 27, "outputs": [ { "output_type": "stream", "text": [ "CPU times: user 238 ms, sys: 100 ms, total: 338 ms\n", "Wall time: 486 ms\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "[0.35564667, tensor(0.8354)]" ] }, "metadata": { "tags": [] }, "execution_count": 27 } ] }, { "cell_type": "code", "metadata": { "id": "KGKmBGdjoS2N", "colab_type": "code", "outputId": "e021a668-5bb5-41f0-f2a0-9eb7f7d8e27c", "colab": { "base_uri": "https://localhost:8080/", "height": 346 } }, "source": [ "learn.data" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TabularDataBunch;\n", "\n", "Train: LabelList (23444 items)\n", "x: TabularList\n", "workclass ?; education HS-grad; marital-status Never-married; occupation ?; relationship Own-child; race White; education-num_na False; age -1.3688; fnlwgt 1.3052; education-num -0.4249; ,workclass Local-gov; education 7th-8th; marital-status Married-civ-spouse; occupation Craft-repair; relationship Husband; race White; education-num_na False; age 1.4246; fnlwgt 0.7850; education-num -2.3731; ,workclass Private; education Some-college; marital-status Never-married; occupation Sales; relationship Own-child; race White; education-num_na False; age -1.2218; fnlwgt 1.8899; education-num -0.0352; ,workclass Private; education Assoc-acdm; marital-status Married-civ-spouse; occupation Adm-clerical; relationship Husband; race White; education-num_na False; age -0.0456; fnlwgt -0.0499; education-num 0.7441; ,workclass Private; education HS-grad; marital-status Divorced; occupation Handlers-cleaners; relationship Not-in-family; race White; education-num_na False; age 0.0279; fnlwgt -0.6546; education-num -0.4249; \n", "y: CategoryList\n", "<50k,<50k,<50k,<50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Valid: LabelList (3257 items)\n", "x: TabularList\n", "workclass Local-gov; education HS-grad; marital-status Married-civ-spouse; occupation Transport-moving; relationship Husband; race Black; education-num_na False; age 0.9100; fnlwgt 1.2334; education-num -0.4249; ,workclass Private; education HS-grad; marital-status Never-married; occupation Exec-managerial; relationship Own-child; race White; education-num_na False; age -1.1483; fnlwgt 0.2891; education-num -0.4249; ,workclass State-gov; education HS-grad; marital-status Married-civ-spouse; occupation Adm-clerical; relationship Wife; race White; education-num_na False; age 0.6159; fnlwgt -0.8373; education-num -0.4249; ,workclass Federal-gov; education HS-grad; marital-status Divorced; occupation Adm-clerical; relationship Not-in-family; race White; education-num_na False; age 0.6895; fnlwgt 0.5501; education-num -0.4249; ,workclass Local-gov; education HS-grad; marital-status Never-married; occupation Protective-serv; relationship Own-child; race Black; education-num_na False; age -0.6337; fnlwgt 0.0258; education-num -0.4249; \n", "y: CategoryList\n", "<50k,<50k,>=50k,<50k,<50k\n", "Path: /root/.fastai/data/adult_sample;\n", "\n", "Test: None" ] }, "metadata": { "tags": [] }, "execution_count": 57 } ] }, { "cell_type": "markdown", "metadata": { "id": "JpZ11xNYoU0Q", "colab_type": "text" }, "source": [ "As we can see, we no longer get that **SUPER** high test set accuracy, as it wasn't really validating it for us! Also we can match that the Valid LabelList got replaced with our own, as our test set had 3257 items. Also, this is much faster than doing learn.predict(). I'll show an example below for time" ] }, { "cell_type": "markdown", "metadata": { "id": "InKxFJMhE09X", "colab_type": "text" }, "source": [ "# learn.predict() vs learn.validate() - Time Comparison" ] }, { "cell_type": "markdown", "metadata": { "id": "paN5m_kPE6Ce", "colab_type": "text" }, "source": [ "Below is a quick function using learn.predict where we will check to see if our predictions match our actual in the entire dataset" ] }, { "cell_type": "code", "metadata": { "id": "WljQrD3A3-fY", "colab_type": "code", "colab": {} }, "source": [ "def CalculateAccuracy(learner, df, right):\n", "\tfor x in range(len(df)):\n", "\t\tif str(df['salary'].iloc[x]) == str(learner.predict(df.iloc[x])[0]):\n", "\t\t\tright +=1;\n", "\n", "\treturn right/(len(df))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "CZ75GyYu4ZcO", "colab_type": "code", "outputId": "504d93b9-aab2-439c-8316-1286bc0106d0", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "%time acc = CalculateAccuracy(learn, test, 0)" ], "execution_count": 29, "outputs": [ { "output_type": "stream", "text": [ "CPU times: user 1min 22s, sys: 200 ms, total: 1min 22s\n", "Wall time: 1min 22s\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "M_83WtU65WGe", "colab_type": "code", "outputId": "a56d5785-cb2d-4ec8-ccac-8b46515d57ed", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "acc" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8326680994780473" ] }, "metadata": { "tags": [] }, "execution_count": 16 } ] }, { "cell_type": "markdown", "metadata": { "id": "UmcDYozgFDcd", "colab_type": "text" }, "source": [ "Now let's use fastai's `get_preds` function and do a comparison after we switch the above" ] }, { "cell_type": "code", "metadata": { "id": "Myq0AOnuFPXL", "colab_type": "code", "colab": {} }, "source": [ "data_test = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_none()\n", " .label_from_df(cols=dep_var))\n", "data_test.valid = data_test.train\n", "data_test=data_test.databunch()" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "2j_Oc-kRFZuH", "colab_type": "code", "colab": {} }, "source": [ "learn.data.valid_dl = data_test.valid_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "nDkhz6sookL1", "colab_type": "code", "outputId": "fc9068bb-2cbd-44e0-d3a8-b796d02437e4", "colab": { "base_uri": "https://localhost:8080/", "height": 170 } }, "source": [ "%time learn.get_preds(ds_type=DatasetType.Valid)" ], "execution_count": 33, "outputs": [ { "output_type": "stream", "text": [ "CPU times: user 1.85 s, sys: 437 ms, total: 2.29 s\n", "Wall time: 3.45 s\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "[tensor([[0.4628, 0.5372],\n", " [0.4795, 0.5205],\n", " [0.9468, 0.0532],\n", " ...,\n", " [0.5167, 0.4833],\n", " [0.7234, 0.2766],\n", " [0.8070, 0.1930]]), tensor([1, 1, 0, ..., 1, 0, 0])]" ] }, "metadata": { "tags": [] }, "execution_count": 33 } ] }, { "cell_type": "markdown", "metadata": { "id": "HxG6HryHpLVW", "colab_type": "text" }, "source": [ "Look at that time difference! 1:24 vs 3s. That is **much** faster as we are using the GPU here too." ] }, { "cell_type": "code", "metadata": { "id": "XU5BoU7NpwwJ", "colab_type": "code", "colab": {} }, "source": [ "valid_dl = data.valid_dl\n", "test_dl = data_test.valid_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ak9TFOy-p3ei", "colab_type": "text" }, "source": [ "Now we can safely just replace one or the other and keep going.\n", "\n", "To predict on test:" ] }, { "cell_type": "code", "metadata": { "id": "tvbe1DYep6MI", "colab_type": "code", "colab": {} }, "source": [ "learn.data.valid_dl = test_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "iDyiCjQFp_eR", "colab_type": "text" }, "source": [ "To revert back to our validation" ] }, { "cell_type": "code", "metadata": { "id": "4ON1Cc00p8oL", "colab_type": "code", "colab": {} }, "source": [ "learn.data.valid_dl = valid_dl" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "FGdApxCdqCkZ", "colab_type": "text" }, "source": [ "And now we can flip back in forth!" ] } ] }