{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.tabular import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Rossmann" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create the feature-engineered train_clean and test_clean from the Kaggle competition data, run `rossman_data_clean.ipynb`. One important step that deals with time series is this:\n", "\n", "```python\n", "add_datepart(train, \"Date\", drop=False)\n", "add_datepart(test, \"Date\", drop=False)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()/'rossmann'\n", "train_df = pd.read_pickle(path/'train_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
index01234
Store12345
DayOfWeek55555
Date2015-07-312015-07-312015-07-312015-07-312015-07-31
Sales526360648314139954822
Customers5556258211498559
Open11111
Promo11111
StateHolidayFalseFalseFalseFalseFalse
SchoolHoliday11111
Year20152015201520152015
Month77777
Week3131313131
Day3131313131
Dayofweek44444
Dayofyear212212212212212
Is_month_endTrueTrueTrueTrueTrue
Is_month_startFalseFalseFalseFalseFalse
Is_quarter_endFalseFalseFalseFalseFalse
Is_quarter_startFalseFalseFalseFalseFalse
Is_year_endFalseFalseFalseFalseFalse
Is_year_startFalseFalseFalseFalseFalse
Elapsed14383008001438300800143830080014383008001438300800
StoreTypecaaca
Assortmentaaaca
CompetitionDistance12705701413062029910
CompetitionOpenSinceMonth9111294
CompetitionOpenSinceYear20082007200620092015
Promo201100
Promo2SinceWeek1131411
..................
Min_Sea_Level_PressurehPa10151017101710141016
Max_VisibilityKm3110311010
Mean_VisibilityKm1510141010
Min_VisibilitykM1010101010
Max_Wind_SpeedKm_h2414142314
Mean_Wind_SpeedKm_h111151611
Max_Gust_SpeedKm_hNaNNaNNaNNaNNaN
Precipitationmm00000
CloudCover14264
EventsFogFogFogNaNNaN
WindDirDegrees13309354282290
StateNameHessenThueringenNordrheinWestfalenBerlinSachsen
CompetitionOpenSince2008-09-152007-11-152006-12-152009-09-152015-04-15
CompetitionDaysOpen2510281531502145107
CompetitionMonthsOpen242424243
Promo2Since1900-01-012010-03-292011-04-041900-01-011900-01-01
Promo2Days01950157900
Promo2Weeks0252500
AfterSchoolHoliday00000
BeforeSchoolHoliday00000
AfterStateHoliday5767576757
BeforeStateHoliday00000
AfterPromo00000
BeforePromo00000
SchoolHoliday_bw55555
StateHoliday_bw00000
Promo_bw55555
SchoolHoliday_fw71511
StateHoliday_fw00000
Promo_fw51511
\n", "

93 rows × 5 columns

\n", "
" ], "text/plain": [ " 0 1 2 \\\n", "index 0 1 2 \n", "Store 1 2 3 \n", "DayOfWeek 5 5 5 \n", "Date 2015-07-31 2015-07-31 2015-07-31 \n", "Sales 5263 6064 8314 \n", "Customers 555 625 821 \n", "Open 1 1 1 \n", "Promo 1 1 1 \n", "StateHoliday False False False \n", "SchoolHoliday 1 1 1 \n", "Year 2015 2015 2015 \n", "Month 7 7 7 \n", "Week 31 31 31 \n", "Day 31 31 31 \n", "Dayofweek 4 4 4 \n", "Dayofyear 212 212 212 \n", "Is_month_end True True True \n", "Is_month_start False False False \n", "Is_quarter_end False False False \n", "Is_quarter_start False False False \n", "Is_year_end False False False \n", "Is_year_start False False False \n", "Elapsed 1438300800 1438300800 1438300800 \n", "StoreType c a a \n", "Assortment a a a \n", "CompetitionDistance 1270 570 14130 \n", "CompetitionOpenSinceMonth 9 11 12 \n", "CompetitionOpenSinceYear 2008 2007 2006 \n", "Promo2 0 1 1 \n", "Promo2SinceWeek 1 13 14 \n", "... ... ... ... \n", "Min_Sea_Level_PressurehPa 1015 1017 1017 \n", "Max_VisibilityKm 31 10 31 \n", "Mean_VisibilityKm 15 10 14 \n", "Min_VisibilitykM 10 10 10 \n", "Max_Wind_SpeedKm_h 24 14 14 \n", "Mean_Wind_SpeedKm_h 11 11 5 \n", "Max_Gust_SpeedKm_h NaN NaN NaN \n", "Precipitationmm 0 0 0 \n", "CloudCover 1 4 2 \n", "Events Fog Fog Fog \n", "WindDirDegrees 13 309 354 \n", "StateName Hessen Thueringen NordrheinWestfalen \n", "CompetitionOpenSince 2008-09-15 2007-11-15 2006-12-15 \n", "CompetitionDaysOpen 2510 2815 3150 \n", "CompetitionMonthsOpen 24 24 24 \n", "Promo2Since 1900-01-01 2010-03-29 2011-04-04 \n", "Promo2Days 0 1950 1579 \n", "Promo2Weeks 0 25 25 \n", "AfterSchoolHoliday 0 0 0 \n", "BeforeSchoolHoliday 0 0 0 \n", "AfterStateHoliday 57 67 57 \n", "BeforeStateHoliday 0 0 0 \n", "AfterPromo 0 0 0 \n", "BeforePromo 0 0 0 \n", "SchoolHoliday_bw 5 5 5 \n", "StateHoliday_bw 0 0 0 \n", "Promo_bw 5 5 5 \n", "SchoolHoliday_fw 7 1 5 \n", "StateHoliday_fw 0 0 0 \n", "Promo_fw 5 1 5 \n", "\n", " 3 4 \n", "index 3 4 \n", "Store 4 5 \n", "DayOfWeek 5 5 \n", "Date 2015-07-31 2015-07-31 \n", "Sales 13995 4822 \n", "Customers 1498 559 \n", "Open 1 1 \n", "Promo 1 1 \n", "StateHoliday False False \n", "SchoolHoliday 1 1 \n", "Year 2015 2015 \n", "Month 7 7 \n", "Week 31 31 \n", "Day 31 31 \n", "Dayofweek 4 4 \n", "Dayofyear 212 212 \n", "Is_month_end True True \n", "Is_month_start False False \n", "Is_quarter_end False False \n", "Is_quarter_start False False \n", "Is_year_end False False \n", "Is_year_start False False \n", "Elapsed 1438300800 1438300800 \n", "StoreType c a \n", "Assortment c a \n", "CompetitionDistance 620 29910 \n", "CompetitionOpenSinceMonth 9 4 \n", "CompetitionOpenSinceYear 2009 2015 \n", "Promo2 0 0 \n", "Promo2SinceWeek 1 1 \n", "... ... ... \n", "Min_Sea_Level_PressurehPa 1014 1016 \n", "Max_VisibilityKm 10 10 \n", "Mean_VisibilityKm 10 10 \n", "Min_VisibilitykM 10 10 \n", "Max_Wind_SpeedKm_h 23 14 \n", "Mean_Wind_SpeedKm_h 16 11 \n", "Max_Gust_SpeedKm_h NaN NaN \n", "Precipitationmm 0 0 \n", "CloudCover 6 4 \n", "Events NaN NaN \n", "WindDirDegrees 282 290 \n", "StateName Berlin Sachsen \n", "CompetitionOpenSince 2009-09-15 2015-04-15 \n", "CompetitionDaysOpen 2145 107 \n", "CompetitionMonthsOpen 24 3 \n", "Promo2Since 1900-01-01 1900-01-01 \n", "Promo2Days 0 0 \n", "Promo2Weeks 0 0 \n", "AfterSchoolHoliday 0 0 \n", "BeforeSchoolHoliday 0 0 \n", "AfterStateHoliday 67 57 \n", "BeforeStateHoliday 0 0 \n", "AfterPromo 0 0 \n", "BeforePromo 0 0 \n", "SchoolHoliday_bw 5 5 \n", "StateHoliday_bw 0 0 \n", "Promo_bw 5 5 \n", "SchoolHoliday_fw 1 1 \n", "StateHoliday_fw 0 0 \n", "Promo_fw 1 1 \n", "\n", "[93 rows x 5 columns]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head().T" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "844338" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = len(train_df); n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experimenting with a sample" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = np.random.permutation(range(n))[:2000]\n", "idx.sort()\n", "small_train_df = train_df.iloc[idx[:1000]]\n", "small_test_df = train_df.iloc[idx[1000:]]\n", "small_cont_vars = ['CompetitionDistance', 'Mean_Humidity']\n", "small_cat_vars = ['Store', 'DayOfWeek', 'PromoInterval']\n", "small_train_df = small_train_df[small_cat_vars + small_cont_vars + ['Sales']]\n", "small_test_df = small_test_df[small_cat_vars + small_cont_vars + ['Sales']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
2672685NaN4520.0677492
6046065NaN2260.0617187
9839865Feb,May,Aug,Nov620.0617051
16365254NaN1870.0559673
23481233NaN16760.05010007
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "267 268 5 NaN 4520.0 67 \n", "604 606 5 NaN 2260.0 61 \n", "983 986 5 Feb,May,Aug,Nov 620.0 61 \n", "1636 525 4 NaN 1870.0 55 \n", "2348 123 3 NaN 16760.0 50 \n", "\n", " Sales \n", "267 7492 \n", "604 7187 \n", "983 7051 \n", "1636 9673 \n", "2348 10007 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
4205108293NaN110.0556802
4206549733Jan,Apr,Jul,Oct330.0596644
4209901942Feb,May,Aug,Nov16970.0554720
4213085122Mar,Jun,Sept,Dec590.0726248
42182410292NaN1590.0648004
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance \\\n", "420510 829 3 NaN 110.0 \n", "420654 973 3 Jan,Apr,Jul,Oct 330.0 \n", "420990 194 2 Feb,May,Aug,Nov 16970.0 \n", "421308 512 2 Mar,Jun,Sept,Dec 590.0 \n", "421824 1029 2 NaN 1590.0 \n", "\n", " Mean_Humidity Sales \n", "420510 55 6802 \n", "420654 59 6644 \n", "420990 55 4720 \n", "421308 72 6248 \n", "421824 64 8004 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "categorify = Categorify(small_cat_vars, small_cont_vars)\n", "categorify(small_train_df)\n", "categorify(small_test_df, test=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
420510NaN3NaN110.0556802
420654973.03Jan,Apr,Jul,Oct330.0596644
420990NaN2Feb,May,Aug,Nov16970.0554720
421308512.02Mar,Jun,Sept,Dec590.0726248
4218241029.02NaN1590.0648004
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance \\\n", "420510 NaN 3 NaN 110.0 \n", "420654 973.0 3 Jan,Apr,Jul,Oct 330.0 \n", "420990 NaN 2 Feb,May,Aug,Nov 16970.0 \n", "421308 512.0 2 Mar,Jun,Sept,Dec 590.0 \n", "421824 1029.0 2 NaN 1590.0 \n", "\n", " Mean_Humidity Sales \n", "420510 55 6802 \n", "420654 59 6644 \n", "420990 55 4720 \n", "421308 72 6248 \n", "421824 64 8004 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_test_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['Feb,May,Aug,Nov', 'Jan,Apr,Jul,Oct', 'Mar,Jun,Sept,Dec'], dtype='object')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df.PromoInterval.cat.categories" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "267 -1\n", "604 -1\n", "983 0\n", "1636 -1\n", "2348 -1\n", "dtype: int8" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df['PromoInterval'].cat.codes[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fill_missing = FillMissing(small_cat_vars, small_cont_vars)\n", "fill_missing(small_train_df)\n", "fill_missing(small_test_df, test=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySalesCompetitionDistance_na
1857496222NaN2300.0934508True
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "185749 622 2 NaN 2300.0 93 \n", "\n", " Sales CompetitionDistance_na \n", "185749 4508 True " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_train_df[small_train_df['CompetitionDistance_na'] == True]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing full data set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_pickle(path/'train_clean')\n", "test_df = pd.read_pickle(path/'test_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(844338, 41088)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_df),len(test_df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "procs=[FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_vars = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n", " 'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n", " 'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n", " 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n", "\n", "cont_vars = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n", " 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n", " 'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n", " 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = 'Sales'\n", "df = train_df[cat_vars + cont_vars + [dep_var,'Date']].copy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('2015-08-01', '2015-09-17')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df['Date'].min(), test_df['Date'].max()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "41395" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cut = train_df['Date'][(train_df['Date'] == train_df['Date'][len(test_df)])].index.max()\n", "cut" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_idx = range(cut)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 5263\n", "1 6064\n", "2 8314\n", "3 13995\n", "4 4822\n", "Name: Sales, dtype: int64" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[dep_var].head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TabularList.from_df(df, path=path, cat_names=cat_vars, cont_names=cont_vars, procs=procs,)\n", " .split_by_idx(valid_idx)\n", " .label_from_df(cols=dep_var, label_cls=FloatList, log=True)\n", " .add_test(TabularList.from_df(test_df, path=path, cat_names=cat_vars, cont_names=cont_vars))\n", " .databunch())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "doc(FloatList)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_log_y = np.log(np.max(train_df['Sales'])*1.2)\n", "y_range = torch.tensor([0, max_log_y], device=defaults.device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = tabular_learner(data, layers=[1000,500], ps=[0.001,0.01], emb_drop=0.04, \n", " y_range=y_range, metrics=exp_rmspe)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TabularModel(\n", " (embeds): ModuleList(\n", " (0): Embedding(1116, 81)\n", " (1): Embedding(8, 5)\n", " (2): Embedding(4, 3)\n", " (3): Embedding(13, 7)\n", " (4): Embedding(32, 11)\n", " (5): Embedding(3, 3)\n", " (6): Embedding(26, 10)\n", " (7): Embedding(27, 10)\n", " (8): Embedding(5, 4)\n", " (9): Embedding(4, 3)\n", " (10): Embedding(4, 3)\n", " (11): Embedding(24, 9)\n", " (12): Embedding(9, 5)\n", " (13): Embedding(13, 7)\n", " (14): Embedding(53, 15)\n", " (15): Embedding(22, 9)\n", " (16): Embedding(7, 5)\n", " (17): Embedding(7, 5)\n", " (18): Embedding(4, 3)\n", " (19): Embedding(4, 3)\n", " (20): Embedding(9, 5)\n", " (21): Embedding(9, 5)\n", " (22): Embedding(3, 3)\n", " (23): Embedding(3, 3)\n", " )\n", " (emb_drop): Dropout(p=0.04)\n", " (bn_cont): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): Linear(in_features=233, out_features=1000, bias=True)\n", " (1): ReLU(inplace)\n", " (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.001)\n", " (4): Linear(in_features=1000, out_features=500, bias=True)\n", " (5): ReLU(inplace)\n", " (6): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): Dropout(p=0.01)\n", " (8): Linear(in_features=500, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(data.train_ds.cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAELCAYAAADURYGZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xd8XNWd9/HPTxr1alnFsuXeC67CVJtiCBAglAAJgSwJbEiyG9gku5vdZ/Nsks1usklITzbFQIDNJmSpT2gxGAi2ARtbNrjj3uSiaku2ujTn+WPGIBxZzTNzp3zfr9e8PHPnju5X45F+uuece4455xARkcSV5HUAERHxlgqBiEiCUyEQEUlwKgQiIglOhUBEJMGpEIiIJLiwFQIz+42ZVZvZpm7bbjazzWbmN7PycB1bRET6L5xnBA8DV56ybRNwI7A8jMcVEZEB8IXrCzvnlpvZmFO2bQUws3AdVkREBkh9BCIiCS5sZwRnyszuBu4GyMrKmjdlyhSPE4mIxJa1a9fWOueK+tovaguBc24xsBigvLzcVVRUeJxIRCS2mNm+/uynpiERkQQXzuGjjwIrgclmVmlmd5nZDWZWCZwHPG9mL4br+CIi0j/hHDV062meejpcxxQRkYFT05CISIJTIRARSXAqBCIiCU6FQEQkwakQiIhEoeb2Tv7t2c3sq2sK+7FUCEREotDzGw7z0Bt7qWpsC/uxVAhERKLQYxUHGFeYxdljhoT9WCoEIiJRZlfNCdbsPcotZ4+MyGzNKgQiIlHmsTUHSE4ybpw7IiLHUyEQEYkiHV1+nlxXyaVTiinOSY/IMVUIRESiyKvvVlN7op2Pnz0yYsdUIRARiSL/u+YAxTlpXDSpz2UEQkaFQEQkShxpaOW1bdXcNK8MX3Lkfj2rEIiIRIkn1h7A7+CW8sg1C4EKgYhIVHDO8cTaSs4dV8CYwqyIHluFQEQkCmw62Mjeumaunx2ZIaPdqRCIiESB5zYcwpdkXDljWMSPrUIgIuIx5xzPbTjMgomF5GemRvz4KgQiIh5bt/8YB4+1cM3M4Z4cX4VARMRjz204RGpyEpdPL/Hk+CoEIiIe6vI7nt9wmIsmF5GbnuJJBhUCEREPrdlbT/XxNq6d5U2zEKgQiIh46rkNh0hPSWLRlGLPMqgQiIh4pLPLz582HmHRlBKy0nye5QhbITCz35hZtZlt6ratwMyWmtmO4L/hX3pHRCRKrdxdR11TO9fMLPU0RzjPCB4Grjxl2z8DrzjnJgKvBB+LiCSkl7dUkZ6SxCUeNgtBGAuBc245UH/K5uuAR4L3HwGuD9fxRUSi3YodtZw7bijpKcme5oh0H0GJc+4wQPBfb8ugiIhHDtQ3s7u2iQUTI7fuwOlEbWexmd1tZhVmVlFTU+N1HBGRkHp9Zy0ACycWepwk8oWgysxKAYL/Vp9uR+fcYudcuXOuvKjI+4opIhJKK3bUMCw3nQnF2V5HiXgheAa4I3j/DuCPET6+iIjnuvyO13fUsmBiIWbmdZywDh99FFgJTDazSjO7C/gOcLmZ7QAuDz4WEUkoGyqP0djayYIIrkvcm7BdweCcu/U0Ty0K1zFFRGLBih21mMGFE7zvH4Ao7iwWEYlXK3bUMGN4HgVZkV97oCcqBCIiEXS8tYN1+4+xIApGC52kQiAiEkErd9XR5XdRcf3ASSoEIiIRtGJHLZmpycwdne91lPeoEIiIRNCKHTWcO24oaT5vp5XoToVARCRCdlQdZ29dMxdFybDRk1QIREQi5LGKA/iSzPNpp0+lQiAiEgHtnX6eWneQy6aWMDQ7zes4H6BCICISAa++W0VdUzsfO3uk11H+ggqBiEgE/O+aAwzLTWdhlPUPgAqBiEjYHWloZdn2Gm6aV0ZykveTzJ1KhUBEJMyeWHsAv4NbyqOvWQhUCEREwsrvdzxWUcl544Yyamim13F6pEIgIhJGq/bUsb++OSo7iU9SIRARCaP/XXOAnHQfV84Y5nWU01IhEBEJk/qmdv608Qg3zBlBekr0TClxKhUCEZEwebziAO1dfm4/d7TXUXqlQiAiEgZ+v+P3q/czf0wBk0pyvI7TKxUCEZEweH1nLfvqmrnt3FFeR+mTCoGISBj87q19FGSlRnUn8UkqBCIiIXakoZWXt1Zzc3lZVK07cDoqBCIiIfaHNfvp8jtumx/dncQnqRCIiIRQZ5efP6w+wMJJRVF7JfGpVAhERELoz9tqONLYym3nRH8n8UmeFAIz+zsz22Rmm83si15kEBEJh2fWH2JIZgqXTin2Okq/RbwQmNkM4DPAfGAWcI2ZTYx0DhGRUGtp7+KVrVVcOaOUlOTYaXDxIulUYJVzrtk51wksA27wIIeISEi9+m41ze1dXBtlaxL3xYtCsAlYaGZDzSwT+DAQvdPyiYj003MbDlGYncY544Z6HWVAfJE+oHNuq5l9F1gKnADWA52n7mdmdwN3A4waFTudLiKSmE60dfLqu9V87OyRUbkKWW88acRyzj3onJvrnFsI1AM7ethnsXOu3DlXXlQUfWt8ioh098rWKto6/Vwzc7jXUQYs4mcEAGZW7JyrNrNRwI3AeV7kEBEJlWfXH2ZYbjrlo4d4HWXAPCkEwJNmNhToAP7WOXfUoxwiImesoaWD5dtr+OR5o0mKsWYh8KgQOOcWeHFcEZFwWLqlivYuP9fE2Gihk2JnoKuISJR6bsMhyoZkMHtkvtdRBkWFQETkDNSdaOP1HbVcPbMUs9hrFgIVAhGRM/LM+kN0+h03zinzOsqgqRCIiJyBp9YdZMaIXCYPi+7lKHujQiAiMkjbq46z8WBDTJ8NgAqBiMigPbmuEl+S8ZHZsXcRWXcqBCIig9Dld/y/tw9y8eQiCrPTvI5zRlQIREQG4c1dtVQ1tnHj3NhuFgIVAhGRQXlq3UFy030smho7C9CcjgqBiMgAnWjrZMmmI1w7azhpvmSv45wxFQIRkQH608bDtHR0xUWzEKgQiIgM2DPrDzF6aCZzR8XmlBKnUiEQERmAo03tvLmrjg+fFbtTSpxKhUBEZACWbqmiy++4+qzYnGm0JyoEIiID8MKmw4wsyGD68Fyvo4SMCoGISD81NHfwxs7auGoWAhUCEZF+W7q1io4ux4dnxE+zEKgQiIj02wsbDzMiP4OZZXleRwkpFQIRkX5obO1gxY4aPnzWsLhqFgIVAhGRfnl5S6BZ6Ko4Gi10kgqBiEg/vLDxCMPz0pkTo+sS90aFQESkD8dbO1i+o4YrZ8TXaKGTVAhERPqwclcd7Z1+PjS9xOsoYaFCICLSh1W760nzJTEnTuYWOpUnhcDMvmRmm81sk5k9ambpXuQQEemPt/bUMXfUkLiYcronES8EZjYCuBcod87NAJKBj0c6h4hIfzS0dLDlcCPnjCvwOkrYeNU05AMyzMwHZAKHPMohItKrNXvqcQ7OHTfU6yhhE/FC4Jw7CHwf2A8cBhqccy9FOoeISH+8taeOVF8Ss+Nw2OhJXjQNDQGuA8YCw4EsM7u9h/3uNrMKM6uoqamJdEwRESDQUTx7ZD7pKfHZPwD9LARmNt7M0oL3Lzaze81ssOXxMmCPc67GOdcBPAWcf+pOzrnFzrly51x5UVHRIA8lIjJ4ja0dbD7UENfNQtD/M4IngS4zmwA8SOCv+d8P8pj7gXPNLNMCV2YsArYO8muJiIRNxd56/A7OHRu/HcXQ/0Lgd851AjcAP3bOfQkY1IQbzrm3gCeAdcDGYIbFg/laIiLh9NbuelKTk5gzaojXUcLK18/9OszsVuAO4NrgtpTBHtQ593Xg64N9vYhIJKzaU8+skXlkpMZv/wD0/4zg08B5wLecc3vMbCzwP+GLJSLirRNtnWw62MA5Y+O7fwD6eUbgnNtC4CKwk6N+cpxz3wlnMBERL1XsrafL7+K+oxj6P2roNTPLNbMCYD3wkJn9MLzRRES8s2p3Pb4kY+7o+L1+4KT+Ng3lOecagRuBh5xz8wgMAxURiUurdtcxsyyPzNT+dqXGrv4WAp+ZlQK3AM+FMY+IiOcamjvYUHmMCyYUeh0lIvpbCL4JvAjscs6tMbNxwI7wxRIR8c7rO2vxO7hoUmJczNrfzuLHgce7Pd4NfDRcoUREvLRsezU56b64nl+ou/52FpeZ2dNmVm1mVWb2pJmVhTuciEikOedYvr2WCycU4ktOjLW7+vtdPgQ8Q2CSuBHAs8FtIiJxZXvVCY40tiZMsxD0vxAUOececs51Bm8PA4nzLolIwli+PTDb8UIVgr9Qa2a3m1ly8HY7UBfOYCIiXli2vYaJxdkMz8/wOkrE9LcQ3Elg6OgRAovJ3ERg2gkRkbjR3N7J6j31CdUsBP0sBM65/c65jzjnipxzxc656wlcXCYiEjfe2l1Pe5c/oZqF4MxWKPtyyFKIiESBZdtrSE9JYn6crz9wqjMpBBayFCIiUWD59hrOGTs0rpel7MmZFAIXshQiIh47UN/M7tqmhOsfgD6uLDaz4/T8C9+AxOlSF5G4tywBh42e1GshcM7lRCqIiIiXlm2vYUR+BuOLsryOEnGJcf20iEgv2jv9rNxVx0WTizBLvO5PFQIRSXjr9h/lRFtnQvYPgAqBiAjLttfgSzLOHx//y1L2RIVARBLesm01zBs9hJz0FK+jeEKFQEQSWnVjK1sON3LR5MRsFgIVAhFJcMt31AKJsxpZTyJeCMxsspm90+3WaGZfjHQOEREI9A8UZqcxdViu11E806+lKkPJObcNmA1gZsnAQeDpSOcQEenyO17fUcMlU4pJSkq8YaMned00tAjY5Zzb53EOEUlAGw82cLS5I6GbhcD7QvBx4FGPM4hIglq2rQYzWDBRhcATZpYKfAR4/DTP321mFWZWUVNTE9lwIpIQlm2vZmZZPgVZqV5H8ZSXZwRXAeucc1U9PemcW+ycK3fOlRcVJXa1FpHQO9bczjsHjnHRxEKvo3jOy0JwK2oWEhGPLNteg9/BJVOKvY7iOU8KgZllApcDT3lxfBGRl7dWU5idyqyyfK+jeC7iw0cBnHPNQGJO6iEinuvo8rNsWzVXTB+W0MNGT/J61JCISMSt3XeUxtZOFk1VsxCoEIhIAnplaxWpyUlcmODDRk9SIRCRhPPKu9WcM66A7DRPWsejjgqBiCSUPbVN7K5pYpFGC71HhUBEEsorWwOXLi2aWuJxkuihQiAiCeXVd6uZVJLNyIJMr6NEjYRqIOvyOxpaOjjW3M6xlg4aWjpobuuiqb2T5rZOhmansXBiEXmZiblKkUi8a2ztYPWeev56wTivo0SVhCgER5vaefjNvTyyci/Hmjt63Tc5ySgfPYSLJhfh9zsqj7ZQebQFh+P2c0aHZNyxc46jzR1UHm2m5ngbDS0dNLZ0cKKtk7yMFIpz0ynJTSc9JYn9dc3srw/cstN8zBiRx4zheYwsyMBM459FBmL59ho6/U7DRk8R14Xg0LEWHlixh0dX76elo4vLphZz4YRC8jNTyctMIS8jhaxUH5mpyWSmJrO3rolXtlbz6rvVfG/JNgAKs1MZMSSTo03tfP536xhXlMXnLhrPuWOH0tzRSVNbF01tnRxpbOVIQyuHG1qpbmylrqmdo83t1De1Y0Bmqo+M1GSSDA43tNLc3jWg7yUnzUdLRxedfhd4nO5jQnE24wqzGV+cxbjCLEYPzWL00EwyUz/43+qcU9EQAV7ZWk1+ZgpzRw3xOkpUietC8MOl23n67YNcN3s4n7toPJNKcnrdf2h2GvNGF/CVK6dQd6LtvV/eEGhWemHjYX7x2i6+8sSG036NwuxUinLSKcxOZfTQTIZkBmY1bG7vpLm9iy6/46JJxZQNyaBsSAYluenkZaSQm5FCdpqPhpYOqhpbqT7eSlNbFyMLMhldkEl+ZgptnX62Vx1n08FGNh9qYHdNEyt21PDkuspTMqThSzKa2ztp6ejCOZhUksNZI/KYUZbHzBF5TCnNIc2X/IHX+f2Olo4usjSkTuJQR5efV7ZWcdm0EpJ1NfEHmHPO6wx9Ki8vdxUVFQN+3cFjLTjnKBsSuk4h5xxv7KzjSGPre2cSWWk+SnLSKc5NIz0lue8vEmLHWzvYW9vMvvom9tU1c6C+GecgIzWZjNRkuvyOrYcb2RRchAMgJdmYWprLtNJcjjV3sKe2ib11TbR1+slJ81Gan05pXgZjC7OYWprD1NJcJpXkePL9iYTCih01fPLB1dz/V+VcPi0xRgyZ2VrnXHlf+8X1n34j8jNC/jXNjAujbNranPQUzirL46yyvF73c85x8FgLGysbWF/ZwPoDx1iy+QgFWamMK8xi4aRChmSlUt3YxqFjLRxqaGHN3vr3mrGSk4xJJTnMGZXP7JH5TC7JIT/YxJaTnqK/siSqLdl0hMzUZBZE2c9vNIjrQiAfZGaUDcmkbEgmV51V2q/X+P2O/fXNbD3cyOZDjayvPMaz6w/x+7f2/8W+WcGzo6w0H9lpvveKRH5mCiPyM5k+PJdpw3MpzE4L9bcm0qsuv+PFzVVcMrlYZ7U9UCGQXiUlGWMKsxhTmPVe8fD7HbtrT7CntpmG4DDchpYOTrR20tzeSVN7F8dbA9sOHm3haHP7e01SACW5aUwrzWVKaS5TS3MZX5RFQVYq+RmppKckqWNbQu7t/UepPdHGFTOGeR0lKqkQyIAlJRkTinOYUNx753t3Dc0dbD7cwJZDgTOLrYcbWbGj9r1RUCel+pIozUtnzNAsxhYGRkENy02nKCeN4pxAx7ov2fAlGylJSZpCWPplyaYjpCYncclkTTLXExUCiYi8zBTOH1/I+ePfb59t7/Szs/oE++qaONbSwbHmwMV+lcda2FvbRMXeepr6GGZbnJPGuKIsxhVlM64wi4klOUwqyWZYbrrOLAQI9I0t2XyECycWkpOui0V7okIgnkn1JTEt2G/QE+ccdU3tVDe2UXOijerGVhpbO+ns8tPpd7R1+jl4tIU9tSd4YePhD1wsmJ3mY1ppLvPGDKF89BDmjR5CfmZiL1CeqLYcbqTyaAv3XjrR6yhRS4VAopaZUZid1u/O5boTbeyoPhG4VR1nfWUD9y/fzS+DzU+jh2YyY3ge00fkBv4dnstQdVzHvRc3HSHJ0NXEvVAhkLgxNDuNodlpnDvu/VVQW9q7WF95jLX7jrLpYAMbDh7j+Y2H33u+NC+d6cMD10iMLQw0MY0vytLZQxxZsvkI88cWqOj3QoVA4lpGajLnjhv6geLQ0NzB5kMNbD4UuEJ706FGXttW84GO62mluVw8uYiLJxczd1Q+vmRN1BuLdtecYHvVCb5x7TSvo0Q1FQJJOHmZKZw/oZDzJ7zfcd3R5acy2N+w9fBxlm2v4dfLd/OL13aRkZLM9OG5nFWWx6yyfC6YUEhRjv66jAUvB9ceuHy6ho32Jq6nmBA5E42tHbyxo5bVe+vZWNnApkMNtHb4STKYP7aAq88q5YoZwyjOSfc6qpzGJ+5fRd2Jdl780kKvo3hCU0yInKHc9BSuOqv0vQvpOrv8vHvkOC9tqeKFjYf51z9u5mvPbGZWWT6XTyth0dRiJpfkaNhqlDjR1smavfXceeFYr6NEPRUCkX7yJScF1oMYkceXL5/E9qrjLNl0hFe2VnHfi9u478VtlA3J4LKpJXxoWglnjy0gRX0LnnlzZy0dXY6LJ2m0UF88KQRmlg88AMwAHHCnc26lF1lEBmtSSQ6TSnK4d9FEqhpbeWVrNS9vreL3q/fz8Jt7yctI4db5o7jrwrHqU/DAn7fVkJ3mo3yM1h7oi1dnBD8BljjnbjKzVECLh0pMK8lN5xPnjOIT54yiub2T5dtreXb9IRYv38VDb+zhY2eP5O6F40I6JbqcnnOOZduquWDCUJ2V9UPEC4GZ5QILgU8BOOfagfZI5xAJl8xUH1fOGMaVM4axp7aJX722i0dX7+fR1fu547wxfOHSCbpOIcx2VJ/gUEMr9y7S1cT94UWpHAfUAA+Z2dtm9oCZZXmQQyTsxhZm8d2bZrLsHy/hhjkjePCNPVx032s8sGI37Z1+r+PFrT+/Ww3ARZpkrl+8KAQ+YC7wS+fcHKAJ+OdTdzKzu82swswqampqIp1RJKSG52fwvZtm8cK9C5hZlsd/PL+VK3+8nGXb9dkOh9e21TBlWA6leaFfnCoeeVEIKoFK59xbwcdPECgMH+CcW+ycK3fOlRcVqapLfJhamstv7zqHhz59Ng644zerufu/KzhQ3+x1tLhxvLWDin31OhsYgIgXAufcEeCAmU0ObloEbIl0DhEvXTK5mCVfXMA/XTmF13fWctkPl/Fff96p5qIQeGNnHR1djksma9hof3nVnX4P8Dsz2wDMBr7tUQ4Rz6T5kvn8xeN55e8vYtHUYu57cRtX/3QFa/bWex0tpi3bXk1Omo95ozVstL88KQTOuXeCzT4znXPXO+eOepFDJBqU5mXwi9vm8eAd5TS3d3Hzr1byL09v5ERbp9fRYo7f7/jzuzVcMKFQw0YHQO+USJRYNLWEpV9eyGcWjOXR1fu56ifLWb1HZwcDsWpPHUcaW7nqLE0yNxAqBCJRJDPVx1evnsbjnz2PJDM+tngl335hK60dvS/ZKQFPrj1ITpqPD01TIRgIFQKRKFQ+poAX7l3AbeeMYvHy3Xzywbdo6LYUp/ylprZO/rTpMFfPLCUjNdnrODFFhUAkSmWl+fiP68/i55+Yw/oDDdz86zc5dKzF61hRa8mmIzS3d/HReWVeR4k5KgQiUe6amcN5+M6zOXyslRt/8Sbbq457HSkqPbG2ktFDMynXaKEBUyEQiQHnjy/ksc+dh985bvrlm1RoiOkHVB5tZuXuOj46t0zrQQyCCoFIjJhamstTf3M+Q7PTuP3Bt3htW7XXkaLG0+sOAnDDnBEeJ4lNKgQiMaRsSCaPf+48xhVm89ePVPDM+kNeR/Kcc44n11Vy3rihjCzQNN+DoUIgEmMKs9P4w2fPZe6oIfzdH97md2/t8zqSp9buO8reumZ1Ep8BFQKRGJSbnsIjd87nksnFfPXpTTywYrfXkTzzhzUHyExN5qoZunZgsFQIRGJURmoyv7p9HlfNGMZ/PL+Vn7+6w+tIEVfd2Mof3znIzfPKyErTEuyDpUIgEsNSfUn87NY53DBnBN9/aTv3vfguzjmvY0XMIyv30ul33HnhWK+jxDSVUJEY50tO4gc3zyI9JYn/+vMu0nzJCbFEY3N7J/+zaj9XTBvG6KFa5PBMqBCIxIGkJOPbN5xFW4efHy7dzqiCTK6P86GUT6ytpKGlg88s1NnAmVIhEIkTZsZ3PjqTQw0tfOWJDQzLS+fccUO9jhUWXX7HAyv2MGdUPvNGF3gdJ+apj0AkjqT6kvj17eWMLMjgs79dy66aE15HCoulW46wv76ZzywY53WUuKBCIBJn8jJTePjT80lJNj710GqqG1u9jhRy96/Yw8iCDK6YriGjoaBCIBKHRhZk8uAdZ1N3op07HlpDQ0v8TGH99v6jrN13lDsvGEtykuYVCgUVApE4NWtkPr/+5Dx2Vh/nM/9dETeL2zzw+h5y0n3cUj7S6yhxQ4VAJI4tmFjED2+ZzZq99dzz6Nt0dvm9jnRGDh5rYcmmI9w6f5QuIAshFQKROHftrOF849rpLN1Sxd8/vp6OGC4Gj7y5F4A7zh/jaY54o5IqkgDuOH8MTe2dfG/JNprbu/jZrXNIT4mt5Ryb2jp5dPV+rpwxjBH5GV7HiSs6IxBJEH9z8QS+eV3gzOCuR9bQ1NbpdaQBebziAMdbO7lL00mEnAqBSAL5q/PG8MNbZrFqdz23P/hWzAwt7fI7fvPGXuaMymfuKC1FGWqeFAIz22tmG83sHTOr8CKDSKK6cW4Zv7htLlsPN3LFj5fzp42HvY7Up5e3VrG/vpm/vlAXkIWDl2cElzjnZjvnyj3MIJKQrpg+jOfuWcDIgkw+/7t1fPmxd2hsjc5rDRqaO/jR0u2MyM/giuklXseJS2oaEklQE4qzefLz53Pvoon88Z1DXPaDZTy25gBd/uiZxrr6eCsfW7yS3TVNfPO66fiS9SsrHLx6Vx3wkpmtNbO7PcogkvBSkpP48uWTeOrz5zNiSAZfeXIDV/90Bcu213gdjQP1zdzyq5Xsq2vmwU+Vs2iqzgbCxbxYxMLMhjvnDplZMbAUuMc5t/yUfe4G7gYYNWrUvH37EntdVpFwc87xwsYjfHfJu+yvb+bqmaV8+/qzyMtMiXiWzYcauOvhCprbO3no0/OZN1odxINhZmv70/zuSSH4QACzbwAnnHPfP90+5eXlrqJCfcoikdDW2cX9y3fz45d3UJyTxo8+NptzIjSddWB66d18/6VtDMlM5ZE75zO1NDcix45H/S0EEW8aMrMsM8s5eR/4ELAp0jlEpGdpvmS+cOlEnvz8+aT6kvj4/av4wUvb8Ie57+DgsRZue2AV//mnd1k0pYQlX1yoIhAhXlxZXAI8bWYnj/9759wSD3KISC9mjczn+XsX8PVnNvOzV3dS19TOt66fQfBnN2TaOrt45M29/OyVnfid43s3zeTmeWUhP46cXsQLgXNuNzAr0scVkYHLSvNx300zKcxO41fLduFLMv7tI9ND8kvaOcfSLVV864Wt7Ktr5tIpxXzj2umMGpoZguQyEJprSER6ZWb805WT6fL7uX/FHpKTjK9dM+2MisHBYy38y1MbWba9hgnF2Txy53wumlQUwtQyECoEItInM+NfPjyVLj/85o09dPkdX7tm2oDH9TvneHT1Ab79wlb8zvH1a6fxyXNH6/oAj6kQiEi/mBn/es1UfMnG4uW72V3TxM8/MYf8zNQ+X9vU1skbO2t5ZOVe3thZx/njh/Ldj85kZIGagaKBCoGI9NvJM4MJxdn836c3cd1/vcH9f1XOpJKc9/ZxzlFzvI0thxvZcriRN3fWsXpPPe1dfnLSfXzrhhl8Yv4odQZHERUCERmwW8pHMr4om8/+di3X/fwNSvPTA084ONbSQX1T+3v7TizO5lMXjOHiyUWUjy4g1admoGijQiAigzJv9BCevecCfvLyDo63doKBAdlpPqYMy2FKaS5Th+V6cmWyDIwKgYgMWmleBt/56EyvY8gZ0jkR7OdCAAAIO0lEQVSaiEiCUyEQEUlwKgQiIglOhUBEJMGpEIiIJDgVAhGRBKdCICKS4FQIREQSnOdLVfaHmdUAx4CGU57K62NbX/dP/lsI1A4iWk/HH0i+vjL3lLX784PJPZjMveXq6XFPWc/kvY5k5u73o/3zES2Ze9quz0ffIvH5yHfO9T2/t3MuJm7A4oFu6+t+t38rQpUplJlPk7X7vgPOPZjMveXqz/t7pu91JDPH0ucjWjLr8xH9n4++brHUNPTsILb1db+n159ppr6eH0jm7o+9zNzT9t4e95T1THJHMnP3+9H++YiWzD1t1+ejb5H8fPQqJpqGws3MKpxz5V7nGKhYzK3MkRGLmSE2c8di5lPF0hlBOC32OsAgxWJuZY6MWMwMsZk7FjN/gM4IREQSnM4IREQSXNwVAjP7jZlVm9mmQbx2npltNLOdZvZT67aWnpndY2bbzGyzmX0v2jOb2TfM7KCZvRO8fTiUmcOVu9vz/2BmzswKQ5c4bO/1v5vZhuD7/JKZDY+BzPeZ2bvB3E+bWX4MZL45+PPnN7OQtcmfSdbTfL07zGxH8HZHt+29fuY9NZhhT9F8AxYCc4FNg3jtauA8Agst/Qm4Krj9EuBlIC34uDgGMn8D+IdYe6+Dz40EXgT2AYXRnhnI7bbPvcCvYiDzhwBf8P53ge/GQOapwGTgNaDc66zBHGNO2VYA7A7+OyR4f0hv31c03OLujMA5txyo777NzMab2RIzW2tmK8xsyqmvM7NSAj/QK13gf+2/geuDT38e+I5zri14jOoYyBx2Ycz9I+ArQMg7sMKR2TnX2G3XrFDnDlPml5xzncFdVwFlMZB5q3NuWyhznknW07gCWOqcq3fOHQWWAld6/bPal7grBKexGLjHOTcP+AfgFz3sMwKo7Pa4MrgNYBKwwMzeMrNlZnZ2WNMGnGlmgC8ET/1/Y2ZDwhf1A84ot5l9BDjonFsf7qDdnPF7bWbfMrMDwG3A18KY9aRQfD5OupPAX6jhFsrM4dafrD0ZARzo9vhk/mj5vnoU92sWm1k2cD7weLcmubSedu1h28m/7HwETvPOBc4GHjOzccHKHnIhyvxL4N+Dj/8d+AGBH/iwOdPcZpYJfJVAs0VEhOi9xjn3VeCrZvZ/gC8AXw9x1PeDhChz8Gt9FegEfhfKjH8RJISZw623rGb2aeDvgtsmAC+YWTuwxzl3A6fP7/n31Zu4LwQEznqOOedmd99oZsnA2uDDZwj84ux+elwGHArerwSeCv7iX21mfgLzi9REa2bnXFW3190PPBemrN2dae7xwFhgffAHsAxYZ2bznXNHojTzqX4PPE8YCwEhyhzsyLwGWBSuP2q6CfX7HE49ZgVwzj0EPARgZq8Bn3LO7e22SyVwcbfHZQT6Eirx/vs6Pa87KcJxA8bQreMHeBO4OXjfgFmned0aAn/1n+zM+XBw++eAbwbvTyJw6mdRnrm02z5fAv4QC+/1KfvsJcSdxWF6ryd22+ce4IkYyHwlsAUoCsfnIpyfDULcWTzYrJy+s3gPgRaEIcH7Bf39zHt18zxAGD58jwKHgQ4CVfguAn9lLgHWBz/8XzvNa8uBTcAu4Oe8f8FdKvA/wefWAZfGQObfAhuBDQT+0ioNZeZw5T5ln72EftRQON7rJ4PbNxCY32VEDGTeSeAPmneCt1CPdApH5huCX6sNqAJe9DIrPRSC4PY7g+/vTuDTA/nMe3XTlcUiIgkuUUYNiYjIaagQiIgkOBUCEZEEp0IgIpLgVAhERBKcCoHEJDM7EeHjPWBm00L0tbosMFPpJjN7tq+ZP80s38z+JhTHFumJho9KTDKzE8657BB+PZ97fxK2sOqe3cweAbY7577Vy/5jgOecczMikU8Sj84IJG6YWZGZPWlma4K3C4Lb55vZm2b2dvDfycHtnzKzx83sWeAlM7vYzF4zsycsMFf/707OGR/cXh68fyI4ydx6M1tlZiXB7eODj9eY2Tf7edaykvcn3Ms2s1fMbJ0F5q2/LrjPd4DxwbOI+4L7/mPwOBvM7N9C+DZKAlIhkHjyE+BHzrmzgY8CDwS3vwssdM7NITAz6Le7veY84A7n3KXBx3OALwLTgHHABT0cJwtY5ZybBSwHPtPt+D8JHr/PeWSC8+wsInDlN0ArcINzbi6BNTB+ECxE/wzscs7Nds79o5l9CJgIzAdmA/PMbGFfxxM5nUSYdE4Sx2XAtG4zRuaaWQ6QBzxiZhMJzPiY0u01S51z3eeiX+2cqwQws3cIzEHz+inHaef9SfzWApcH75/H+3PM/x74/mlyZnT72msJzFkPgTlovh38pe4ncKZQ0sPrPxS8vR18nE2gMCw/zfFEeqVCIPEkCTjPOdfSfaOZ/Qz4s3PuhmB7+2vdnm465Wu0dbvfRc8/Ix3u/c610+3Tmxbn3GwzyyNQUP4W+CmBtQyKgHnOuQ4z2wuk9/B6A/7TOffrAR5XpEdqGpJ48hKBtQAAMLOT0wjnAQeD9z8VxuOvItAkBfDxvnZ2zjUQWNryH8wshUDO6mARuAQYHdz1OJDT7aUvAncG583HzEaYWXGIvgdJQCoEEqsyzayy2+3LBH6plgc7ULcQmD4c4HvAf5rZG0ByGDN9Efiyma0GSoGGvl7gnHubwAyXHyewOEy5mVUQODt4N7hPHfBGcLjpfc65lwg0Pa00s43AE3ywUIgMiIaPioRIcIW1FuecM7OPA7c6567r63UiXlMfgUjozAN+Hhzpc4wwLw0qEio6IxARSXDqIxARSXAqBCIiCU6FQEQkwakQiIgkOBUCEZEEp0IgIpLg/j+Al2B53j4b5wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 11:27

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossexp_rmspe
10.0235870.0209410.140551
20.0176780.0234310.132211
30.0174530.0169290.120169
40.0126080.0162960.109245
50.0102220.0112380.105433
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5, 1e-3, wd=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_losses(skip_start=10000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('1');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 11:32

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossexp_rmspe
10.0122230.0143120.116988
20.0120010.0177890.117619
30.0114020.0355960.114396
40.0100670.0151250.113652
50.0091480.0313260.116344
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5, 3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 11:31

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossexp_rmspe
10.0118400.0132360.110483
20.0107650.0576640.129586
30.0101010.0427440.111584
40.0088200.1168930.135458
50.0091440.0179690.126323
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5, 3e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(10th place in the competition was 0.108)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_preds=learn.get_preds(DatasetType.Test)\n", "test_df[\"Sales\"]=np.exp(test_preds[0].data).numpy().T[0]\n", "test_df[[\"Id\",\"Sales\"]]=test_df[[\"Id\",\"Sales\"]].astype(\"int\")\n", "test_df[[\"Id\",\"Sales\"]].to_csv(\"rossmann_submission.csv\",index=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }