{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.tabular.all import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Rossmann" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create the feature-engineered train_clean and test_clean from the Kaggle competition data, run `rossman_data_clean.ipynb`. One important step that deals with time series is this:\n", "\n", "```python\n", "add_datepart(train, \"Date\", drop=False)\n", "add_datepart(test, \"Date\", drop=False)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Config().data/'rossmann'\n", "train_df = pd.read_pickle(path/'train_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
Store12345
DayOfWeek55555
Date2015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:002015-07-31 00:00:00
Sales526360648314139954822
Customers5556258211498559
..................
StateHoliday_bw00000
SchoolHoliday_bw55555
Promo_fw51511
StateHoliday_fw00000
SchoolHoliday_fw71511
\n", "

69 rows × 5 columns

\n", "
" ], "text/plain": [ " 0 1 \\\n", "Store 1 2 \n", "DayOfWeek 5 5 \n", "Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n", "Sales 5263 6064 \n", "Customers 555 625 \n", "... ... ... \n", "StateHoliday_bw 0 0 \n", "SchoolHoliday_bw 5 5 \n", "Promo_fw 5 1 \n", "StateHoliday_fw 0 0 \n", "SchoolHoliday_fw 7 1 \n", "\n", " 2 3 \\\n", "Store 3 4 \n", "DayOfWeek 5 5 \n", "Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n", "Sales 8314 13995 \n", "Customers 821 1498 \n", "... ... ... \n", "StateHoliday_bw 0 0 \n", "SchoolHoliday_bw 5 5 \n", "Promo_fw 5 1 \n", "StateHoliday_fw 0 0 \n", "SchoolHoliday_fw 5 1 \n", "\n", " 4 \n", "Store 5 \n", "DayOfWeek 5 \n", "Date 2015-07-31 00:00:00 \n", "Sales 4822 \n", "Customers 559 \n", "... ... \n", "StateHoliday_bw 0 \n", "SchoolHoliday_bw 5 \n", "Promo_fw 1 \n", "StateHoliday_fw 0 \n", "SchoolHoliday_fw 1 \n", "\n", "[69 rows x 5 columns]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head().T" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "844338" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = len(train_df); n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experimenting with a sample" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = np.random.permutation(range(n))[:2000]\n", "idx.sort()\n", "small_df = train_df.iloc[idx]\n", "small_cont_vars = ['CompetitionDistance', 'Mean_Humidity']\n", "small_cat_vars = ['Store', 'DayOfWeek', 'PromoInterval']\n", "small_df = small_df[small_cat_vars + small_cont_vars + ['Sales']].reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
04715Feb,May,Aug,Nov5300.0509116.0
16565Jan,Apr,Jul,Oct410.0544576.0
211125NaN1880.0619626.0
34594Feb,May,Aug,Nov250.08610847.0
411084NaN540.0517187.0
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "0 471 5 Feb,May,Aug,Nov 5300.0 50 \n", "1 656 5 Jan,Apr,Jul,Oct 410.0 54 \n", "2 1112 5 NaN 1880.0 61 \n", "3 459 4 Feb,May,Aug,Nov 250.0 86 \n", "4 1108 4 NaN 540.0 51 \n", "\n", " Sales \n", "0 9116.0 \n", "1 4576.0 \n", "2 9626.0 \n", "3 10847.0 \n", "4 7187.0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
1000753NaN22440.0684823.0
1001793NaN3320.0683968.0
10023903NaN1600.0719571.0
10034003Jan,Apr,Jul,Oct70.0737629.0
10048253Jan,Apr,Jul,Oct380.0783422.0
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "1000 75 3 NaN 22440.0 68 \n", "1001 79 3 NaN 3320.0 68 \n", "1002 390 3 NaN 1600.0 71 \n", "1003 400 3 Jan,Apr,Jul,Oct 70.0 73 \n", "1004 825 3 Jan,Apr,Jul,Oct 380.0 78 \n", "\n", " Sales \n", "1000 4823.0 \n", "1001 3968.0 \n", "1002 9571.0 \n", "1003 7629.0 \n", "1004 3422.0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_df.iloc[1000:].head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = [list(range(1000)),list(range(1000,2000))]\n", "to = TabularPandas(small_df.copy(), Categorify, cat_names=small_cat_vars, cont_names=small_cont_vars, splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
0283515300.0509116.0
138952410.0544576.0
2653501880.0619626.0
327241250.08610847.0
464940540.0517187.0
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "0 283 5 1 5300.0 50 \n", "1 389 5 2 410.0 54 \n", "2 653 5 0 1880.0 61 \n", "3 272 4 1 250.0 86 \n", "4 649 4 0 540.0 51 \n", "\n", " Sales \n", "0 9116.0 \n", "1 4576.0 \n", "2 9626.0 \n", "3 10847.0 \n", "4 7187.0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.train.items.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySales
1000463022440.0684823.0
100149303320.0683968.0
10020301600.0719571.0
10032363270.0737629.0
100449232380.0783422.0
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "1000 46 3 0 22440.0 68 \n", "1001 49 3 0 3320.0 68 \n", "1002 0 3 0 1600.0 71 \n", "1003 236 3 2 70.0 73 \n", "1004 492 3 2 380.0 78 \n", "\n", " Sales \n", "1000 4823.0 \n", "1001 3968.0 \n", "1002 9571.0 \n", "1003 7629.0 \n", "1004 3422.0 " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.valid.items.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#8) [#na#,1,2,3,4,5,6,7]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.classes['DayOfWeek']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = [list(range(1000)),list(range(1000,2000))]\n", "to = TabularPandas(small_df.copy(), FillMissing, cat_names=small_cat_vars, cont_names=small_cont_vars, splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekPromoIntervalCompetitionDistanceMean_HumiditySalesCompetitionDistance_naMean_Humidity_na
5212915NaN2380.0837928.0TrueFalse
\n", "
" ], "text/plain": [ " Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n", "521 291 5 NaN 2380.0 83 \n", "\n", " Sales CompetitionDistance_na Mean_Humidity_na \n", "521 7928.0 True False " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "to.train.items[to.train.items['CompetitionDistance_na'] == True]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing full data set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_pickle(path/'train_clean')\n", "test_df = pd.read_pickle(path/'test_clean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(844338, 41088)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_df),len(test_df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "procs=[FillMissing, Categorify, Normalize]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = 'Sales'\n", "cat_names = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'StoreType', 'Assortment', \n", " 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear', 'State', 'Week', 'Events', 'Promo_fw', \n", " 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw', 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n", "\n", "cont_names = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC', \n", " 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', 'Mean_Wind_SpeedKm_h', \n", " 'CloudCover', 'trend', 'trend_DE', 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dep_var = 'Sales'\n", "df = train_df[cat_names + cont_names + [dep_var,'Date']].copy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Timestamp('2015-08-01 00:00:00'), Timestamp('2015-09-17 00:00:00'))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df['Date'].min(), test_df['Date'].max()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "41254" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cut = train_df['Date'][(train_df['Date'] == train_df['Date'][len(test_df)])].index.max()\n", "cut" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = (list(range(cut, len(train_df))),list(range(cut)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 5263.0\n", "1 6064.0\n", "2 8314.0\n", "3 13995.0\n", "4 4822.0\n", "Name: Sales, dtype: float64" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df[dep_var].head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df[dep_var] = np.log(train_df[dep_var])\n", "#train_df = train_df.iloc[:100000]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#cut = 20000" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = (list(range(cut, len(train_df))),list(range(cut)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3min 57s, sys: 59.2 s, total: 4min 56s\n", "Wall time: 44.8 s\n" ] } ], "source": [ "%time to = TabularPandas(train_df, procs, cat_names, cont_names, dep_var, y_block=TransformBlock(), splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = to.dataloaders(bs=512, path=path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StoreDayOfWeekYearMonthDayStateHolidayStoreTypeAssortmentPromoIntervalCompetitionOpenSinceYearPromo2SinceYearStateWeekEventsPromo_fwPromo_bwStateHoliday_fwStateHoliday_bwSchoolHoliday_fwSchoolHoliday_bwCompetitionDistance_naMax_TemperatureC_naMean_TemperatureC_naMin_TemperatureC_naMax_Humidity_naMean_Humidity_naMin_Humidity_naMax_Wind_SpeedKm_h_naMean_Wind_SpeedKm_h_naCloudCover_natrend_natrend_DE_naAfterStateHoliday_naBeforeStateHoliday_naPromo_naSchoolHoliday_naCompetitionDistanceMax_TemperatureCMean_TemperatureCMin_TemperatureCMax_HumidityMean_HumidityMin_HumidityMax_Wind_SpeedKm_hMean_Wind_SpeedKm_hCloudCovertrendtrend_DEAfterStateHolidayBeforeStateHolidayPromoSchoolHolidaySales
09062013511Falseaa#na#2007#na#NW19Rain500100FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse329.99992616.011.07.000000e+0093.077.048.00000037.016.06.062.00000060.01.999998-9.0000001.436922e-081.656771e-098.963928
185242013314FalsecaJan,Apr,Jul,Oct20042011HE11Snow310000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse940.0001852.0-4.0-1.100000e+0193.078.051.00000021.05.04.070.00000062.071.999998-15.0000001.436922e-081.656771e-098.379310
218932014924Falseda#na#2014#na#RP39Rain220000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse5760.00001316.011.06.000000e+0097.077.059.00000011.05.06.063.00000072.096.999998-9.0000001.436922e-081.656771e-098.744328
36152201434Falseda#na#2007#na#HE10#na#420000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse729.99981413.07.04.409999e-08100.072.028.99999913.06.03.050.00000055.062.000000-45.0000001.000000e+001.656771e-099.527994
45253201326Falsedc#na#2013#na#BE6#na#330033FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse1869.9999364.01.0-3.000000e+0093.073.048.00000024.014.04.055.00000051.036.000000-51.0000001.000000e+001.000000e+009.314791
567122013101FalseacJan,Apr,Jul,Oct20082010BY40#na#131000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse2070.00011211.09.08.000000e+0093.084.071.00000013.08.07.057.00000062.047.000000-1.9999991.436922e-081.656771e-098.411611
624342015312FalseaaFeb,May,Aug,Nov#na#2013BY11Snow310000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse310.0000018.04.04.409999e-0887.071.049.00000014.06.06.065.00000074.065.000000-21.9999991.436922e-081.656771e-098.549273
780022013910Falseda#na#2014#na#RP37Rain420000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse2020.00006817.014.01.200000e+0189.064.041.00000024.014.05.037.99999951.0103.000001-23.0000001.000000e+001.656771e-098.782169
825312013610FalseacFeb,May,Aug,Nov#na#2013NW24#na#040000FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse250.00022719.014.01.100000e+0182.066.039.00000019.013.04.069.00000067.011.000001-115.0000031.436922e-081.656771e-098.610683
9105332014827Falseaa#na#2015#na#HB,NI35Fog220077FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse1710.00007421.013.06.000000e+00100.080.037.00000014.05.03.075.00000077.079.000001-37.0000001.436922e-081.000000e+008.795733
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_log_y = np.log(1.2) + np.max(train_df['Sales'])\n", "y_range = (0, max_log_y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls.c = 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = tabular_learner(dls, layers=[1000,500], loss_func=MSELossFlat(),\n", " config=tabular_config(ps=[0.001,0.01], embed_p=0.04, y_range=y_range), \n", " metrics=exp_rmspe)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TabularModel(\n", " (embeds): ModuleList(\n", " (0): Embedding(1116, 81)\n", " (1): Embedding(8, 5)\n", " (2): Embedding(4, 3)\n", " (3): Embedding(13, 7)\n", " (4): Embedding(32, 11)\n", " (5): Embedding(3, 3)\n", " (6): Embedding(5, 4)\n", " (7): Embedding(4, 3)\n", " (8): Embedding(4, 3)\n", " (9): Embedding(24, 9)\n", " (10): Embedding(8, 5)\n", " (11): Embedding(13, 7)\n", " (12): Embedding(53, 15)\n", " (13): Embedding(22, 9)\n", " (14): Embedding(7, 5)\n", " (15): Embedding(7, 5)\n", " (16): Embedding(4, 3)\n", " (17): Embedding(4, 3)\n", " (18): Embedding(9, 5)\n", " (19): Embedding(9, 5)\n", " (20): Embedding(3, 3)\n", " (21): Embedding(2, 2)\n", " (22): Embedding(2, 2)\n", " (23): Embedding(2, 2)\n", " (24): Embedding(2, 2)\n", " (25): Embedding(2, 2)\n", " (26): Embedding(2, 2)\n", " (27): Embedding(2, 2)\n", " (28): Embedding(2, 2)\n", " (29): Embedding(3, 3)\n", " (30): Embedding(2, 2)\n", " (31): Embedding(2, 2)\n", " (32): Embedding(2, 2)\n", " (33): Embedding(2, 2)\n", " (34): Embedding(2, 2)\n", " (35): Embedding(2, 2)\n", " )\n", " (emb_drop): Dropout(p=0.04, inplace=False)\n", " (bn_cont): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): LinBnDrop(\n", " (0): BatchNorm1d(241, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Dropout(p=0.001, inplace=False)\n", " (2): Linear(in_features=241, out_features=1000, bias=False)\n", " (3): ReLU(inplace=True)\n", " )\n", " (1): LinBnDrop(\n", " (0): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Dropout(p=0.01, inplace=False)\n", " (2): Linear(in_features=1000, out_features=500, bias=False)\n", " (3): ReLU(inplace=True)\n", " )\n", " (2): LinBnDrop(\n", " (0): Linear(in_features=500, out_features=1, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dls.train_ds.cont_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_loss_exp_rmspetime
00.0274840.0280100.15912301:16
10.0154870.0182400.14121601:16
20.0115810.0157340.12302501:16
30.0084310.0126070.11260901:16
40.0072780.0117240.10859601:16
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5, 3e-3, wd=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(10th place in the competition was 0.108)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_loss(skip_start=1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(10th place in the competition was 0.108)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference on the test set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_to = to.new(test_df)\n", "test_to.process()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_dls = test_to.dataloaders(bs=512, path=path, shuffle_train=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.metrics=[]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tst_preds,_ = learn.get_preds(dl=test_dls.train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 41088)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.exp(tst_preds.numpy()).T.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_df[\"Sales\"]=np.exp(tst_preds.numpy()).T[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_df[[\"Id\",\"Sales\"]] = test_df[[\"Id\",\"Sales\"]].astype(\"int\")\n", "test_df[[\"Id\",\"Sales\"]].to_csv(\"rossmann_submission.csv\",index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This submission scored 3rd on the private leaderboard." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }