{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.tabular.all import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rossmann"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To create the feature-engineered train_clean and test_clean from the Kaggle competition data, run `rossman_data_clean.ipynb`. One important step that deals with time series is this:\n",
"\n",
"```python\n",
"add_datepart(train, \"Date\", drop=False)\n",
"add_datepart(test, \"Date\", drop=False)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = Config().data/'rossmann'\n",
"train_df = pd.read_pickle(path/'train_clean')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
"
\n",
" \n",
" \n",
" \n",
" | Store | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
"
\n",
" \n",
" | DayOfWeek | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
"
\n",
" \n",
" | Date | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
"
\n",
" \n",
" | Sales | \n",
" 5263 | \n",
" 6064 | \n",
" 8314 | \n",
" 13995 | \n",
" 4822 | \n",
"
\n",
" \n",
" | Customers | \n",
" 555 | \n",
" 625 | \n",
" 821 | \n",
" 1498 | \n",
" 559 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | StateHoliday_bw | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | SchoolHoliday_bw | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
"
\n",
" \n",
" | Promo_fw | \n",
" 5 | \n",
" 1 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" | StateHoliday_fw | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | SchoolHoliday_fw | \n",
" 7 | \n",
" 1 | \n",
" 5 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
69 rows × 5 columns
\n",
"
"
],
"text/plain": [
" 0 1 \\\n",
"Store 1 2 \n",
"DayOfWeek 5 5 \n",
"Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n",
"Sales 5263 6064 \n",
"Customers 555 625 \n",
"... ... ... \n",
"StateHoliday_bw 0 0 \n",
"SchoolHoliday_bw 5 5 \n",
"Promo_fw 5 1 \n",
"StateHoliday_fw 0 0 \n",
"SchoolHoliday_fw 7 1 \n",
"\n",
" 2 3 \\\n",
"Store 3 4 \n",
"DayOfWeek 5 5 \n",
"Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n",
"Sales 8314 13995 \n",
"Customers 821 1498 \n",
"... ... ... \n",
"StateHoliday_bw 0 0 \n",
"SchoolHoliday_bw 5 5 \n",
"Promo_fw 5 1 \n",
"StateHoliday_fw 0 0 \n",
"SchoolHoliday_fw 5 1 \n",
"\n",
" 4 \n",
"Store 5 \n",
"DayOfWeek 5 \n",
"Date 2015-07-31 00:00:00 \n",
"Sales 4822 \n",
"Customers 559 \n",
"... ... \n",
"StateHoliday_bw 0 \n",
"SchoolHoliday_bw 5 \n",
"Promo_fw 1 \n",
"StateHoliday_fw 0 \n",
"SchoolHoliday_fw 1 \n",
"\n",
"[69 rows x 5 columns]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df.head().T"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"844338"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = len(train_df); n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Experimenting with a sample"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = np.random.permutation(range(n))[:2000]\n",
"idx.sort()\n",
"small_df = train_df.iloc[idx]\n",
"small_cont_vars = ['CompetitionDistance', 'Mean_Humidity']\n",
"small_cat_vars = ['Store', 'DayOfWeek', 'PromoInterval']\n",
"small_df = small_df[small_cat_vars + small_cont_vars + ['Sales']].reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" PromoInterval | \n",
" CompetitionDistance | \n",
" Mean_Humidity | \n",
" Sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 471 | \n",
" 5 | \n",
" Feb,May,Aug,Nov | \n",
" 5300.0 | \n",
" 50 | \n",
" 9116.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 656 | \n",
" 5 | \n",
" Jan,Apr,Jul,Oct | \n",
" 410.0 | \n",
" 54 | \n",
" 4576.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1112 | \n",
" 5 | \n",
" NaN | \n",
" 1880.0 | \n",
" 61 | \n",
" 9626.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 459 | \n",
" 4 | \n",
" Feb,May,Aug,Nov | \n",
" 250.0 | \n",
" 86 | \n",
" 10847.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1108 | \n",
" 4 | \n",
" NaN | \n",
" 540.0 | \n",
" 51 | \n",
" 7187.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n",
"0 471 5 Feb,May,Aug,Nov 5300.0 50 \n",
"1 656 5 Jan,Apr,Jul,Oct 410.0 54 \n",
"2 1112 5 NaN 1880.0 61 \n",
"3 459 4 Feb,May,Aug,Nov 250.0 86 \n",
"4 1108 4 NaN 540.0 51 \n",
"\n",
" Sales \n",
"0 9116.0 \n",
"1 4576.0 \n",
"2 9626.0 \n",
"3 10847.0 \n",
"4 7187.0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"small_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" PromoInterval | \n",
" CompetitionDistance | \n",
" Mean_Humidity | \n",
" Sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1000 | \n",
" 75 | \n",
" 3 | \n",
" NaN | \n",
" 22440.0 | \n",
" 68 | \n",
" 4823.0 | \n",
"
\n",
" \n",
" | 1001 | \n",
" 79 | \n",
" 3 | \n",
" NaN | \n",
" 3320.0 | \n",
" 68 | \n",
" 3968.0 | \n",
"
\n",
" \n",
" | 1002 | \n",
" 390 | \n",
" 3 | \n",
" NaN | \n",
" 1600.0 | \n",
" 71 | \n",
" 9571.0 | \n",
"
\n",
" \n",
" | 1003 | \n",
" 400 | \n",
" 3 | \n",
" Jan,Apr,Jul,Oct | \n",
" 70.0 | \n",
" 73 | \n",
" 7629.0 | \n",
"
\n",
" \n",
" | 1004 | \n",
" 825 | \n",
" 3 | \n",
" Jan,Apr,Jul,Oct | \n",
" 380.0 | \n",
" 78 | \n",
" 3422.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n",
"1000 75 3 NaN 22440.0 68 \n",
"1001 79 3 NaN 3320.0 68 \n",
"1002 390 3 NaN 1600.0 71 \n",
"1003 400 3 Jan,Apr,Jul,Oct 70.0 73 \n",
"1004 825 3 Jan,Apr,Jul,Oct 380.0 78 \n",
"\n",
" Sales \n",
"1000 4823.0 \n",
"1001 3968.0 \n",
"1002 9571.0 \n",
"1003 7629.0 \n",
"1004 3422.0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"small_df.iloc[1000:].head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splits = [list(range(1000)),list(range(1000,2000))]\n",
"to = TabularPandas(small_df.copy(), Categorify, cat_names=small_cat_vars, cont_names=small_cont_vars, splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" PromoInterval | \n",
" CompetitionDistance | \n",
" Mean_Humidity | \n",
" Sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 283 | \n",
" 5 | \n",
" 1 | \n",
" 5300.0 | \n",
" 50 | \n",
" 9116.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 389 | \n",
" 5 | \n",
" 2 | \n",
" 410.0 | \n",
" 54 | \n",
" 4576.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 653 | \n",
" 5 | \n",
" 0 | \n",
" 1880.0 | \n",
" 61 | \n",
" 9626.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 272 | \n",
" 4 | \n",
" 1 | \n",
" 250.0 | \n",
" 86 | \n",
" 10847.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 649 | \n",
" 4 | \n",
" 0 | \n",
" 540.0 | \n",
" 51 | \n",
" 7187.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n",
"0 283 5 1 5300.0 50 \n",
"1 389 5 2 410.0 54 \n",
"2 653 5 0 1880.0 61 \n",
"3 272 4 1 250.0 86 \n",
"4 649 4 0 540.0 51 \n",
"\n",
" Sales \n",
"0 9116.0 \n",
"1 4576.0 \n",
"2 9626.0 \n",
"3 10847.0 \n",
"4 7187.0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.train.items.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" PromoInterval | \n",
" CompetitionDistance | \n",
" Mean_Humidity | \n",
" Sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1000 | \n",
" 46 | \n",
" 3 | \n",
" 0 | \n",
" 22440.0 | \n",
" 68 | \n",
" 4823.0 | \n",
"
\n",
" \n",
" | 1001 | \n",
" 49 | \n",
" 3 | \n",
" 0 | \n",
" 3320.0 | \n",
" 68 | \n",
" 3968.0 | \n",
"
\n",
" \n",
" | 1002 | \n",
" 0 | \n",
" 3 | \n",
" 0 | \n",
" 1600.0 | \n",
" 71 | \n",
" 9571.0 | \n",
"
\n",
" \n",
" | 1003 | \n",
" 236 | \n",
" 3 | \n",
" 2 | \n",
" 70.0 | \n",
" 73 | \n",
" 7629.0 | \n",
"
\n",
" \n",
" | 1004 | \n",
" 492 | \n",
" 3 | \n",
" 2 | \n",
" 380.0 | \n",
" 78 | \n",
" 3422.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n",
"1000 46 3 0 22440.0 68 \n",
"1001 49 3 0 3320.0 68 \n",
"1002 0 3 0 1600.0 71 \n",
"1003 236 3 2 70.0 73 \n",
"1004 492 3 2 380.0 78 \n",
"\n",
" Sales \n",
"1000 4823.0 \n",
"1001 3968.0 \n",
"1002 9571.0 \n",
"1003 7629.0 \n",
"1004 3422.0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.valid.items.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#8) [#na#,1,2,3,4,5,6,7]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.classes['DayOfWeek']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splits = [list(range(1000)),list(range(1000,2000))]\n",
"to = TabularPandas(small_df.copy(), FillMissing, cat_names=small_cat_vars, cont_names=small_cont_vars, splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" PromoInterval | \n",
" CompetitionDistance | \n",
" Mean_Humidity | \n",
" Sales | \n",
" CompetitionDistance_na | \n",
" Mean_Humidity_na | \n",
"
\n",
" \n",
" \n",
" \n",
" | 521 | \n",
" 291 | \n",
" 5 | \n",
" NaN | \n",
" 2380.0 | \n",
" 83 | \n",
" 7928.0 | \n",
" True | \n",
" False | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek PromoInterval CompetitionDistance Mean_Humidity \\\n",
"521 291 5 NaN 2380.0 83 \n",
"\n",
" Sales CompetitionDistance_na Mean_Humidity_na \n",
"521 7928.0 True False "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"to.train.items[to.train.items['CompetitionDistance_na'] == True]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preparing full data set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_pickle(path/'train_clean')\n",
"test_df = pd.read_pickle(path/'test_clean')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(844338, 41088)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_df),len(test_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"procs=[FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep_var = 'Sales'\n",
"cat_names = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'StoreType', 'Assortment', \n",
" 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear', 'State', 'Week', 'Events', 'Promo_fw', \n",
" 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw', 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n",
"\n",
"cont_names = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC', \n",
" 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', 'Mean_Wind_SpeedKm_h', \n",
" 'CloudCover', 'trend', 'trend_DE', 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep_var = 'Sales'\n",
"df = train_df[cat_names + cont_names + [dep_var,'Date']].copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Timestamp('2015-08-01 00:00:00'), Timestamp('2015-09-17 00:00:00'))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_df['Date'].min(), test_df['Date'].max()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"41254"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cut = train_df['Date'][(train_df['Date'] == train_df['Date'][len(test_df)])].index.max()\n",
"cut"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splits = (list(range(cut, len(train_df))),list(range(cut)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 5263.0\n",
"1 6064.0\n",
"2 8314.0\n",
"3 13995.0\n",
"4 4822.0\n",
"Name: Sales, dtype: float64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df[dep_var].head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df[dep_var] = np.log(train_df[dep_var])\n",
"#train_df = train_df.iloc[:100000]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#cut = 20000"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splits = (list(range(cut, len(train_df))),list(range(cut)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3min 57s, sys: 59.2 s, total: 4min 56s\n",
"Wall time: 44.8 s\n"
]
}
],
"source": [
"%time to = TabularPandas(train_df, procs, cat_names, cont_names, dep_var, y_block=TransformBlock(), splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = to.dataloaders(bs=512, path=path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" Year | \n",
" Month | \n",
" Day | \n",
" StateHoliday | \n",
" StoreType | \n",
" Assortment | \n",
" PromoInterval | \n",
" CompetitionOpenSinceYear | \n",
" Promo2SinceYear | \n",
" State | \n",
" Week | \n",
" Events | \n",
" Promo_fw | \n",
" Promo_bw | \n",
" StateHoliday_fw | \n",
" StateHoliday_bw | \n",
" SchoolHoliday_fw | \n",
" SchoolHoliday_bw | \n",
" CompetitionDistance_na | \n",
" Max_TemperatureC_na | \n",
" Mean_TemperatureC_na | \n",
" Min_TemperatureC_na | \n",
" Max_Humidity_na | \n",
" Mean_Humidity_na | \n",
" Min_Humidity_na | \n",
" Max_Wind_SpeedKm_h_na | \n",
" Mean_Wind_SpeedKm_h_na | \n",
" CloudCover_na | \n",
" trend_na | \n",
" trend_DE_na | \n",
" AfterStateHoliday_na | \n",
" BeforeStateHoliday_na | \n",
" Promo_na | \n",
" SchoolHoliday_na | \n",
" CompetitionDistance | \n",
" Max_TemperatureC | \n",
" Mean_TemperatureC | \n",
" Min_TemperatureC | \n",
" Max_Humidity | \n",
" Mean_Humidity | \n",
" Min_Humidity | \n",
" Max_Wind_SpeedKm_h | \n",
" Mean_Wind_SpeedKm_h | \n",
" CloudCover | \n",
" trend | \n",
" trend_DE | \n",
" AfterStateHoliday | \n",
" BeforeStateHoliday | \n",
" Promo | \n",
" SchoolHoliday | \n",
" Sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 90 | \n",
" 6 | \n",
" 2013 | \n",
" 5 | \n",
" 11 | \n",
" False | \n",
" a | \n",
" a | \n",
" #na# | \n",
" 2007 | \n",
" #na# | \n",
" NW | \n",
" 19 | \n",
" Rain | \n",
" 5 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 329.999926 | \n",
" 16.0 | \n",
" 11.0 | \n",
" 7.000000e+00 | \n",
" 93.0 | \n",
" 77.0 | \n",
" 48.000000 | \n",
" 37.0 | \n",
" 16.0 | \n",
" 6.0 | \n",
" 62.000000 | \n",
" 60.0 | \n",
" 1.999998 | \n",
" -9.000000 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.963928 | \n",
"
\n",
" \n",
" | 1 | \n",
" 852 | \n",
" 4 | \n",
" 2013 | \n",
" 3 | \n",
" 14 | \n",
" False | \n",
" c | \n",
" a | \n",
" Jan,Apr,Jul,Oct | \n",
" 2004 | \n",
" 2011 | \n",
" HE | \n",
" 11 | \n",
" Snow | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 940.000185 | \n",
" 2.0 | \n",
" -4.0 | \n",
" -1.100000e+01 | \n",
" 93.0 | \n",
" 78.0 | \n",
" 51.000000 | \n",
" 21.0 | \n",
" 5.0 | \n",
" 4.0 | \n",
" 70.000000 | \n",
" 62.0 | \n",
" 71.999998 | \n",
" -15.000000 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.379310 | \n",
"
\n",
" \n",
" | 2 | \n",
" 189 | \n",
" 3 | \n",
" 2014 | \n",
" 9 | \n",
" 24 | \n",
" False | \n",
" d | \n",
" a | \n",
" #na# | \n",
" 2014 | \n",
" #na# | \n",
" RP | \n",
" 39 | \n",
" Rain | \n",
" 2 | \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 5760.000013 | \n",
" 16.0 | \n",
" 11.0 | \n",
" 6.000000e+00 | \n",
" 97.0 | \n",
" 77.0 | \n",
" 59.000000 | \n",
" 11.0 | \n",
" 5.0 | \n",
" 6.0 | \n",
" 63.000000 | \n",
" 72.0 | \n",
" 96.999998 | \n",
" -9.000000 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.744328 | \n",
"
\n",
" \n",
" | 3 | \n",
" 615 | \n",
" 2 | \n",
" 2014 | \n",
" 3 | \n",
" 4 | \n",
" False | \n",
" d | \n",
" a | \n",
" #na# | \n",
" 2007 | \n",
" #na# | \n",
" HE | \n",
" 10 | \n",
" #na# | \n",
" 4 | \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 729.999814 | \n",
" 13.0 | \n",
" 7.0 | \n",
" 4.409999e-08 | \n",
" 100.0 | \n",
" 72.0 | \n",
" 28.999999 | \n",
" 13.0 | \n",
" 6.0 | \n",
" 3.0 | \n",
" 50.000000 | \n",
" 55.0 | \n",
" 62.000000 | \n",
" -45.000000 | \n",
" 1.000000e+00 | \n",
" 1.656771e-09 | \n",
" 9.527994 | \n",
"
\n",
" \n",
" | 4 | \n",
" 525 | \n",
" 3 | \n",
" 2013 | \n",
" 2 | \n",
" 6 | \n",
" False | \n",
" d | \n",
" c | \n",
" #na# | \n",
" 2013 | \n",
" #na# | \n",
" BE | \n",
" 6 | \n",
" #na# | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 3 | \n",
" 3 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 1869.999936 | \n",
" 4.0 | \n",
" 1.0 | \n",
" -3.000000e+00 | \n",
" 93.0 | \n",
" 73.0 | \n",
" 48.000000 | \n",
" 24.0 | \n",
" 14.0 | \n",
" 4.0 | \n",
" 55.000000 | \n",
" 51.0 | \n",
" 36.000000 | \n",
" -51.000000 | \n",
" 1.000000e+00 | \n",
" 1.000000e+00 | \n",
" 9.314791 | \n",
"
\n",
" \n",
" | 5 | \n",
" 671 | \n",
" 2 | \n",
" 2013 | \n",
" 10 | \n",
" 1 | \n",
" False | \n",
" a | \n",
" c | \n",
" Jan,Apr,Jul,Oct | \n",
" 2008 | \n",
" 2010 | \n",
" BY | \n",
" 40 | \n",
" #na# | \n",
" 1 | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 2070.000112 | \n",
" 11.0 | \n",
" 9.0 | \n",
" 8.000000e+00 | \n",
" 93.0 | \n",
" 84.0 | \n",
" 71.000000 | \n",
" 13.0 | \n",
" 8.0 | \n",
" 7.0 | \n",
" 57.000000 | \n",
" 62.0 | \n",
" 47.000000 | \n",
" -1.999999 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.411611 | \n",
"
\n",
" \n",
" | 6 | \n",
" 243 | \n",
" 4 | \n",
" 2015 | \n",
" 3 | \n",
" 12 | \n",
" False | \n",
" a | \n",
" a | \n",
" Feb,May,Aug,Nov | \n",
" #na# | \n",
" 2013 | \n",
" BY | \n",
" 11 | \n",
" Snow | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 310.000001 | \n",
" 8.0 | \n",
" 4.0 | \n",
" 4.409999e-08 | \n",
" 87.0 | \n",
" 71.0 | \n",
" 49.000000 | \n",
" 14.0 | \n",
" 6.0 | \n",
" 6.0 | \n",
" 65.000000 | \n",
" 74.0 | \n",
" 65.000000 | \n",
" -21.999999 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.549273 | \n",
"
\n",
" \n",
" | 7 | \n",
" 800 | \n",
" 2 | \n",
" 2013 | \n",
" 9 | \n",
" 10 | \n",
" False | \n",
" d | \n",
" a | \n",
" #na# | \n",
" 2014 | \n",
" #na# | \n",
" RP | \n",
" 37 | \n",
" Rain | \n",
" 4 | \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 2020.000068 | \n",
" 17.0 | \n",
" 14.0 | \n",
" 1.200000e+01 | \n",
" 89.0 | \n",
" 64.0 | \n",
" 41.000000 | \n",
" 24.0 | \n",
" 14.0 | \n",
" 5.0 | \n",
" 37.999999 | \n",
" 51.0 | \n",
" 103.000001 | \n",
" -23.000000 | \n",
" 1.000000e+00 | \n",
" 1.656771e-09 | \n",
" 8.782169 | \n",
"
\n",
" \n",
" | 8 | \n",
" 253 | \n",
" 1 | \n",
" 2013 | \n",
" 6 | \n",
" 10 | \n",
" False | \n",
" a | \n",
" c | \n",
" Feb,May,Aug,Nov | \n",
" #na# | \n",
" 2013 | \n",
" NW | \n",
" 24 | \n",
" #na# | \n",
" 0 | \n",
" 4 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 250.000227 | \n",
" 19.0 | \n",
" 14.0 | \n",
" 1.100000e+01 | \n",
" 82.0 | \n",
" 66.0 | \n",
" 39.000000 | \n",
" 19.0 | \n",
" 13.0 | \n",
" 4.0 | \n",
" 69.000000 | \n",
" 67.0 | \n",
" 11.000001 | \n",
" -115.000003 | \n",
" 1.436922e-08 | \n",
" 1.656771e-09 | \n",
" 8.610683 | \n",
"
\n",
" \n",
" | 9 | \n",
" 1053 | \n",
" 3 | \n",
" 2014 | \n",
" 8 | \n",
" 27 | \n",
" False | \n",
" a | \n",
" a | \n",
" #na# | \n",
" 2015 | \n",
" #na# | \n",
" HB,NI | \n",
" 35 | \n",
" Fog | \n",
" 2 | \n",
" 2 | \n",
" 0 | \n",
" 0 | \n",
" 7 | \n",
" 7 | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" 1710.000074 | \n",
" 21.0 | \n",
" 13.0 | \n",
" 6.000000e+00 | \n",
" 100.0 | \n",
" 80.0 | \n",
" 37.000000 | \n",
" 14.0 | \n",
" 5.0 | \n",
" 3.0 | \n",
" 75.000000 | \n",
" 77.0 | \n",
" 79.000001 | \n",
" -37.000000 | \n",
" 1.436922e-08 | \n",
" 1.000000e+00 | \n",
" 8.795733 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"max_log_y = np.log(1.2) + np.max(train_df['Sales'])\n",
"y_range = (0, max_log_y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls.c = 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(dls, layers=[1000,500], loss_func=MSELossFlat(),\n",
" config=tabular_config(ps=[0.001,0.01], embed_p=0.04, y_range=y_range), \n",
" metrics=exp_rmspe)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TabularModel(\n",
" (embeds): ModuleList(\n",
" (0): Embedding(1116, 81)\n",
" (1): Embedding(8, 5)\n",
" (2): Embedding(4, 3)\n",
" (3): Embedding(13, 7)\n",
" (4): Embedding(32, 11)\n",
" (5): Embedding(3, 3)\n",
" (6): Embedding(5, 4)\n",
" (7): Embedding(4, 3)\n",
" (8): Embedding(4, 3)\n",
" (9): Embedding(24, 9)\n",
" (10): Embedding(8, 5)\n",
" (11): Embedding(13, 7)\n",
" (12): Embedding(53, 15)\n",
" (13): Embedding(22, 9)\n",
" (14): Embedding(7, 5)\n",
" (15): Embedding(7, 5)\n",
" (16): Embedding(4, 3)\n",
" (17): Embedding(4, 3)\n",
" (18): Embedding(9, 5)\n",
" (19): Embedding(9, 5)\n",
" (20): Embedding(3, 3)\n",
" (21): Embedding(2, 2)\n",
" (22): Embedding(2, 2)\n",
" (23): Embedding(2, 2)\n",
" (24): Embedding(2, 2)\n",
" (25): Embedding(2, 2)\n",
" (26): Embedding(2, 2)\n",
" (27): Embedding(2, 2)\n",
" (28): Embedding(2, 2)\n",
" (29): Embedding(3, 3)\n",
" (30): Embedding(2, 2)\n",
" (31): Embedding(2, 2)\n",
" (32): Embedding(2, 2)\n",
" (33): Embedding(2, 2)\n",
" (34): Embedding(2, 2)\n",
" (35): Embedding(2, 2)\n",
" )\n",
" (emb_drop): Dropout(p=0.04, inplace=False)\n",
" (bn_cont): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers): Sequential(\n",
" (0): LinBnDrop(\n",
" (0): BatchNorm1d(241, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Dropout(p=0.001, inplace=False)\n",
" (2): Linear(in_features=241, out_features=1000, bias=False)\n",
" (3): ReLU(inplace=True)\n",
" )\n",
" (1): LinBnDrop(\n",
" (0): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Dropout(p=0.01, inplace=False)\n",
" (2): Linear(in_features=1000, out_features=500, bias=False)\n",
" (3): ReLU(inplace=True)\n",
" )\n",
" (2): LinBnDrop(\n",
" (0): Linear(in_features=500, out_features=1, bias=True)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dls.train_ds.cont_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" _exp_rmspe | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.027484 | \n",
" 0.028010 | \n",
" 0.159123 | \n",
" 01:16 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.015487 | \n",
" 0.018240 | \n",
" 0.141216 | \n",
" 01:16 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.011581 | \n",
" 0.015734 | \n",
" 0.123025 | \n",
" 01:16 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.008431 | \n",
" 0.012607 | \n",
" 0.112609 | \n",
" 01:16 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.007278 | \n",
" 0.011724 | \n",
" 0.108596 | \n",
" 01:16 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 3e-3, wd=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(10th place in the competition was 0.108)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_loss(skip_start=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(10th place in the competition was 0.108)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference on the test set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_to = to.new(test_df)\n",
"test_to.process()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dls = test_to.dataloaders(bs=512, path=path, shuffle_train=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.metrics=[]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tst_preds,_ = learn.get_preds(dl=test_dls.train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 41088)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.exp(tst_preds.numpy()).T.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_df[\"Sales\"]=np.exp(tst_preds.numpy()).T[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_df[[\"Id\",\"Sales\"]] = test_df[[\"Id\",\"Sales\"]].astype(\"int\")\n",
"test_df[[\"Id\",\"Sales\"]].to_csv(\"rossmann_submission.csv\",index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This submission scored 3rd on the private leaderboard."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}