{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.tabular import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Rossmann" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create the feature-engineered filed train_clean and test_clean from the initial data, run x_009a_rossman_data_clean" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pyarrow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/anaconda3/lib/python3.6/site-packages/pandas/io/feather_format.py:112: FutureWarning: `nthreads` argument is deprecated, pass `use_threads` instead\n", " return feather.read_dataframe(path, nthreads=nthreads)\n" ] } ], "source": [ "path = Path('data/rossmann/')\n", "train_df = pd.read_feather(path/'train_clean')\n", "test_df = pd.read_feather(path/'test_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
index01234
Store12345
DayOfWeek55555
Date2015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:00
Sales526360648314139954822
Customers5556258211498559
Open11111
Promo11111
StateHolidayFalseFalseFalseFalseFalse
SchoolHoliday11111
Year20152015201520152015
Month77777
Week3131313131
Day3131313131
Dayofweek44444
Dayofyear212212212212212
Is_month_endTrueTrueTrueTrueTrue
Is_month_startFalseFalseFalseFalseFalse
Is_quarter_endFalseFalseFalseFalseFalse
Is_quarter_startFalseFalseFalseFalseFalse
Is_year_endFalseFalseFalseFalseFalse
Is_year_startFalseFalseFalseFalseFalse
Elapsed14383008001438300800143830080014383008001438300800
StoreTypecaaca
Assortmentaaaca
CompetitionDistance12705701413062029910
CompetitionOpenSinceMonth9111294
CompetitionOpenSinceYear20082007200620092015
Promo201100
Promo2SinceWeek1131411
..................
Min_Sea_Level_PressurehPa10151017101710141016
Max_VisibilityKm3110311010
Mean_VisibilityKm1510141010
Min_VisibilitykM1010101010
Max_Wind_SpeedKm_h2414142314
Mean_Wind_SpeedKm_h111151611
Max_Gust_SpeedKm_hNaNNaNNaNNaNNaN
Precipitationmm00000
CloudCover14264
EventsFogFogFogNoneNone
WindDirDegrees13309354282290
StateNameHessenThueringenNordrheinWestfalenBerlinSachsen
CompetitionOpenSince2008-09-15 00:00:002007-11-15 00:00:002006-12-15 00:00:002009-09-15 00:00:002015-04-15 00:00:00
CompetitionDaysOpen2510281531502145107
CompetitionMonthsOpen242424243
Promo2Since1900-01-01 00:00:002010-03-29 00:00:002011-04-04 00:00:001900-01-01 00:00:001900-01-01 00:00:00
Promo2Days01950157900
Promo2Weeks0252500
AfterSchoolHoliday00000
BeforeSchoolHoliday00000
AfterStateHoliday5767576757
BeforeStateHoliday00000
AfterPromo00000
BeforePromo00000
SchoolHoliday_bw55555
StateHoliday_bw00000
Promo_bw55555
SchoolHoliday_fw71511
StateHoliday_fw00000
Promo_fw51511
\n", "

93 rows × 5 columns

\n", "
" ], "text/plain": [ " 0 1 \\\n", "index 0 1 \n", "Store 1 2 \n", "DayOfWeek 5 5 \n", "Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n", "Sales 5263 6064 \n", "Customers 555 625 \n", "Open 1 1 \n", "Promo 1 1 \n", "StateHoliday False False \n", "SchoolHoliday 1 1 \n", "Year 2015 2015 \n", "Month 7 7 \n", "Week 31 31 \n", "Day 31 31 \n", "Dayofweek 4 4 \n", "Dayofyear 212 212 \n", "Is_month_end True True \n", "Is_month_start False False \n", "Is_quarter_end False False \n", "Is_quarter_start False False \n", "Is_year_end False False \n", "Is_year_start False False \n", "Elapsed 1438300800 1438300800 \n", "StoreType c a \n", "Assortment a a \n", "CompetitionDistance 1270 570 \n", "CompetitionOpenSinceMonth 9 11 \n", "CompetitionOpenSinceYear 2008 2007 \n", "Promo2 0 1 \n", "Promo2SinceWeek 1 13 \n", "... ... ... \n", "Min_Sea_Level_PressurehPa 1015 1017 \n", "Max_VisibilityKm 31 10 \n", "Mean_VisibilityKm 15 10 \n", "Min_VisibilitykM 10 10 \n", "Max_Wind_SpeedKm_h 24 14 \n", "Mean_Wind_SpeedKm_h 11 11 \n", "Max_Gust_SpeedKm_h NaN NaN \n", "Precipitationmm 0 0 \n", "CloudCover 1 4 \n", "Events Fog Fog \n", "WindDirDegrees 13 309 \n", "StateName Hessen Thueringen \n", "CompetitionOpenSince 2008-09-15 00:00:00 2007-11-15 00:00:00 \n", "CompetitionDaysOpen 2510 2815 \n", "CompetitionMonthsOpen 24 24 \n", "Promo2Since 1900-01-01 00:00:00 2010-03-29 00:00:00 \n", "Promo2Days 0 1950 \n", "Promo2Weeks 0 25 \n", "AfterSchoolHoliday 0 0 \n", "BeforeSchoolHoliday 0 0 \n", "AfterStateHoliday 57 67 \n", "BeforeStateHoliday 0 0 \n", "AfterPromo 0 0 \n", "BeforePromo 0 0 \n", "SchoolHoliday_bw 5 5 \n", "StateHoliday_bw 0 0 \n", "Promo_bw 5 5 \n", "SchoolHoliday_fw 7 1 \n", "StateHoliday_fw 0 0 \n", "Promo_fw 5 1 \n", "\n", " 2 3 \\\n", "index 2 3 \n", "Store 3 4 \n", "DayOfWeek 5 5 \n", "Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n", "Sales 8314 13995 \n", "Customers 821 1498 \n", "Open 1 1 \n", "Promo 1 1 \n", "StateHoliday False False \n", "SchoolHoliday 1 1 \n", "Year 2015 2015 \n", "Month 7 7 \n", "Week 31 31 \n", "Day 31 31 \n", "Dayofweek 4 4 \n", "Dayofyear 212 212 \n", "Is_month_end True True \n", "Is_month_start False False \n", "Is_quarter_end False False \n", "Is_quarter_start False False \n", "Is_year_end False False \n", "Is_year_start False False \n", "Elapsed 1438300800 1438300800 \n", "StoreType a c \n", "Assortment a c \n", "CompetitionDistance 14130 620 \n", "CompetitionOpenSinceMonth 12 9 \n", "CompetitionOpenSinceYear 2006 2009 \n", "Promo2 1 0 \n", "Promo2SinceWeek 14 1 \n", "... ... ... \n", "Min_Sea_Level_PressurehPa 1017 1014 \n", "Max_VisibilityKm 31 10 \n", "Mean_VisibilityKm 14 10 \n", "Min_VisibilitykM 10 10 \n", "Max_Wind_SpeedKm_h 14 23 \n", "Mean_Wind_SpeedKm_h 5 16 \n", "Max_Gust_SpeedKm_h NaN NaN \n", "Precipitationmm 0 0 \n", "CloudCover 2 6 \n", "Events Fog None \n", "WindDirDegrees 354 282 \n", "StateName NordrheinWestfalen Berlin \n", "CompetitionOpenSince 2006-12-15 00:00:00 2009-09-15 00:00:00 \n", "CompetitionDaysOpen 3150 2145 \n", "CompetitionMonthsOpen 24 24 \n", "Promo2Since 2011-04-04 00:00:00 1900-01-01 00:00:00 \n", "Promo2Days 1579 0 \n", "Promo2Weeks 25 0 \n", "AfterSchoolHoliday 0 0 \n", "BeforeSchoolHoliday 0 0 \n", "AfterStateHoliday 57 67 \n", "BeforeStateHoliday 0 0 \n", "AfterPromo 0 0 \n", "BeforePromo 0 0 \n", "SchoolHoliday_bw 5 5 \n", "StateHoliday_bw 0 0 \n", "Promo_bw 5 5 \n", "SchoolHoliday_fw 5 1 \n", "StateHoliday_fw 0 0 \n", "Promo_fw 5 1 \n", "\n", " 4 \n", "index 4 \n", "Store 5 \n", "DayOfWeek 5 \n", "Date 2015-07-31 00:00:00 \n", "Sales 4822 \n", "Customers 559 \n", "Open 1 \n", "Promo 1 \n", "StateHoliday False \n", "SchoolHoliday 1 \n", "Year 2015 \n", "Month 7 \n", "Week 31 \n", "Day 31 \n", "Dayofweek 4 \n", "Dayofyear 212 \n", "Is_month_end True \n", "Is_month_start False \n", "Is_quarter_end False \n", "Is_quarter_start False \n", "Is_year_end False \n", "Is_year_start False \n", "Elapsed 1438300800 \n", "StoreType a \n", "Assortment a \n", "CompetitionDistance 29910 \n", "CompetitionOpenSinceMonth 4 \n", "CompetitionOpenSinceYear 2015 \n", "Promo2 0 \n", "Promo2SinceWeek 1 \n", "... ... \n", "Min_Sea_Level_PressurehPa 1016 \n", "Max_VisibilityKm 10 \n", "Mean_VisibilityKm 10 \n", "Min_VisibilitykM 10 \n", "Max_Wind_SpeedKm_h 14 \n", "Mean_Wind_SpeedKm_h 11 \n", "Max_Gust_SpeedKm_h NaN \n", "Precipitationmm 0 \n", "CloudCover 4 \n", "Events None \n", "WindDirDegrees 290 \n", "StateName Sachsen \n", "CompetitionOpenSince 2015-04-15 00:00:00 \n", "CompetitionDaysOpen 107 \n", "CompetitionMonthsOpen 3 \n", "Promo2Since 1900-01-01 00:00:00 \n", "Promo2Days 0 \n", "Promo2Weeks 0 \n", "AfterSchoolHoliday 0 \n", "BeforeSchoolHoliday 0 \n", "AfterStateHoliday 57 \n", "BeforeStateHoliday 0 \n", "AfterPromo 0 \n", "BeforePromo 0 \n", "SchoolHoliday_bw 5 \n", "StateHoliday_bw 0 \n", "Promo_bw 5 \n", "SchoolHoliday_fw 1 \n", "StateHoliday_fw 0 \n", "Promo_fw 1 \n", "\n", "[93 rows x 5 columns]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head().T" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "844338" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat_vars = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n", " 'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n", " 'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n", " 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n", "\n", "cont_vars = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n", " 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n", " 'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n", " 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']\n", "\n", "n = len(train_df); n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = np.random.permutation(range(n))[:2000]\n", "idx.sort()\n", "small_train_df = train_df.iloc[idx[:1000]]\n", "small_test_df = train_df.iloc[idx[1000:]]\n", "small_cont_vars = ['CompetitionDistance', 'Mean_Humidity']\n", "small_cat_vars = ['Store', 'DayOfWeek', 'PromoInterval']\n", "small_train_df = small_train_df[small_cat_vars+small_cont_vars + ['Sales']]\n", "small_test_df = small_test_df[small_cat_vars+small_cont_vars + ['Sales']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
7207225None50.0679349
7617635None32240.0618022
14453344Mar,Jun,Sept,Dec4040.0736050
2302773Jan,Apr,Jul,Oct1090.0547865
24241993Mar,Jun,Sept,Dec6360.0639121
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "720 722 5 None 50.0 67 \n", "761 763 5 None 32240.0 61 \n", "1445 334 4 Mar,Jun,Sept,Dec 4040.0 73 \n", "2302 77 3 Jan,Apr,Jul,Oct 1090.0 54 \n", "2424 199 3 Mar,Jun,Sept,Dec 6360.0 63 \n", "\n", " Sales \n", "720 9349 \n", "761 8022 \n", "1445 6050 \n", "2302 7865 \n", "2424 9121 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
4188452764Mar,Jun,Sept,Dec2960.0514892
4189984294Jan,Apr,Jul,Oct16350.0675242
4193988304Jan,Apr,Jul,Oct6320.0516087
4200073253Feb,May,Aug,Nov350.0597110
42069210113Feb,May,Aug,Nov490.0599483
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance \\\n", "418845 276 4 Mar,Jun,Sept,Dec 2960.0 \n", "418998 429 4 Jan,Apr,Jul,Oct 16350.0 \n", "419398 830 4 Jan,Apr,Jul,Oct 6320.0 \n", "420007 325 3 Feb,May,Aug,Nov 350.0 \n", "420692 1011 3 Feb,May,Aug,Nov 490.0 \n", "\n", " Mean_Humidity Sales \n", "418845 51 4892 \n", "418998 67 5242 \n", "419398 51 6087 \n", "420007 59 7110 \n", "420692 59 9483 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "categorify = Categorify(small_cat_vars, small_cont_vars)\n", "categorify(small_train_df)\n", "categorify(small_test_df, test=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
418845276.04Mar,Jun,Sept,Dec2960.0514892
418998429.04Jan,Apr,Jul,Oct16350.0675242
419398830.04Jan,Apr,Jul,Oct6320.0516087
420007325.03Feb,May,Aug,Nov350.0597110
420692NaN3Feb,May,Aug,Nov490.0599483
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "418845 276.0 4 Mar,Jun,Sept,Dec 2960.0 51 \n", "418998 429.0 4 Jan,Apr,Jul,Oct 16350.0 67 \n", "419398 830.0 4 Jan,Apr,Jul,Oct 6320.0 51 \n", "420007 325.0 3 Feb,May,Aug,Nov 350.0 59 \n", "420692 NaN 3 Feb,May,Aug,Nov 490.0 59 \n", "\n", " Sales \n", "418845 4892 \n", "418998 5242 \n", "419398 6087 \n", "420007 7110 \n", "420692 9483 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "720 -1\n", "761 -1\n", "1445 2\n", "2302 1\n", "2424 2\n", "dtype: int8" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df['PromoInterval'].cat.codes[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "418845 147\n", "418998 234\n", "419398 481\n", "420007 173\n", "420692 -1\n", "dtype: int16" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df['Store'].cat.codes[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fill_missing = FillMissing(small_cat_vars, small_cont_vars)\n", "fill_missing(small_train_df)\n", "fill_missing(small_test_df, test=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySalesCompetitionDistance_na
181602911NaN2620.08312663True
360832913NaN2620.0775479True
881242911NaN2620.07710660True
3110842913NaN2620.0739244True
3316512915NaN2620.0816994True
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "18160 291 1 NaN 2620.0 83 \n", "36083 291 3 NaN 2620.0 77 \n", "88124 291 1 NaN 2620.0 77 \n", "311084 291 3 NaN 2620.0 73 \n", "331651 291 5 NaN 2620.0 81 \n", "\n", " Sales CompetitionDistance_na \n", "18160 12663 True \n", "36083 5479 True \n", "88124 10660 True \n", "311084 9244 True \n", "331651 6994 True " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df[small_train_df['CompetitionDistance_na'] == True]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySalesCompetitionDistance_na
584834NaN2Feb,May,Aug,Nov2620.0964772True
611734NaN1Feb,May,Aug,Nov2620.0756035True
745902NaN3NaN2620.0703654True
760633NaN2Feb,May,Aug,Nov2620.0833179True
815761291.04NaN2620.0667531True
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "584834 NaN 2 Feb,May,Aug,Nov 2620.0 96 \n", "611734 NaN 1 Feb,May,Aug,Nov 2620.0 75 \n", "745902 NaN 3 NaN 2620.0 70 \n", "760633 NaN 2 Feb,May,Aug,Nov 2620.0 83 \n", "815761 291.0 4 NaN 2620.0 66 \n", "\n", " Sales CompetitionDistance_na \n", "584834 4772 True \n", "611734 6035 True \n", "745902 3654 True \n", "760633 3179 True \n", "815761 7531 True " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df[small_test_df['CompetitionDistance_na'] == True]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TODO: add something about Normalize" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/anaconda3/lib/python3.6/site-packages/pandas/io/feather_format.py:112: FutureWarning: `nthreads` argument is deprecated, pass `use_threads` instead\n", " return feather.read_dataframe(path, nthreads=nthreads)\n" ] } ], "source": [ "train_df = pd.read_feather(path/'train_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "procs=[FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_names = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n", " 'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n", " 'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n", " 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n", "\n", "cont_names = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n", " 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n", " 'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n", " 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/anaconda3/lib/python3.6/site-packages/pandas/io/feather_format.py:112: FutureWarning: `nthreads` argument is deprecated, pass `use_threads` instead\n", " return feather.read_dataframe(path, nthreads=nthreads)\n" ] } ], "source": [ "dep_var = 'Sales'\n", "train_df = pd.read_feather(path/'train_clean')\n", "df = train_df[cat_vars+cont_vars+[dep_var, 'Date']].copy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Timestamp('2015-08-01 00:00:00'), Timestamp('2015-09-17 00:00:00'))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df['Date'].min(), test_df['Date'].max()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "41088" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(test_df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "41395" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cut = train_df['Date'][(train_df['Date'] == train_df['Date'][len(test_df)])].index.max()\n", "cut" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_idx = range(cut)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", " .split_by_idx(valid_idx)\n", " .label_from_df(cols=dep_var, label_cls=FloatList, log=True)\n", " .databunch())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_log_y = np.log(np.max(train_df['Sales']))\n", "y_range = torch.tensor([0, max_log_y*1.2], device=defaults.device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_szs = data.get_emb_szs({})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = TabularModel(emb_szs, len(cont_vars), 1, [1000,500], [0.001,0.01], emb_drop=0.04, y_range=y_range)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TabularModel(\n", " (embeds): ModuleList(\n", " (0): Embedding(1116, 50)\n", " (1): Embedding(8, 5)\n", " (2): Embedding(4, 3)\n", " (3): Embedding(13, 7)\n", " (4): Embedding(32, 17)\n", " (5): Embedding(3, 2)\n", " (6): Embedding(26, 14)\n", " (7): Embedding(27, 14)\n", " (8): Embedding(5, 3)\n", " (9): Embedding(4, 3)\n", " (10): Embedding(4, 3)\n", " (11): Embedding(24, 13)\n", " (12): Embedding(9, 5)\n", " (13): Embedding(13, 7)\n", " (14): Embedding(53, 27)\n", " (15): Embedding(22, 12)\n", " (16): Embedding(7, 4)\n", " (17): Embedding(7, 4)\n", " (18): Embedding(4, 3)\n", " (19): Embedding(4, 3)\n", " (20): Embedding(9, 5)\n", " (21): Embedding(9, 5)\n", " (22): Embedding(3, 2)\n", " (23): Embedding(3, 2)\n", " )\n", " (emb_drop): Dropout(p=0.04)\n", " (bn_cont): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): Linear(in_features=229, out_features=1000, bias=True)\n", " (1): ReLU(inplace)\n", " (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.001)\n", " (4): Linear(in_features=1000, out_features=500, bias=True)\n", " (5): ReLU(inplace)\n", " (6): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): Dropout(p=0.01)\n", " (8): Linear(in_features=500, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1115,\n", " 7,\n", " 3,\n", " 12,\n", " 31,\n", " 2,\n", " 25,\n", " 26,\n", " 4,\n", " 3,\n", " 3,\n", " 23,\n", " 8,\n", " 12,\n", " 52,\n", " 21,\n", " 6,\n", " 6,\n", " 3,\n", " 3,\n", " 8,\n", " 8,\n", " 2,\n", " 2]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[len(v) for k,v in data.train_ds.classes.items()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(data.train_ds.cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model)\n", "learn.loss_fn = F.mse_loss\n", "learn.metrics = [exp_rmspe]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4HOW9/v/3R70Xq1myLDfcbdxksAmmhBICoYWeEAghkAopJOfk/JITkpMKSU4SDufHiUMgkIRqSggBU0wxAYwtd+Pe1CxZsq1qq+v5/rErI4xsy7Z2Z1d7v65rL+/OzO7cux7tZ2eeeZ4x5xwiIhK5orwOICIi3lIhEBGJcCoEIiIRToVARCTCqRCIiEQ4FQIRkQinQiAiEuFUCEREIpwKgYhIhIvxOkB/ZGdnu5EjR3odQ0QkrCxfvnyPcy7naMuFRSEYOXIkJSUlXscQEQkrZlban+V0aEhEJMKpEIiIRDgVAhGRCKdCICIS4VQIREQinAqBiEiEUyEQEYlwKgQiIiFoa00z//3KZmoaWwO+LhUCEZEQtGT7Xu5ZtIX2ru6Ar0uFQEQkBG2oaiQtIYZhGYkBX1fACoGZPWBmNWa2rte0IWb2iplt8f+bGaj1i4iEs/VVjUwqSMPMAr6uQO4R/Bm44JBp3wMWOefGAov8j0VEpJeubsfGqiYm5acHZX0BKwTOucXAvkMmXwo85L//EHBZoNYvIhKudu7dT0tHFxPzU4OyvmC3EeQ556r896uBvMMtaGa3mlmJmZXU1tYGJ52ISAhYv6sRgEkFaUFZn2eNxc45B7gjzJ/vnCt2zhXn5Bx1OG0RkUFjQ1UjsdHG2NzBuUew28zyAfz/1gR5/SIiIW99VSMn5aYSFxOcr+hgF4LngBv9928E/h7k9YuIhLz1uxqD1j4AgT199FHgXWC8mVWY2c3AL4HzzGwLcK7/sYiI+O1pbqOmqY1J+cFpH4AAXqrSOXfdYWadE6h1ioiEuw1VwW0oBvUsFhEJKQfPGAriHoEKgYhICFlf1ciwjEQykuKCtk4VAhGREBLshmJQIRARCRmtHV1s37M/qIeFQIVARCRkbN7dRFe3C2pDMagQiIiEjA8aioMz2FwPFQIRkRCxvqqRlPgYCjMDfw2C3lQIRERCxIYqX0NxVFTgr0HQmwqBiEgI6O52bKhqYmKQG4pBhUBEJCRU1rfQ3NbJhKEqBCIiEWlTdRMA44cGtw8BqBCIiISETbtVCEREItrG6iYKMxNJiQ/YWKCHpUIgIhICNlU3MsGDvQFQIRAR8Vx7Zzfba/d7clgIVAhERDy3rbaZzm7HeA/OGAIVAhERz/WcMaRDQyIiEWpjdROx0cao7GRP1q9CICLisU3VjYzJSSE22puvZBUCERGPbapu8qyhGFQIREQ81dDSwa6GVhUCEZFItXm3tw3FoEIgIuKpjQfHGPLm1FFQIRAR8dSm6kZSE2IoSE/wLIMKgYiIhzZVNzE+LxWz4F6MpjcVAhERjzjn2OjxGUOgQiAi4pmqhlaaWjs9bSgGFQIREc/0DC0xLk+FQEQkIm08OMaQd2cMgUeFwMy+YWbrzOx9M/umFxlERLy2oaqR/PQE0pNiPc0R9EJgZlOAW4BTgGnAp8zspGDnEBHx2rrKBqYOS/c6hid7BBOB95xzB5xzncCbwKc9yCEi4pnG1g6279nPyYWRWQjWAfPMLMvMkoALgeGHLmRmt5pZiZmV1NbWBj2kiEggratsAGBKJO4ROOc2AHcBLwMLgVVAVx/LzXfOFTvninNycoKcUkQksNZW+ApBpB4awjn3J+fcLOfcGUAdsNmLHCIiXllb2cCwjESyUuK9jkKMFys1s1znXI2ZFeFrH5jjRQ4REa+srWwIifYB8KgQAE+ZWRbQAXzNOVfvUQ4RkaBrONBB6d4DXDP7I82jnvCkEDjn5nmxXhGRULDW31B88rAMj5P4qGexiEiQran0HQQJhYZiUCEQEQm6dZUNjMhK8rxHcQ8VAhGRIFtT0RAS/Qd6qBCIiATRvv3tVNS1cLIKgYhIZOppKJ4aIqeOggqBiEhQra3wNRTr0JCISIRaW9nA6Oxk0hJCo6EYVAhERIJqbUVDSB0WAhUCEZGgqW1qY1dDa8j0H+ihQiAiEiSry33tAycXhkaP4h4qBCIiQbKirI6YKAuZweZ6qBCIiATJirI6JhWkkRAb7XWUD1EhEBEJgs6ublaXNzCzKNPrKB+hQiAiEgQbq5to6ehiRlFotQ+ACoGISFCsLKsD0B6BiEikWlFWT05qPIWZiV5H+QgVAhGRIFhRVsfMogzMzOsoH6FCICISYHua2yjdeyAkDwuBCoGISMCtLPN1JJs5QoVARCQi9XQkC7WhJXqoEIiIBNiK0jomh2BHsh4qBCIiAdTZ1c2aigZmhGj7AKgQiIgEVE9HslBtHwAVAhGRgPqgI1no9SjuoUIgIhJAK8rqyU2NZ1hG6HUk66FCICISQMtL65hZlBmSHcl6qBCIiARITWMrZfsOUDwydNsHQIVARCRgSkp97QOzQrihGFQIREQCpmRnHQmxUUwuCM2OZD08KQRm9i0ze9/M1pnZo2aW4EUOEZFAKindx/ThGcTFhPZv7qCnM7NhwO1AsXNuChANXBvsHCIigXSgvZP3dzVSPGKI11GOyqsyFQMkmlkMkATs8iiHiEhArCqrp6vbhXxDMXhQCJxzlcCvgTKgCmhwzr0c7BwiIoG0bGcdZqE74mhvXhwaygQuBUYBBUCymV3fx3K3mlmJmZXU1tYGO6aIyAkpKd3H+LxU0hJivY5yVF4cGjoX2OGcq3XOdQBPA6cdupBzbr5zrtg5V5yTkxP0kCIix6ur27GyrD4sDguBN4WgDJhjZknm62p3DrDBgxwiIgGxsbqR5rZOZo8M/YZi8KaN4D1gAbACWOvPMD/YOUREAqVkZ3h0JOsR48VKnXN3And6sW4RkUArKa0jPz0hpAea6y20ezmIiIShkp37mDUitAea602FQERkAFXWt1DV0Bo27QOgQiAiMqBKdu4Dwqd9APpZCMxsjJnF+++fZWa3m1noXm5HRMQj/1xTRXZKHBOGpnodpd/6u0fwFNBlZifhO8NnOPBIwFKJiISh2qY2XttYwxUzC4mJDp8DLv1N2u2c6wQuB/7HOfddID9wsUREws+zKyvp7HZcVVzodZRj0t9C0GFm1wE3As/7p4V+v2kRkSBxzvF4STkzizI4KTd8DgtB/wvBTcBc4GfOuR1mNgr4S+BiiYiEl5Xl9Wytaeaa2cO9jnLM+tWhzDm3Ht81BHoGjUt1zt0VyGAiIuHkyZJyEmOjuejkAq+jHLP+njX0hpmlmdkQfEND/NHM/juw0UREwsOB9k7+sbqKi07OJyXekwEbTkh/Dw2lO+cagU8DDzvnTsU3iqiISMR7YW01zW2dXF0cfoeFoP+FIMbM8oGr+aCxWEREgCdKyhmVnczsMBl2+lD9LQT/BbwEbHPOLTOz0cCWwMUSEQkPO/fsZ+mOfVw5qzBsxhY6VH8bi58Enuz1eDtwRaBCiYiEiwXLK4gyuGJmePUd6K2/jcWFZvaMmdX4b0+ZWfi+axGRAdDV7XhqRQXzxuYwND3B6zjHrb+Hhh4EnsN3jeEC4B/+aSIiEeudbXuoamgNu57Eh+pvIchxzj3onOv03/4M6ELCIhLRFiyvID0xlnMn5nkd5YT0txDsNbPrzSzaf7se2BvIYCIioayhpYOF66q5ZFoBCbHRXsc5If0tBF/Ad+poNVAFXAl8PkCZRERC3vNrdtHW2R32h4Wgn4XAOVfqnLvEOZfjnMt1zl3GID5rqLapjfve2MbL71fT1NrhdRwRCUELllcwLi+FqcPSvY5ywk6kL/S3gd8NVJBQ4Jzj2VWV/Pgf66k/4CsAMVHGzKJMRuckU93YSlV9K7ubWpk3NocfXDSRvLRjP1Ogpb2L1RX1rCiro6W9i7SEWNISY0hPjKUgI5HCzCQyk2LD9pxkkcFua00TK8vq+f6FEwfF3+mJFILwf/d+Le1dlNcd4JcvbuS1jTXMLMrgZ5dPpf5AB29tqWXxllpe3bCboekJFGUlMXlYGs+vqeL1jTV867xx3Dh3xIcuQtHa0cWKsjre2bqXktJ9dHQ5YqKM2Ogomlo7eH9XI53dDgAzcO6jmZLiopmUn8Z1pxRx0cn5YX8MUmQwWbC8kugo47IZw7yOMiDM9fUt1J8nmpU554oGOE+fiouLXUlJyYC93u7GVh54ewevvL+bmqY2mts6AUiMjea7nxjPjaeNJDrqyHWudO9+fvj393lzcy1FQ5LITI6jo7Ob9q5uyvcdoK2zm+goY0pBGqkJsXR0ddPR1U18TDTTizIoHpHJzKJM0hNj2d/eSWNrJ3X729lV30JFXQvldQdYvLmWbbX7yUiK5apZhUwYmgb4ikdSXAxnjsshMU4FQiSYOrq6mfuL15g+PJ37b5ztdZwjMrPlzrnioy13xD0CM2sC+qoUBiQeZzZPOOfYtLuJB/61g2dWVtLV7ThrfC5njs8hNzWBnNR45oweQmFmUr9eb0RWMn++aTYL11XzyNIyAOJjooiLieLMcTmcNiaLU0YNITXh6NfvSU2IJTUhlmEZiUzpdbzROce72/fy1yWlPPD2Trq63SHPi+GKmYV85tQixuWF14UwRMLVog017Glu49rZQfkdHBTHvUcQTMe7R7C8dB9Ltu9jZVkdK8vq2bu/nYTYKK4pHs4X541m+JD+femHgvoD7TS2dOL8dbmyroXHlpWzcF017V3djM9LZcqwdCYXpDF+aCp7mtvYWtPM5t1NHGjv4oyxOZw3KY+R2ckevxOR8Pb5B5eyoaqRt//94yF/XeL+7hEM6kJw04NLeX1TLaNzkpkxPJOZIzL45JR8hiTHBSClN/Y2t/HUigre3rqX93c1sqe57eC86ChjRFYSMVHG5t3NAIzNTeHq4uF8bu4ItTuIHKPK+hZOv+s1vn72Sdxx/niv4xyVCgG+4/jpibFkJA2eL/6jqWlsZfPuZnJS4xmZnUR8jO/LvnzfAV5Zv5sX11WxbGcdQ9MSuP2csVxVXEhsiP+qEQkVv31lM/e8toXF3z07LI4oqBDIYb27bS93v7SRlWX1jMhK4pwJeRSPzKR4RCa5x3E6rEgk6Op2zLvrNcbkpvCXm0/1Ok6/DEhjsQxOc8dk8fRXTmPRhhr+9K8dPLK0lAfe3gFAWkIMsdFRREUZMVHGhKGpvkb1cTlqX5CItnhLLbsaWvn+RZO8jjLggl4IzGw88HivSaOBHzrnBlXntFBnZpw7KY9zJ+XR3tnN+qpGSnbuo3zfAbqco6sb2jq7WFlWz53PvQ9AXlo8yfExxEQZ0VFRZCXHMTYvhXF5qYzLS2FyQbraHWTQemxpGVnJcZw3KbwHmOtL0AuBc24TMB3AzKKBSuCZYOeQD8TFRDF9eAbTh2f0Ob90737e3FzLqrJ62ru66exydHZ3U9PUxmNLy2np6AIgNtqYMiyd2SOHMCo7mdqmNqoaWtnd2EpmUhzTizKYMTyDCUNTQ/5sC5HeappaWbShhi+cPoq4mMG37Xp9aOgcfJe/LPU4hxzBiKxkbpibzA1zPzqvu9tRWd/ChqpGVpTVU7JzH39+eyftXd0AZKfEkZuawJqKep5aUQFAQmwU0wozmDUik1kjMpk+PIOslPhgviWRY/JkSQWd3Y5rZofnxemPxutCcC3wqMcZ5ARERRnDhyQxfEgS508eCviG2NjT3EZOavzBs5acc1TUtbCyvJ4VpXWsLKtj/uLtB4fayEmNZ2J+GhOHpjIiK5mCjAQKMhIZlpFIcrzXm6lEsu5ux6NLy5gzeghjclK8jhMQnv2FmVkccAnwH4eZfytwK0BR0eDpwRcJEmKjP9JD2+yDgnHJtALgg8H31lU2sKGqiY3VjTz49t6DexM9CtITOCkvlXG5KYwbmsqk/DTG5qUcLDIigbR4Sy0VdS38+wUTvI4SMF7+1PoksMI5t7uvmc65+cB88J0+GsxgEhyJcdHMGZ3FnNFZB6d1dnVT29zGrvoWKutbKd93gC27m9i8u5n3tu+lrdNXJGKijDE5KeSkxpOWGENaQixDkuMYkZXEiKxkRmUnk5saPyhGhhRvPfKer5H4E/493sHIy0JwHTosJIeIiY4iPz2R/PREZo348LyubsfOvftZv6uRDVWNbN7dxL797VQ3ttLQ0kHd/vaDh5oARmUnc1VxIVfOLFT/CDku1Q2tLNpYwy3zRg/KRuIenhQCM0sGzgO+5MX6JTxF+/cCxuSkcLH/8FJvnV3d7KpvZefe/WyvbeaFtdXcvXATv3l5M2ePz+GCKfmcPT5HDdPSb48vK6er23HdKYOzkbiHJ4XAObcfyDrqgiLHICY6iqKsJIqykjhjXA6f/9gottU280RJOc+urOTVDTWYwcyiTE4ZNYTslHiykuPIToln6rB00pOOPlKsRI7Orm4eW1bGvLHZjMga3J0pdTqGDGpjclL4j09O5HsXTGBdZSOvbtjNoo27mb94+4eG9Y6OMopHZHLOxFzOGp/LmJyUo16TQga3NzbVUtXQyp0XD76exIfSWEMSkbq7HY2tHezd387uhlbe2baXVzfsZmN1E+C7tsRJuSmMH5rKjKJMzpmQS0FGWF2CQ07QF/68jHWVDbz9vY+H7cCMGnRO5DhU1rfw7ra9bKpuZGN1Exurm6ht8g3tPSk/jXMn5nLpjGGD9nxy8amoO8C8u18Pm+GmD0eDzokch2EZiVw5q/DgY+ccW2uaWbSxhkUbdnPv61u557WtzB6ZyTWzi7hw6lCS4vRnNNg8vqwcYND2JD6U9ghEjkFNYytPrajkiZJyduzZT2p8DBdPL+Da2cOZOixd/RYGgY6ubj72y9eYXJDGgzed4nWcE6I9ApEAyE1L4CtnjeHLZ45m6Y59PF5SztMrKnjkvTImDE3lE5OHUjzSN35Sf65XLaFn0YYaapra+NmpI46+8CChQiByHMyMU0dnceroLH50yWSeW7WLJ5dXcM9rW3AOzGByQRp3nD+es8fneh1XjsGjS8sYmpbA2eNzvI4SNCoEIicoLSGW6+eM4Po5I2hq7WBVeT3LS+t4bvUubnpwGZdNL+A/PzVJHdnCQPm+AyzeUsttHx8bUUOlR847FQmC1IRY5o3N4ZvnjuPFb8zjG+eM5Z9rqzj3v9/k6RUVhEObXCR7bFkZBlwbIY3EPVQIRAIkPiaab503jn/ePo9R2cl8+4nV3PxQCdUNrV5Hkz50dHXzREkFZ4+PvD4jKgQiATYuL5Unv3wa//mpSbyzbQ/n/fZNnigp195BiHl1/W5qm9r4zKmRN+y9CoFIEERHGTefPoqF3ziDiflp/NuCNVz/p/fYXtvsdTTxe2xZOfnpCZwVgY37KgQiQTQyO5nHbpnDTy6dzJryBi743Vv87tXNtPqv+yzeqGpoYfGWWq6cVRiRY0ypEIgEWVSU8bm5I1l0x5l8YspQfvfqFi685y221mjvwCtPr6jEOT7UqzySqBCIeCQ3LYH/uW4GD3/hFBpbOrjivndYsn2v17EijnOOJ0vKOXXUkEE/3PThqBCIeOyMcTk889WPkZMaz+f+9B5Pr6jwOlJEWbazjp17D3B1cWSdMtqbCoFICBg+JImnvnIas0cO4dtPrOYXL2xQu0GQPFFSTkp8DJ+cOnivSXw0KgQiISI9MZY/33QKnzm1iD8s3s4nfreYt7bUeh1rUGtu6+SFtVV86uT8iB5FVoVAJITExUTx88un8sgtpxJlxuf+tJRvPraSxtYOr6MNSi+sqeJAexdXFUdmI3EPFQKREHTamGxe/MY8bj9nLM+vqeKWh0p0qCgAnlxezuicZGYWZXodxVMqBCIhKiE2mm+fN45fXzWN93bs444nVtPdrd7IA2XHnv0s21nHVbOGR/x1JCL3oJhImLhsxjBqm9r42QsbyEmN586LJ0X8F9dAWLC8nCiDT88c5nUUz6kQiISBW84YTXVjK3/61w7y/BfHkePX1e14ekUlZ47LIS8twes4ntOhIZEw8f0LJ3LxtALuWriRNzbVeB0nrL29dQ9VDa1cOSty+w70pkIgEiaiooy7rziZCUNT+dbjq9hV3+J1pLC1YHkF6YmxnDsp8gaY64sKgUgYSYyL5n8/O5P2zm5ue3QlHV3dXkcKOw0tHbz0fjWXTi8gPiba6zghQYVAJMyMyUnhl1eczPLSOn710iav44Sd59fsoq2zm6t0WOggFQKRMHTxtAKun1PE/MXbeen9aq/jhJUFyysYn5fKlGFpXkcJGSoEImHqBxdN4uTCdO54YjVbdjd5HScsbK1pYmVZPVfOKtQpuL14UgjMLMPMFpjZRjPbYGZzvcghEs4SYqP5w+dmkRAbzRcfLqH+QLvXkULeguWVREcZl81Q34HevNoj+D2w0Dk3AZgGbPAoh0hYy09P5A+fm0lVfStfe2QFnWo8PqyubsezKys5a1wOOanxXscJKUEvBGaWDpwB/AnAOdfunKsPdg6RwWLWiCH89LIpvL11Lz/9p35THc572/dS3djK5epJ/BFe7BGMAmqBB81spZndb2aReVkgkQFy9ezh3PSxkfz5nZ38fVWl13FC0jMrK0mJj+HciXleRwk5XhSCGGAmcJ9zbgawH/jeoQuZ2a1mVmJmJbW1GpNd5Gi+f+FEZhZl8INn1lG+74DXcUJKa0cXC9dVc8GUoSTEqu/AobwoBBVAhXPuPf/jBfgKw4c45+Y754qdc8U5OTlBDSgSjmKio/j9tTMAuP0xdTbrbdGGGpraOrlsug4L9SXohcA5Vw2Um9l4/6RzgPXBziEyGA0fksTPPz2VlWX1/P7VLV7HCRnPrqokNzWeuWOyvI4Skrw6a+g24G9mtgaYDvzcoxwig87F0wq4uriQ/31jK+9s2+N1HM/VH2jnjU01XDKtgOgo9R3oiyeFwDm3yn/Y52Tn3GXOuTovcogMVj+6ZDKjspK544nVNLd1eh3HU/9cW0VHl1PfgSNQz2KRQSgpLoZfXz2N6sZWfrVwo9dxPPX3lbs4KTeFyQUaUuJwVAhEBqmZRZncOHckDy8pZXlpZO50V9QdYOnOfVw2vUBDShyBCoHIIPadT4wnPy2B7z21hvbOyDuLaMHyCgAu1dlCR6RCIDKIpcTH8NPLp7Clppn73tjmdZygau3o4q9LSjl7fA7DhyR5HSekqRCIDHIfn5DHJdMK+N/Xt0bUKKV/X1XJnuZ2vjhvtNdRQp4KgUgE+OHFk0hJiOGrf1vB/gg4i8g5x/1v7WBifhqnqe/AUakQiESA7JR47r1uBttqm/m3BWtwznkdKaDe3FzLlppmbpk3So3E/aBCIBIhTjspm3+/YAL/XFvFH9/a7nWcgLr/rR3kpcXzqZMLvI4SFlQIRCLIrWeM5sKpQ/nlixt5Z+vg7HW8oaqRf23dw42njSQuRl9x/aFPSSSCmBl3XzmNMTkpfO2RFWwehI3H97+1g8TYaD5zSpHXUcKGCoFIhEmJj+GPNxQTGx3FtfOXsKGq0etIA2ZtRQPPra7kquJCMpLivI4TNlQIRCLQyOxkHv/SXOKio/jMH5fw/q4GryOdsJrGVm55uITc1ARuP2es13HCigqBSIQalZ3M41+aQ1JcDJ/543usqQjfK8a2dnRxy1+W09DSwR9vKCY7RdckPhYqBCIRbERWMo/dOofUhBiu+cMSFq6r8jrSMXPO8b2n1rC6vJ7fXjOdSRpc7pipEIhEuOFDknj6K6cxfmgqX/7rCu59bUtY9TO4/60dPLtqF985fxwXTBnqdZywpEIgIuSmJfDYrXO4bHoBv355M998fBWtHV1exzqqmqZWfvvqZs6dmMfXzj7J6zhhK8brACISGhJio/ntNdMZNzSVX720ia01zfzf9bNCesC237+6hfbObn5w0UT1ID4B2iMQkYPMjK+edRL331BM2b4DXHzvv1i8udbrWH3aVtvMY8vK+eypRYzMTvY6TlhTIRCRjzhnYh7/+PrpDE1L4MYHl4bkENZ3L9xIQkwUt+lU0ROmQiAifRqZnczTXz2Ni6bmc9fCjTy2tMzrSActL93HS+/v5ktnjtGpogNAbQQiclhJcTH87prpNLZ28oNn1zEyO5k5o70d1tk5x89f2EhOajxfnDfK0yyDhfYIROSIYqKj+J/rZlCUlcRX/rqcsr0HPM3z2sYalpfW8a1zx5EUp9+yA0GFQESOKj0xlgdunE23g5sfWkZTa4dnWf7vzW0My0jkquJCzzIMNioEItIvI7OTue/6mezYs59bHi6hpT34/QyWl9axbGcdN58+ithofX0NFH2SItJvp43J5jdXT2Ppjn3c/NCyoBeD+Yu3kZ4YyzWzhwd1vYOdCoGIHJNLpw/jN1dP493te7nl4ZKg9UDeVtvMy+t3c8PcESTHq21gIOnTFJFjdvmMQrq64bsLVnPjA0s5/aRsoqKM6Chj+vCMgJxZdP9b24mNjuKGuSMH/LUjnQqBiByXK2cV4pzjB8+u470d+w5ON4MfXTyZG08bOWDrqmlq5akVlVw5q5CcVPUbGGgqBCJy3K4qHs4VMwvpco6ubkdbRzffWbCaO597n9qmNu44f9yAjAH00Ds76ejq5pZ5owcgtRzKk0JgZjuBJqAL6HTOFXuRQ0ROXFSUEYURG+0buO6+z87kP/++jntf30ptUxs/u3wKMcdxhk9nVzevb6rlkfdKeWNzLRdMHsoojSkUEF7uEZztnNvj4fpFJABioqP4+eVTyUmJ557XtrK6op4fXTL5mNoNXlxbxY//sZ7qxlZyU+P5+tkn8cXTtTcQKDo0JCIDzsz49vnjmVSQxk+e38C185dw0dR8/uPCCRRmHnlY60UbdvP1R1cyuSCNH186mXMm5B7XHoX0n1eFwAEvm5kD/uCcm+9RDhEJoAum5HPW+Fz+8OZ27ntzKy+uq2JyQTpzRg9h7pgs5ozO+tAwEUu27+Wrf1vB5II0HrllDik6TTQozItL0pnZMOdcpZnlAq8AtznnFh+yzK3ArQBFRUWzSktLg55TRAZOZX0LTywr593te1lVVk97VzfzAACOAAAJAUlEQVSJsdFcMGUon545jNSEWK6//z3y0xN4/EtzGZIc53XksGdmy/vTButJIfhQALMfAc3OuV8fbpni4mJXUlISvFAiElAt7V0sL63jn2t38fyaKppaOwEozExkwZdPY2h6gscJB4f+FoKg73eZWTIQ5Zxr8t8/H/ivYOcQEe8kxkVz+thsTh+bzZ0XT+a1jTW8tWUPXz5ztIqAB7w4AJcHPOM/tzgGeMQ5t9CDHCISAhJio7lwaj4XTs33OkrECnohcM5tB6YFe70iItI3nZMlIhLhVAhERCKcCoGISIRTIRARiXAqBCIiEU6FQEQkwqkQiIhEOM+HmOgPM2sAtvQxKx1o6Ofjvu73/JsNHM+Q2Ieurz/zjzYtFDP3Nb0/n3Vf044ndzAz976v7aP/809k++g9L9S3j1Dbpg+Xs+d+hnMu56hpnHMhfwPm92f6kR73db/XvyUDmetI8482LRQzH+9nfZhpx5w7mJm9/qwjcfs4ZF5Ibx+htk33d/s42i1cDg39o5/Tj/S4r/uHe93+Otrz+5p/tGmhmLmv6f35rA/3Xo5VMDP3vq/to//zT2T7CMfM/Vnv8WQ62vzj3T6OKCwODQWamZW4MLtcZjhmhvDMrczBE465wzHzocJljyDQwvHCOOGYGcIztzIHTzjmDsfMH6I9AhGRCKc9AhGRCDfoCoGZPWBmNWa27jieO8vM1prZVjO7x/wXTfDPu83MNprZ+2Z2d6hnNrMfmVmlma3y3y4M9cy95t9hZs7Msgcu8cHXDsRn/RMzW+P/nF82s4IwyPwr//a8xsyeMbOMMMh8lf/vr9vMBuyY/IlkPczr3WhmW/y3G3tNP+J276njOcUslG/AGcBMYN1xPHcpMAcw4EXgk/7pZwOvAvH+x7lhkPlHwHfC6XP2zxsOvASUAtnhkBtI67XM7cD/hUHm84EY//27gLvCIPNEYDzwBlDsdVZ/jpGHTBsCbPf/m+m/n3mk9xUKt0G3R+CcWwzs6z3NzMaY2UIzW25mb5nZhEOfZ2b5+P6glzjf/9rDwGX+2V8Bfumca/OvoyYMMgdUADP/Fvg3ICCNV4HI7Zxr7LVo8kBnD1Dml51znf5FlwCFYZB5g3Nu00DmPJGsh/EJ4BXn3D7nXB3wCnCBl3+r/THoCsFhzAduc87NAr4D/P99LDMMqOj1uMI/DWAcMM/M3jOzN81sdkDT+pxoZoCv+3f9HzCzzMBFPeiEMpvZpUClc251oIMe4oQ/azP7mZmVA58FfhjArD0GYvvo8QV8v1ADbSAzB1p/svZlGFDe63FP/lB5X33y4prFQWVmKcBpwJO9DsnFH+PLxODb1ZsDzAaeMLPR/so+4AYo833AT/D9Ov0J8Bt8f/ABcaKZzSwJ+P/wHbIImgH6rHHOfR/4vpn9B/B14M4BC3mIgcrsf63vA53A3wYm3WHXM2CZA+1IWc3sJuAb/mknAS+YWTuwwzl3ebCzDpRBXwjw7fXUO+em955oZtHAcv/D5/B9cfbePS4EKv33K4Cn/V/8S82sG9/4IrWhmtk5t7vX8/4IPB+grD1ONPMYYBSw2v/HVwisMLNTnHPVIZz7UH8DXiCAhYABymxmnwc+BZwTqB81vQz05xxIfWYFcM49CDwIYGZvAJ93zu3stUglcFavx4X42hIq8f59HZ7XjRSBuAEj6dXwA7wDXOW/b8C0wzzv0MacC/3Tvwz8l//+OHy7fhbimfN7LfMt4LFQ/5wPWWYnAWgsDtBnPbbXMrcBC8Ig8wXAeiAnEJ9xILcPBrix+HizcvjG4h34Gooz/feH9He79+rmeYAAbHyPAlVAB75f8jfj+6W5EFjt3/h/eJjnFgPrgG3AvXzQ4S4O+Kt/3grg42GQ+S/AWmANvl9a+aGe+ZBldhKYs4YC8Vk/5Z++Bt/4LsPCIPNWfD9oVvlvA32mUyAyX+5/rTZgN/CSl1npoxD4p3/B//luBW46lu3eq5t6FouIRLhIOWtIREQOQ4VARCTCqRCIiEQ4FQIRkQinQiAiEuFUCCQsmVlzkNd3v5lNGqDX6jLfSKXrzOwfRxv508wyzOyrA7Fukb7o9FEJS2bW7JxLGcDXi3EfDMIWUL2zm9lDwGbn3M+OsPxI4Hnn3JRg5JPIoz0CGTTMLMfMnjKzZf7bx/zTTzGzd81spZm9Y2bj/dM/b2bPmdlrwCIzO8vM3jCzBeYbq/9vPWPG+6cX++83+weZW21mS8wszz99jP/xWjP7aT/3Wt7lg0H3UsxskZmt8L/Gpf5lfgmM8e9F/Mq/7Hf973GNmf14AD9GiUAqBDKY/B74rXNuNnAFcL9/+kZgnnNuBr6RQX/e6zkzgSudc2f6H88AvglMAkYDH+tjPcnAEufcNGAxcEuv9f/eOTeVD4802Sf/ODvn4Ov5DdAKXO6cm4nvGhi/8Rei7wHbnHPTnXPfNbPzgbHAKcB0YJaZnXG09YkcTiQMOieR41xgUq8RI9P8I0mmAw+Z2Vh8o7HG9nrOK8653mPRL3XOVQCY2Sp8Y9D865D1tPPBIH7LgfP89+fywRjzjwC/PkzORP9rDwM24BuzHnxj0Pzc/6Xe7Z+f18fzz/ffVvofp+ArDIsPsz6RI1IhkMEkCpjjnGvtPdHM7gVed85d7j/e/kav2fsPeY22Xve76PtvpMN90Lh2uGWOpMU5N90/9PZLwNeAe/BdyyAHmOWc6zCznUBCH8834BfOuT8c43pF+qRDQzKYvIxv9E8AzKxnGOF0Phjy9/MBXP8SfIekAK492sLOuQP4Lm15h5nF4MtZ4y8CZwMj/Is2Aam9nvoS8AX/3g5mNszMcgfoPUgEUiGQcJVkZhW9bt/G96Va7G9AXY9v+HCAu4FfmNlKArsX/E3g22a2Bt9FSxqO9gTn3Ep8o5Zeh+9aBsVmtha4AV/bBs65vcDb/tNNf+Wcexnfoad3/csu4MOFQuSY6PRRkQHiP9TT4pxzZnYtcJ1z7tKjPU/Ea2ojEBk4s4B7/Wf61BPAS4OKDCTtEYiIRDi1EYiIRDgVAhGRCKdCICIS4VQIREQinAqBiEiEUyEQEYlw/w8qIsUcqJZiXQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 13:27\n", "epoch train_loss valid_loss exp_rmspe\n", "1 0.021706 0.019131 0.586892 (02:38)\n", "2 0.019761 0.016307 0.631732 (02:42)\n", "3 0.016764 0.016188 0.644211 (02:42)\n", "4 0.012963 0.011598 0.630723 (02:42)\n", "5 0.010889 0.011673 0.613048 (02:42)\n", "\n" ] } ], "source": [ "learn.fit_one_cycle(5, 1e-3, wd=0.2, pct_start=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(5, 1e-3, wd=0.1, pct_start=0.3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.3370771408081055" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " pct_var,cnt = 0.,0\n", " for x,y in learn.data.valid_dl:\n", " out = learn.model(*x)\n", " cnt += y.size(0)\n", " y, out = torch.exp(y), torch.exp(out)\n", " pct_var += ((y - out)/y).pow(2).sum()\n", "torch.sqrt(pct_var/cnt).item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }