{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Structured and time series data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook contains an implementation of the third place result in the Rossman Kaggle competition as detailed in Guo/Berkhahn's [Entity Embeddings of Categorical Variables](https://arxiv.org/abs/1604.06737).\n",
"\n",
"The motivation behind exploring this architecture is it's relevance to real-world application. Most data used for decision making day-to-day in industry is structured and/or time-series data. Here we explore the end-to-end process of using neural networks with practical structured data problems."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.structured import *\n",
"from fastai.column_data import *\n",
"np.set_printoptions(threshold=50, edgeitems=20)\n",
"\n",
"PATH='data/rossmann/'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create datasets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to the provided data, we will be using external datasets put together by participants in the Kaggle competition. You can download all of them [here](http://files.fast.ai/part2/lesson14/rossmann.tgz).\n",
"\n",
"For completeness, the implementation used to put them together is included below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def concat_csvs(dirname):\n",
" path = f'{PATH}{dirname}'\n",
" filenames=glob(f\"{PATH}/*.csv\")\n",
"\n",
" wrote_header = False\n",
" with open(f\"{path}.csv\",\"w\") as outputfile:\n",
" for filename in filenames:\n",
" name = filename.split(\".\")[0]\n",
" with open(filename) as f:\n",
" line = f.readline()\n",
" if not wrote_header:\n",
" wrote_header = True\n",
" outputfile.write(\"file,\"+line)\n",
" for line in f:\n",
" outputfile.write(name + \",\" + line)\n",
" outputfile.write(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# concat_csvs('googletrend')\n",
"# concat_csvs('weather')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feature Space:\n",
"* train: Training set provided by competition\n",
"* store: List of stores\n",
"* store_states: mapping of store to the German state they are in\n",
"* List of German state names\n",
"* googletrend: trend of certain google keywords over time, found by users to correlate well w/ given data\n",
"* weather: weather\n",
"* test: testing set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_names = ['train', 'store', 'store_states', 'state_names', \n",
" 'googletrend', 'weather', 'test']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll be using the popular data manipulation framework `pandas`. Among other things, pandas allows you to manipulate tables/data frames in python as one would in a database.\n",
"\n",
"We're going to go ahead and load all of our csv's as dataframes into the list `tables`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tables = [pd.read_csv(f'{PATH}{fname}.csv', low_memory=False) for fname in table_names]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML, display"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use `head()` to get a quick look at the contents of each table:\n",
"* train: Contains store information on a daily basis, tracks things like sales, customers, whether that day was a holdiay, etc.\n",
"* store: general info about the store including competition, etc.\n",
"* store_states: maps store to state it is in\n",
"* state_names: Maps state abbreviations to names\n",
"* googletrend: trend data for particular week/state\n",
"* weather: weather conditions for each state\n",
"* test: Same as training table, w/o sales and customers\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" Date | \n",
" Sales | \n",
" Customers | \n",
" Open | \n",
" Promo | \n",
" StateHoliday | \n",
" SchoolHoliday | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 5 | \n",
" 2015-07-31 | \n",
" 5263 | \n",
" 555 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 5 | \n",
" 2015-07-31 | \n",
" 6064 | \n",
" 625 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 5 | \n",
" 2015-07-31 | \n",
" 8314 | \n",
" 821 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
" 2015-07-31 | \n",
" 13995 | \n",
" 1498 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 5 | \n",
" 2015-07-31 | \n",
" 4822 | \n",
" 559 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek Date Sales Customers Open Promo StateHoliday \\\n",
"0 1 5 2015-07-31 5263 555 1 1 0 \n",
"1 2 5 2015-07-31 6064 625 1 1 0 \n",
"2 3 5 2015-07-31 8314 821 1 1 0 \n",
"3 4 5 2015-07-31 13995 1498 1 1 0 \n",
"4 5 5 2015-07-31 4822 559 1 1 0 \n",
"\n",
" SchoolHoliday \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" StoreType | \n",
" Assortment | \n",
" CompetitionDistance | \n",
" CompetitionOpenSinceMonth | \n",
" CompetitionOpenSinceYear | \n",
" Promo2 | \n",
" Promo2SinceWeek | \n",
" Promo2SinceYear | \n",
" PromoInterval | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" c | \n",
" a | \n",
" 1270.0 | \n",
" 9.0 | \n",
" 2008.0 | \n",
" 0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" a | \n",
" a | \n",
" 570.0 | \n",
" 11.0 | \n",
" 2007.0 | \n",
" 1 | \n",
" 13.0 | \n",
" 2010.0 | \n",
" Jan,Apr,Jul,Oct | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" a | \n",
" a | \n",
" 14130.0 | \n",
" 12.0 | \n",
" 2006.0 | \n",
" 1 | \n",
" 14.0 | \n",
" 2011.0 | \n",
" Jan,Apr,Jul,Oct | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" c | \n",
" c | \n",
" 620.0 | \n",
" 9.0 | \n",
" 2009.0 | \n",
" 0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" a | \n",
" a | \n",
" 29910.0 | \n",
" 4.0 | \n",
" 2015.0 | \n",
" 0 | \n",
" NaN | \n",
" NaN | \n",
" NaN | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store StoreType Assortment CompetitionDistance CompetitionOpenSinceMonth \\\n",
"0 1 c a 1270.0 9.0 \n",
"1 2 a a 570.0 11.0 \n",
"2 3 a a 14130.0 12.0 \n",
"3 4 c c 620.0 9.0 \n",
"4 5 a a 29910.0 4.0 \n",
"\n",
" CompetitionOpenSinceYear Promo2 Promo2SinceWeek Promo2SinceYear \\\n",
"0 2008.0 0 NaN NaN \n",
"1 2007.0 1 13.0 2010.0 \n",
"2 2006.0 1 14.0 2011.0 \n",
"3 2009.0 0 NaN NaN \n",
"4 2015.0 0 NaN NaN \n",
"\n",
" PromoInterval \n",
"0 NaN \n",
"1 Jan,Apr,Jul,Oct \n",
"2 Jan,Apr,Jul,Oct \n",
"3 NaN \n",
"4 NaN "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" State | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" HE | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" TH | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" NW | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" BE | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" SN | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store State\n",
"0 1 HE\n",
"1 2 TH\n",
"2 3 NW\n",
"3 4 BE\n",
"4 5 SN"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" StateName | \n",
" State | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" BadenWuerttemberg | \n",
" BW | \n",
"
\n",
" \n",
" 1 | \n",
" Bayern | \n",
" BY | \n",
"
\n",
" \n",
" 2 | \n",
" Berlin | \n",
" BE | \n",
"
\n",
" \n",
" 3 | \n",
" Brandenburg | \n",
" BB | \n",
"
\n",
" \n",
" 4 | \n",
" Bremen | \n",
" HB | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" StateName State\n",
"0 BadenWuerttemberg BW\n",
"1 Bayern BY\n",
"2 Berlin BE\n",
"3 Brandenburg BB\n",
"4 Bremen HB"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" file | \n",
" week | \n",
" trend | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Rossmann_DE_SN | \n",
" 2012-12-02 - 2012-12-08 | \n",
" 96 | \n",
"
\n",
" \n",
" 1 | \n",
" Rossmann_DE_SN | \n",
" 2012-12-09 - 2012-12-15 | \n",
" 95 | \n",
"
\n",
" \n",
" 2 | \n",
" Rossmann_DE_SN | \n",
" 2012-12-16 - 2012-12-22 | \n",
" 91 | \n",
"
\n",
" \n",
" 3 | \n",
" Rossmann_DE_SN | \n",
" 2012-12-23 - 2012-12-29 | \n",
" 48 | \n",
"
\n",
" \n",
" 4 | \n",
" Rossmann_DE_SN | \n",
" 2012-12-30 - 2013-01-05 | \n",
" 67 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" file week trend\n",
"0 Rossmann_DE_SN 2012-12-02 - 2012-12-08 96\n",
"1 Rossmann_DE_SN 2012-12-09 - 2012-12-15 95\n",
"2 Rossmann_DE_SN 2012-12-16 - 2012-12-22 91\n",
"3 Rossmann_DE_SN 2012-12-23 - 2012-12-29 48\n",
"4 Rossmann_DE_SN 2012-12-30 - 2013-01-05 67"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" file | \n",
" Date | \n",
" Max_TemperatureC | \n",
" Mean_TemperatureC | \n",
" Min_TemperatureC | \n",
" Dew_PointC | \n",
" MeanDew_PointC | \n",
" Min_DewpointC | \n",
" Max_Humidity | \n",
" Mean_Humidity | \n",
" ... | \n",
" Max_VisibilityKm | \n",
" Mean_VisibilityKm | \n",
" Min_VisibilitykM | \n",
" Max_Wind_SpeedKm_h | \n",
" Mean_Wind_SpeedKm_h | \n",
" Max_Gust_SpeedKm_h | \n",
" Precipitationmm | \n",
" CloudCover | \n",
" Events | \n",
" WindDirDegrees | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" NordrheinWestfalen | \n",
" 2013-01-01 | \n",
" 8 | \n",
" 4 | \n",
" 2 | \n",
" 7 | \n",
" 5 | \n",
" 1 | \n",
" 94 | \n",
" 87 | \n",
" ... | \n",
" 31.0 | \n",
" 12.0 | \n",
" 4.0 | \n",
" 39 | \n",
" 26 | \n",
" 58.0 | \n",
" 5.08 | \n",
" 6.0 | \n",
" Rain | \n",
" 215 | \n",
"
\n",
" \n",
" 1 | \n",
" NordrheinWestfalen | \n",
" 2013-01-02 | \n",
" 7 | \n",
" 4 | \n",
" 1 | \n",
" 5 | \n",
" 3 | \n",
" 2 | \n",
" 93 | \n",
" 85 | \n",
" ... | \n",
" 31.0 | \n",
" 14.0 | \n",
" 10.0 | \n",
" 24 | \n",
" 16 | \n",
" NaN | \n",
" 0.00 | \n",
" 6.0 | \n",
" Rain | \n",
" 225 | \n",
"
\n",
" \n",
" 2 | \n",
" NordrheinWestfalen | \n",
" 2013-01-03 | \n",
" 11 | \n",
" 8 | \n",
" 6 | \n",
" 10 | \n",
" 8 | \n",
" 4 | \n",
" 100 | \n",
" 93 | \n",
" ... | \n",
" 31.0 | \n",
" 8.0 | \n",
" 2.0 | \n",
" 26 | \n",
" 21 | \n",
" NaN | \n",
" 1.02 | \n",
" 7.0 | \n",
" Rain | \n",
" 240 | \n",
"
\n",
" \n",
" 3 | \n",
" NordrheinWestfalen | \n",
" 2013-01-04 | \n",
" 9 | \n",
" 9 | \n",
" 8 | \n",
" 9 | \n",
" 9 | \n",
" 8 | \n",
" 100 | \n",
" 94 | \n",
" ... | \n",
" 11.0 | \n",
" 5.0 | \n",
" 2.0 | \n",
" 23 | \n",
" 14 | \n",
" NaN | \n",
" 0.25 | \n",
" 7.0 | \n",
" Rain | \n",
" 263 | \n",
"
\n",
" \n",
" 4 | \n",
" NordrheinWestfalen | \n",
" 2013-01-05 | \n",
" 8 | \n",
" 8 | \n",
" 7 | \n",
" 8 | \n",
" 7 | \n",
" 6 | \n",
" 100 | \n",
" 94 | \n",
" ... | \n",
" 10.0 | \n",
" 6.0 | \n",
" 3.0 | \n",
" 16 | \n",
" 10 | \n",
" NaN | \n",
" 0.00 | \n",
" 7.0 | \n",
" Rain | \n",
" 268 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 24 columns
\n",
"
"
],
"text/plain": [
" file Date Max_TemperatureC Mean_TemperatureC \\\n",
"0 NordrheinWestfalen 2013-01-01 8 4 \n",
"1 NordrheinWestfalen 2013-01-02 7 4 \n",
"2 NordrheinWestfalen 2013-01-03 11 8 \n",
"3 NordrheinWestfalen 2013-01-04 9 9 \n",
"4 NordrheinWestfalen 2013-01-05 8 8 \n",
"\n",
" Min_TemperatureC Dew_PointC MeanDew_PointC Min_DewpointC Max_Humidity \\\n",
"0 2 7 5 1 94 \n",
"1 1 5 3 2 93 \n",
"2 6 10 8 4 100 \n",
"3 8 9 9 8 100 \n",
"4 7 8 7 6 100 \n",
"\n",
" Mean_Humidity ... Max_VisibilityKm Mean_VisibilityKm \\\n",
"0 87 ... 31.0 12.0 \n",
"1 85 ... 31.0 14.0 \n",
"2 93 ... 31.0 8.0 \n",
"3 94 ... 11.0 5.0 \n",
"4 94 ... 10.0 6.0 \n",
"\n",
" Min_VisibilitykM Max_Wind_SpeedKm_h Mean_Wind_SpeedKm_h \\\n",
"0 4.0 39 26 \n",
"1 10.0 24 16 \n",
"2 2.0 26 21 \n",
"3 2.0 23 14 \n",
"4 3.0 16 10 \n",
"\n",
" Max_Gust_SpeedKm_h Precipitationmm CloudCover Events WindDirDegrees \n",
"0 58.0 5.08 6.0 Rain 215 \n",
"1 NaN 0.00 6.0 Rain 225 \n",
"2 NaN 1.02 7.0 Rain 240 \n",
"3 NaN 0.25 7.0 Rain 263 \n",
"4 NaN 0.00 7.0 Rain 268 \n",
"\n",
"[5 rows x 24 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Id | \n",
" Store | \n",
" DayOfWeek | \n",
" Date | \n",
" Open | \n",
" Promo | \n",
" StateHoliday | \n",
" SchoolHoliday | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 4 | \n",
" 2015-09-17 | \n",
" 1.0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 2015-09-17 | \n",
" 1.0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 7 | \n",
" 4 | \n",
" 2015-09-17 | \n",
" 1.0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 8 | \n",
" 4 | \n",
" 2015-09-17 | \n",
" 1.0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 9 | \n",
" 4 | \n",
" 2015-09-17 | \n",
" 1.0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Id Store DayOfWeek Date Open Promo StateHoliday SchoolHoliday\n",
"0 1 1 4 2015-09-17 1.0 1 0 0\n",
"1 2 3 4 2015-09-17 1.0 1 0 0\n",
"2 3 7 4 2015-09-17 1.0 1 0 0\n",
"3 4 8 4 2015-09-17 1.0 1 0 0\n",
"4 5 9 4 2015-09-17 1.0 1 0 0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for t in tables: display(t.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is very representative of a typical industry dataset.\n",
"\n",
"The following returns summarized aggregate information to each table accross each field."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" Date | \n",
" Sales | \n",
" Customers | \n",
" Open | \n",
" Promo | \n",
" StateHoliday | \n",
" SchoolHoliday | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 1.01721e+06 | \n",
" 1.01721e+06 | \n",
" NaN | \n",
" 1.01721e+06 | \n",
" 1.01721e+06 | \n",
" 1.01721e+06 | \n",
" 1.01721e+06 | \n",
" NaN | \n",
" 1.01721e+06 | \n",
"
\n",
" \n",
" mean | \n",
" 558.43 | \n",
" 3.99834 | \n",
" NaN | \n",
" 5773.82 | \n",
" 633.146 | \n",
" 0.830107 | \n",
" 0.381515 | \n",
" NaN | \n",
" 0.178647 | \n",
"
\n",
" \n",
" std | \n",
" 321.909 | \n",
" 1.99739 | \n",
" NaN | \n",
" 3849.93 | \n",
" 464.412 | \n",
" 0.375539 | \n",
" 0.485759 | \n",
" NaN | \n",
" 0.383056 | \n",
"
\n",
" \n",
" min | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 25% | \n",
" 280 | \n",
" 2 | \n",
" NaN | \n",
" 3727 | \n",
" 405 | \n",
" 1 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 50% | \n",
" 558 | \n",
" 4 | \n",
" NaN | \n",
" 5744 | \n",
" 609 | \n",
" 1 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 75% | \n",
" 838 | \n",
" 6 | \n",
" NaN | \n",
" 7856 | \n",
" 837 | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" max | \n",
" 1115 | \n",
" 7 | \n",
" NaN | \n",
" 41551 | \n",
" 7388 | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 1 | \n",
"
\n",
" \n",
" counts | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
" 1017209 | \n",
"
\n",
" \n",
" uniques | \n",
" 1115 | \n",
" 7 | \n",
" 942 | \n",
" 21734 | \n",
" 4086 | \n",
" 2 | \n",
" 2 | \n",
" 4 | \n",
" 2 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" numeric | \n",
" numeric | \n",
" categorical | \n",
" numeric | \n",
" numeric | \n",
" bool | \n",
" bool | \n",
" categorical | \n",
" bool | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store DayOfWeek Date Sales Customers \\\n",
"count 1.01721e+06 1.01721e+06 NaN 1.01721e+06 1.01721e+06 \n",
"mean 558.43 3.99834 NaN 5773.82 633.146 \n",
"std 321.909 1.99739 NaN 3849.93 464.412 \n",
"min 1 1 NaN 0 0 \n",
"25% 280 2 NaN 3727 405 \n",
"50% 558 4 NaN 5744 609 \n",
"75% 838 6 NaN 7856 837 \n",
"max 1115 7 NaN 41551 7388 \n",
"counts 1017209 1017209 1017209 1017209 1017209 \n",
"uniques 1115 7 942 21734 4086 \n",
"missing 0 0 0 0 0 \n",
"missing_perc 0% 0% 0% 0% 0% \n",
"types numeric numeric categorical numeric numeric \n",
"\n",
" Open Promo StateHoliday SchoolHoliday \n",
"count 1.01721e+06 1.01721e+06 NaN 1.01721e+06 \n",
"mean 0.830107 0.381515 NaN 0.178647 \n",
"std 0.375539 0.485759 NaN 0.383056 \n",
"min 0 0 NaN 0 \n",
"25% 1 0 NaN 0 \n",
"50% 1 0 NaN 0 \n",
"75% 1 1 NaN 0 \n",
"max 1 1 NaN 1 \n",
"counts 1017209 1017209 1017209 1017209 \n",
"uniques 2 2 4 2 \n",
"missing 0 0 0 0 \n",
"missing_perc 0% 0% 0% 0% \n",
"types bool bool categorical bool "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" StoreType | \n",
" Assortment | \n",
" CompetitionDistance | \n",
" CompetitionOpenSinceMonth | \n",
" CompetitionOpenSinceYear | \n",
" Promo2 | \n",
" Promo2SinceWeek | \n",
" Promo2SinceYear | \n",
" PromoInterval | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 1115 | \n",
" NaN | \n",
" NaN | \n",
" 1112 | \n",
" 761 | \n",
" 761 | \n",
" 1115 | \n",
" 571 | \n",
" 571 | \n",
" NaN | \n",
"
\n",
" \n",
" mean | \n",
" 558 | \n",
" NaN | \n",
" NaN | \n",
" 5404.9 | \n",
" 7.2247 | \n",
" 2008.67 | \n",
" 0.512108 | \n",
" 23.5954 | \n",
" 2011.76 | \n",
" NaN | \n",
"
\n",
" \n",
" std | \n",
" 322.017 | \n",
" NaN | \n",
" NaN | \n",
" 7663.17 | \n",
" 3.21235 | \n",
" 6.19598 | \n",
" 0.500078 | \n",
" 14.142 | \n",
" 1.67494 | \n",
" NaN | \n",
"
\n",
" \n",
" min | \n",
" 1 | \n",
" NaN | \n",
" NaN | \n",
" 20 | \n",
" 1 | \n",
" 1900 | \n",
" 0 | \n",
" 1 | \n",
" 2009 | \n",
" NaN | \n",
"
\n",
" \n",
" 25% | \n",
" 279.5 | \n",
" NaN | \n",
" NaN | \n",
" 717.5 | \n",
" 4 | \n",
" 2006 | \n",
" 0 | \n",
" 13 | \n",
" 2011 | \n",
" NaN | \n",
"
\n",
" \n",
" 50% | \n",
" 558 | \n",
" NaN | \n",
" NaN | \n",
" 2325 | \n",
" 8 | \n",
" 2010 | \n",
" 1 | \n",
" 22 | \n",
" 2012 | \n",
" NaN | \n",
"
\n",
" \n",
" 75% | \n",
" 836.5 | \n",
" NaN | \n",
" NaN | \n",
" 6882.5 | \n",
" 10 | \n",
" 2013 | \n",
" 1 | \n",
" 37 | \n",
" 2013 | \n",
" NaN | \n",
"
\n",
" \n",
" max | \n",
" 1115 | \n",
" NaN | \n",
" NaN | \n",
" 75860 | \n",
" 12 | \n",
" 2015 | \n",
" 1 | \n",
" 50 | \n",
" 2015 | \n",
" NaN | \n",
"
\n",
" \n",
" counts | \n",
" 1115 | \n",
" 1115 | \n",
" 1115 | \n",
" 1112 | \n",
" 761 | \n",
" 761 | \n",
" 1115 | \n",
" 571 | \n",
" 571 | \n",
" 571 | \n",
"
\n",
" \n",
" uniques | \n",
" 1115 | \n",
" 4 | \n",
" 3 | \n",
" 654 | \n",
" 12 | \n",
" 23 | \n",
" 2 | \n",
" 24 | \n",
" 7 | \n",
" 3 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 3 | \n",
" 354 | \n",
" 354 | \n",
" 0 | \n",
" 544 | \n",
" 544 | \n",
" 544 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0.27% | \n",
" 31.75% | \n",
" 31.75% | \n",
" 0% | \n",
" 48.79% | \n",
" 48.79% | \n",
" 48.79% | \n",
"
\n",
" \n",
" types | \n",
" numeric | \n",
" categorical | \n",
" categorical | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" bool | \n",
" numeric | \n",
" numeric | \n",
" categorical | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store StoreType Assortment CompetitionDistance \\\n",
"count 1115 NaN NaN 1112 \n",
"mean 558 NaN NaN 5404.9 \n",
"std 322.017 NaN NaN 7663.17 \n",
"min 1 NaN NaN 20 \n",
"25% 279.5 NaN NaN 717.5 \n",
"50% 558 NaN NaN 2325 \n",
"75% 836.5 NaN NaN 6882.5 \n",
"max 1115 NaN NaN 75860 \n",
"counts 1115 1115 1115 1112 \n",
"uniques 1115 4 3 654 \n",
"missing 0 0 0 3 \n",
"missing_perc 0% 0% 0% 0.27% \n",
"types numeric categorical categorical numeric \n",
"\n",
" CompetitionOpenSinceMonth CompetitionOpenSinceYear Promo2 \\\n",
"count 761 761 1115 \n",
"mean 7.2247 2008.67 0.512108 \n",
"std 3.21235 6.19598 0.500078 \n",
"min 1 1900 0 \n",
"25% 4 2006 0 \n",
"50% 8 2010 1 \n",
"75% 10 2013 1 \n",
"max 12 2015 1 \n",
"counts 761 761 1115 \n",
"uniques 12 23 2 \n",
"missing 354 354 0 \n",
"missing_perc 31.75% 31.75% 0% \n",
"types numeric numeric bool \n",
"\n",
" Promo2SinceWeek Promo2SinceYear PromoInterval \n",
"count 571 571 NaN \n",
"mean 23.5954 2011.76 NaN \n",
"std 14.142 1.67494 NaN \n",
"min 1 2009 NaN \n",
"25% 13 2011 NaN \n",
"50% 22 2012 NaN \n",
"75% 37 2013 NaN \n",
"max 50 2015 NaN \n",
"counts 571 571 571 \n",
"uniques 24 7 3 \n",
"missing 544 544 544 \n",
"missing_perc 48.79% 48.79% 48.79% \n",
"types numeric numeric categorical "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" State | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 1115 | \n",
" NaN | \n",
"
\n",
" \n",
" mean | \n",
" 558 | \n",
" NaN | \n",
"
\n",
" \n",
" std | \n",
" 322.017 | \n",
" NaN | \n",
"
\n",
" \n",
" min | \n",
" 1 | \n",
" NaN | \n",
"
\n",
" \n",
" 25% | \n",
" 279.5 | \n",
" NaN | \n",
"
\n",
" \n",
" 50% | \n",
" 558 | \n",
" NaN | \n",
"
\n",
" \n",
" 75% | \n",
" 836.5 | \n",
" NaN | \n",
"
\n",
" \n",
" max | \n",
" 1115 | \n",
" NaN | \n",
"
\n",
" \n",
" counts | \n",
" 1115 | \n",
" 1115 | \n",
"
\n",
" \n",
" uniques | \n",
" 1115 | \n",
" 12 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" numeric | \n",
" categorical | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Store State\n",
"count 1115 NaN\n",
"mean 558 NaN\n",
"std 322.017 NaN\n",
"min 1 NaN\n",
"25% 279.5 NaN\n",
"50% 558 NaN\n",
"75% 836.5 NaN\n",
"max 1115 NaN\n",
"counts 1115 1115\n",
"uniques 1115 12\n",
"missing 0 0\n",
"missing_perc 0% 0%\n",
"types numeric categorical"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" StateName | \n",
" State | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 16 | \n",
" 16 | \n",
"
\n",
" \n",
" unique | \n",
" 16 | \n",
" 16 | \n",
"
\n",
" \n",
" top | \n",
" Thueringen | \n",
" HB | \n",
"
\n",
" \n",
" freq | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" counts | \n",
" 16 | \n",
" 16 | \n",
"
\n",
" \n",
" uniques | \n",
" 16 | \n",
" 16 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" unique | \n",
" unique | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" StateName State\n",
"count 16 16\n",
"unique 16 16\n",
"top Thueringen HB\n",
"freq 1 1\n",
"counts 16 16\n",
"uniques 16 16\n",
"missing 0 0\n",
"missing_perc 0% 0%\n",
"types unique unique"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" file | \n",
" week | \n",
" trend | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" NaN | \n",
" NaN | \n",
" 2072 | \n",
"
\n",
" \n",
" mean | \n",
" NaN | \n",
" NaN | \n",
" 63.8142 | \n",
"
\n",
" \n",
" std | \n",
" NaN | \n",
" NaN | \n",
" 12.6502 | \n",
"
\n",
" \n",
" min | \n",
" NaN | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 25% | \n",
" NaN | \n",
" NaN | \n",
" 55 | \n",
"
\n",
" \n",
" 50% | \n",
" NaN | \n",
" NaN | \n",
" 64 | \n",
"
\n",
" \n",
" 75% | \n",
" NaN | \n",
" NaN | \n",
" 72 | \n",
"
\n",
" \n",
" max | \n",
" NaN | \n",
" NaN | \n",
" 100 | \n",
"
\n",
" \n",
" counts | \n",
" 2072 | \n",
" 2072 | \n",
" 2072 | \n",
"
\n",
" \n",
" uniques | \n",
" 14 | \n",
" 148 | \n",
" 68 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" categorical | \n",
" categorical | \n",
" numeric | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" file week trend\n",
"count NaN NaN 2072\n",
"mean NaN NaN 63.8142\n",
"std NaN NaN 12.6502\n",
"min NaN NaN 0\n",
"25% NaN NaN 55\n",
"50% NaN NaN 64\n",
"75% NaN NaN 72\n",
"max NaN NaN 100\n",
"counts 2072 2072 2072\n",
"uniques 14 148 68\n",
"missing 0 0 0\n",
"missing_perc 0% 0% 0%\n",
"types categorical categorical numeric"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" file | \n",
" Date | \n",
" Max_TemperatureC | \n",
" Mean_TemperatureC | \n",
" Min_TemperatureC | \n",
" Dew_PointC | \n",
" MeanDew_PointC | \n",
" Min_DewpointC | \n",
" Max_Humidity | \n",
" Mean_Humidity | \n",
" ... | \n",
" Max_VisibilityKm | \n",
" Mean_VisibilityKm | \n",
" Min_VisibilitykM | \n",
" Max_Wind_SpeedKm_h | \n",
" Mean_Wind_SpeedKm_h | \n",
" Max_Gust_SpeedKm_h | \n",
" Precipitationmm | \n",
" CloudCover | \n",
" Events | \n",
" WindDirDegrees | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" NaN | \n",
" NaN | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" ... | \n",
" 15459 | \n",
" 15459 | \n",
" 15459 | \n",
" 15840 | \n",
" 15840 | \n",
" 3604 | \n",
" 15840 | \n",
" 14667 | \n",
" NaN | \n",
" 15840 | \n",
"
\n",
" \n",
" mean | \n",
" NaN | \n",
" NaN | \n",
" 14.6441 | \n",
" 10.389 | \n",
" 6.19899 | \n",
" 8.58782 | \n",
" 6.20581 | \n",
" 3.62614 | \n",
" 93.6596 | \n",
" 74.2829 | \n",
" ... | \n",
" 24.0576 | \n",
" 12.2398 | \n",
" 7.02516 | \n",
" 22.7666 | \n",
" 11.9722 | \n",
" 48.8643 | \n",
" 0.831718 | \n",
" 5.55131 | \n",
" NaN | \n",
" 175.897 | \n",
"
\n",
" \n",
" std | \n",
" NaN | \n",
" NaN | \n",
" 8.64601 | \n",
" 7.37926 | \n",
" 6.52639 | \n",
" 6.24478 | \n",
" 6.08677 | \n",
" 6.12839 | \n",
" 7.67853 | \n",
" 13.4866 | \n",
" ... | \n",
" 8.9768 | \n",
" 5.06794 | \n",
" 4.9806 | \n",
" 8.98862 | \n",
" 5.87284 | \n",
" 13.027 | \n",
" 2.51351 | \n",
" 1.68771 | \n",
" NaN | \n",
" 101.589 | \n",
"
\n",
" \n",
" min | \n",
" NaN | \n",
" NaN | \n",
" -11 | \n",
" -13 | \n",
" -15 | \n",
" -14 | \n",
" -15 | \n",
" -73 | \n",
" 44 | \n",
" 30 | \n",
" ... | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 3 | \n",
" 2 | \n",
" 21 | \n",
" 0 | \n",
" 0 | \n",
" NaN | \n",
" -1 | \n",
"
\n",
" \n",
" 25% | \n",
" NaN | \n",
" NaN | \n",
" 8 | \n",
" 4 | \n",
" 1 | \n",
" 4 | \n",
" 2 | \n",
" -1 | \n",
" 90.75 | \n",
" 65 | \n",
" ... | \n",
" 14 | \n",
" 10 | \n",
" 3 | \n",
" 16 | \n",
" 8 | \n",
" 39 | \n",
" 0 | \n",
" 5 | \n",
" NaN | \n",
" 80 | \n",
"
\n",
" \n",
" 50% | \n",
" NaN | \n",
" NaN | \n",
" 15 | \n",
" 11 | \n",
" 7 | \n",
" 9 | \n",
" 7 | \n",
" 4 | \n",
" 94 | \n",
" 76 | \n",
" ... | \n",
" 31 | \n",
" 11 | \n",
" 7 | \n",
" 21 | \n",
" 11 | \n",
" 48 | \n",
" 0 | \n",
" 6 | \n",
" NaN | \n",
" 202 | \n",
"
\n",
" \n",
" 75% | \n",
" NaN | \n",
" NaN | \n",
" 21 | \n",
" 16 | \n",
" 11 | \n",
" 13 | \n",
" 11 | \n",
" 8 | \n",
" 100 | \n",
" 85 | \n",
" ... | \n",
" 31 | \n",
" 14 | \n",
" 10 | \n",
" 27 | \n",
" 14 | \n",
" 55 | \n",
" 0.25 | \n",
" 7 | \n",
" NaN | \n",
" 256 | \n",
"
\n",
" \n",
" max | \n",
" NaN | \n",
" NaN | \n",
" 39 | \n",
" 31 | \n",
" 24 | \n",
" 25 | \n",
" 20 | \n",
" 19 | \n",
" 100 | \n",
" 100 | \n",
" ... | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
" 101 | \n",
" 53 | \n",
" 111 | \n",
" 58.93 | \n",
" 8 | \n",
" NaN | \n",
" 360 | \n",
"
\n",
" \n",
" counts | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" 15840 | \n",
" ... | \n",
" 15459 | \n",
" 15459 | \n",
" 15459 | \n",
" 15840 | \n",
" 15840 | \n",
" 3604 | \n",
" 15840 | \n",
" 14667 | \n",
" 11889 | \n",
" 15840 | \n",
"
\n",
" \n",
" uniques | \n",
" 16 | \n",
" 990 | \n",
" 51 | \n",
" 45 | \n",
" 40 | \n",
" 40 | \n",
" 36 | \n",
" 40 | \n",
" 53 | \n",
" 71 | \n",
" ... | \n",
" 24 | \n",
" 32 | \n",
" 24 | \n",
" 44 | \n",
" 29 | \n",
" 47 | \n",
" 41 | \n",
" 9 | \n",
" 21 | \n",
" 362 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" 381 | \n",
" 381 | \n",
" 381 | \n",
" 0 | \n",
" 0 | \n",
" 12236 | \n",
" 0 | \n",
" 1173 | \n",
" 3951 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" ... | \n",
" 2.41% | \n",
" 2.41% | \n",
" 2.41% | \n",
" 0% | \n",
" 0% | \n",
" 77.25% | \n",
" 0% | \n",
" 7.41% | \n",
" 24.94% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" categorical | \n",
" categorical | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" ... | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" categorical | \n",
" numeric | \n",
"
\n",
" \n",
"
\n",
"
13 rows × 24 columns
\n",
"
"
],
"text/plain": [
" file Date Max_TemperatureC Mean_TemperatureC \\\n",
"count NaN NaN 15840 15840 \n",
"mean NaN NaN 14.6441 10.389 \n",
"std NaN NaN 8.64601 7.37926 \n",
"min NaN NaN -11 -13 \n",
"25% NaN NaN 8 4 \n",
"50% NaN NaN 15 11 \n",
"75% NaN NaN 21 16 \n",
"max NaN NaN 39 31 \n",
"counts 15840 15840 15840 15840 \n",
"uniques 16 990 51 45 \n",
"missing 0 0 0 0 \n",
"missing_perc 0% 0% 0% 0% \n",
"types categorical categorical numeric numeric \n",
"\n",
" Min_TemperatureC Dew_PointC MeanDew_PointC Min_DewpointC \\\n",
"count 15840 15840 15840 15840 \n",
"mean 6.19899 8.58782 6.20581 3.62614 \n",
"std 6.52639 6.24478 6.08677 6.12839 \n",
"min -15 -14 -15 -73 \n",
"25% 1 4 2 -1 \n",
"50% 7 9 7 4 \n",
"75% 11 13 11 8 \n",
"max 24 25 20 19 \n",
"counts 15840 15840 15840 15840 \n",
"uniques 40 40 36 40 \n",
"missing 0 0 0 0 \n",
"missing_perc 0% 0% 0% 0% \n",
"types numeric numeric numeric numeric \n",
"\n",
" Max_Humidity Mean_Humidity ... Max_VisibilityKm \\\n",
"count 15840 15840 ... 15459 \n",
"mean 93.6596 74.2829 ... 24.0576 \n",
"std 7.67853 13.4866 ... 8.9768 \n",
"min 44 30 ... 0 \n",
"25% 90.75 65 ... 14 \n",
"50% 94 76 ... 31 \n",
"75% 100 85 ... 31 \n",
"max 100 100 ... 31 \n",
"counts 15840 15840 ... 15459 \n",
"uniques 53 71 ... 24 \n",
"missing 0 0 ... 381 \n",
"missing_perc 0% 0% ... 2.41% \n",
"types numeric numeric ... numeric \n",
"\n",
" Mean_VisibilityKm Min_VisibilitykM Max_Wind_SpeedKm_h \\\n",
"count 15459 15459 15840 \n",
"mean 12.2398 7.02516 22.7666 \n",
"std 5.06794 4.9806 8.98862 \n",
"min 0 0 3 \n",
"25% 10 3 16 \n",
"50% 11 7 21 \n",
"75% 14 10 27 \n",
"max 31 31 101 \n",
"counts 15459 15459 15840 \n",
"uniques 32 24 44 \n",
"missing 381 381 0 \n",
"missing_perc 2.41% 2.41% 0% \n",
"types numeric numeric numeric \n",
"\n",
" Mean_Wind_SpeedKm_h Max_Gust_SpeedKm_h Precipitationmm \\\n",
"count 15840 3604 15840 \n",
"mean 11.9722 48.8643 0.831718 \n",
"std 5.87284 13.027 2.51351 \n",
"min 2 21 0 \n",
"25% 8 39 0 \n",
"50% 11 48 0 \n",
"75% 14 55 0.25 \n",
"max 53 111 58.93 \n",
"counts 15840 3604 15840 \n",
"uniques 29 47 41 \n",
"missing 0 12236 0 \n",
"missing_perc 0% 77.25% 0% \n",
"types numeric numeric numeric \n",
"\n",
" CloudCover Events WindDirDegrees \n",
"count 14667 NaN 15840 \n",
"mean 5.55131 NaN 175.897 \n",
"std 1.68771 NaN 101.589 \n",
"min 0 NaN -1 \n",
"25% 5 NaN 80 \n",
"50% 6 NaN 202 \n",
"75% 7 NaN 256 \n",
"max 8 NaN 360 \n",
"counts 14667 11889 15840 \n",
"uniques 9 21 362 \n",
"missing 1173 3951 0 \n",
"missing_perc 7.41% 24.94% 0% \n",
"types numeric categorical numeric \n",
"\n",
"[13 rows x 24 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Id | \n",
" Store | \n",
" DayOfWeek | \n",
" Date | \n",
" Open | \n",
" Promo | \n",
" StateHoliday | \n",
" SchoolHoliday | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 41088 | \n",
" 41088 | \n",
" 41088 | \n",
" NaN | \n",
" 41077 | \n",
" 41088 | \n",
" NaN | \n",
" 41088 | \n",
"
\n",
" \n",
" mean | \n",
" 20544.5 | \n",
" 555.9 | \n",
" 3.97917 | \n",
" NaN | \n",
" 0.854322 | \n",
" 0.395833 | \n",
" NaN | \n",
" 0.443487 | \n",
"
\n",
" \n",
" std | \n",
" 11861.2 | \n",
" 320.274 | \n",
" 2.01548 | \n",
" NaN | \n",
" 0.352787 | \n",
" 0.489035 | \n",
" NaN | \n",
" 0.496802 | \n",
"
\n",
" \n",
" min | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 0 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 25% | \n",
" 10272.8 | \n",
" 279.75 | \n",
" 2 | \n",
" NaN | \n",
" 1 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 50% | \n",
" 20544.5 | \n",
" 553.5 | \n",
" 4 | \n",
" NaN | \n",
" 1 | \n",
" 0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" 75% | \n",
" 30816.2 | \n",
" 832.25 | \n",
" 6 | \n",
" NaN | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 1 | \n",
"
\n",
" \n",
" max | \n",
" 41088 | \n",
" 1115 | \n",
" 7 | \n",
" NaN | \n",
" 1 | \n",
" 1 | \n",
" NaN | \n",
" 1 | \n",
"
\n",
" \n",
" counts | \n",
" 41088 | \n",
" 41088 | \n",
" 41088 | \n",
" 41088 | \n",
" 41077 | \n",
" 41088 | \n",
" 41088 | \n",
" 41088 | \n",
"
\n",
" \n",
" uniques | \n",
" 41088 | \n",
" 856 | \n",
" 7 | \n",
" 48 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
" missing | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 11 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" missing_perc | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
" 0.03% | \n",
" 0% | \n",
" 0% | \n",
" 0% | \n",
"
\n",
" \n",
" types | \n",
" numeric | \n",
" numeric | \n",
" numeric | \n",
" categorical | \n",
" bool | \n",
" bool | \n",
" bool | \n",
" bool | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Id Store DayOfWeek Date Open Promo \\\n",
"count 41088 41088 41088 NaN 41077 41088 \n",
"mean 20544.5 555.9 3.97917 NaN 0.854322 0.395833 \n",
"std 11861.2 320.274 2.01548 NaN 0.352787 0.489035 \n",
"min 1 1 1 NaN 0 0 \n",
"25% 10272.8 279.75 2 NaN 1 0 \n",
"50% 20544.5 553.5 4 NaN 1 0 \n",
"75% 30816.2 832.25 6 NaN 1 1 \n",
"max 41088 1115 7 NaN 1 1 \n",
"counts 41088 41088 41088 41088 41077 41088 \n",
"uniques 41088 856 7 48 2 2 \n",
"missing 0 0 0 0 11 0 \n",
"missing_perc 0% 0% 0% 0% 0.03% 0% \n",
"types numeric numeric numeric categorical bool bool \n",
"\n",
" StateHoliday SchoolHoliday \n",
"count NaN 41088 \n",
"mean NaN 0.443487 \n",
"std NaN 0.496802 \n",
"min NaN 0 \n",
"25% NaN 0 \n",
"50% NaN 0 \n",
"75% NaN 1 \n",
"max NaN 1 \n",
"counts 41088 41088 \n",
"uniques 2 2 \n",
"missing 0 0 \n",
"missing_perc 0% 0% \n",
"types bool bool "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for t in tables: display(DataFrameSummary(t).summary())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Cleaning / Feature Engineering"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a structured data problem, we necessarily have to go through all the cleaning and feature engineering, even though we're using a neural network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train, store, store_states, state_names, googletrend, weather, test = tables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1017209, 41088)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train),len(test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We turn state Holidays to booleans, to make them more convenient for modeling. We can do calculations on pandas fields using notation very similar (often identical) to numpy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train.StateHoliday = train.StateHoliday!='0'\n",
"test.StateHoliday = test.StateHoliday!='0'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`join_df` is a function for joining tables on specific fields. By default, we'll be doing a left outer join of `right` on the `left` argument using the given fields for each table.\n",
"\n",
"Pandas does joins using the `merge` method. The `suffixes` argument describes the naming convention for duplicate fields. We've elected to leave the duplicate field names on the left untouched, and append a \"\\_y\" to those on the right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def join_df(left, right, left_on, right_on=None, suffix='_y'):\n",
" if right_on is None: right_on = left_on\n",
" return left.merge(right, how='left', left_on=left_on, right_on=right_on, \n",
" suffixes=(\"\", suffix))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Join weather/state names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weather = join_df(weather, state_names, \"file\", \"StateName\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pandas you can add new columns to a dataframe by simply defining it. We'll do this for googletrends by extracting dates and state names from the given data and adding those columns.\n",
"\n",
"We're also going to replace all instances of state name 'NI' to match the usage in the rest of the data: 'HB,NI'. This is a good opportunity to highlight pandas indexing. We can use `.loc[rows, cols]` to select a list of rows and a list of columns from the dataframe. In this case, we're selecting rows w/ statename 'NI' by using a boolean list `googletrend.State=='NI'` and selecting \"State\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"googletrend['Date'] = googletrend.week.str.split(' - ', expand=True)[0]\n",
"googletrend['State'] = googletrend.file.str.split('_', expand=True)[2]\n",
"googletrend.loc[googletrend.State=='NI', \"State\"] = 'HB,NI'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following extracts particular date fields from a complete datetime for the purpose of constructing categoricals.\n",
"\n",
"You should *always* consider this feature extraction step when working with date-time. Without expanding your date-time into these additional fields, you can't capture any trend/cyclical behavior as a function of time at any of these granularities. We'll add to every table with a date field."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"add_datepart(weather, \"Date\", drop=False)\n",
"add_datepart(googletrend, \"Date\", drop=False)\n",
"add_datepart(train, \"Date\", drop=False)\n",
"add_datepart(test, \"Date\", drop=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Google trends data has a special category for the whole of Germany - we'll pull that out so we can use it explicitly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trend_de = googletrend[googletrend.file == 'Rossmann_DE']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can outer join all of our data into a single dataframe. Recall that in outer joins everytime a value in the joining field on the left table does not have a corresponding value on the right table, the corresponding row in the new table has Null values for all right table fields. One way to check that all records are consistent and complete is to check for Null values post-join, as we do here.\n",
"\n",
"*Aside*: Why note just do an inner join?\n",
"If you are assuming that all records are complete and match on the field you desire, an inner join will do the same thing as an outer join. However, in the event you are wrong or a mistake is made, an outer join followed by a null-check will catch it. (Comparing before/after # of rows for inner join is equivalent, but requires keeping track of before/after row #'s. Outer join is easier.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"store = join_df(store, store_states, \"Store\")\n",
"len(store[store.State.isnull()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined = join_df(train, store, \"Store\")\n",
"joined_test = join_df(test, store, \"Store\")\n",
"len(joined[joined.StoreType.isnull()]),len(joined_test[joined_test.StoreType.isnull()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined = join_df(joined, googletrend, [\"State\",\"Year\", \"Week\"])\n",
"joined_test = join_df(joined_test, googletrend, [\"State\",\"Year\", \"Week\"])\n",
"len(joined[joined.trend.isnull()]),len(joined_test[joined_test.trend.isnull()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined = joined.merge(trend_de, 'left', [\"Year\", \"Week\"], suffixes=('', '_DE'))\n",
"joined_test = joined_test.merge(trend_de, 'left', [\"Year\", \"Week\"], suffixes=('', '_DE'))\n",
"len(joined[joined.trend_DE.isnull()]),len(joined_test[joined_test.trend_DE.isnull()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined = join_df(joined, weather, [\"State\",\"Date\"])\n",
"joined_test = join_df(joined_test, weather, [\"State\",\"Date\"])\n",
"len(joined[joined.Mean_TemperatureC.isnull()]),len(joined_test[joined_test.Mean_TemperatureC.isnull()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined, joined_test):\n",
" for c in df.columns:\n",
" if c.endswith('_y'):\n",
" if c in df.columns: df.drop(c, inplace=True, axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we'll fill in missing values to avoid complications with `NA`'s. `NA` (not available) is how Pandas indicates missing values; many models have problems when missing values are present, so it's always important to think about how to deal with them. In these cases, we are picking an arbitrary *signal value* that doesn't otherwise appear in the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined,joined_test):\n",
" df['CompetitionOpenSinceYear'] = df.CompetitionOpenSinceYear.fillna(1900).astype(np.int32)\n",
" df['CompetitionOpenSinceMonth'] = df.CompetitionOpenSinceMonth.fillna(1).astype(np.int32)\n",
" df['Promo2SinceYear'] = df.Promo2SinceYear.fillna(1900).astype(np.int32)\n",
" df['Promo2SinceWeek'] = df.Promo2SinceWeek.fillna(1).astype(np.int32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we'll extract features \"CompetitionOpenSince\" and \"CompetitionDaysOpen\". Note the use of `apply()` in mapping a function across dataframe values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined,joined_test):\n",
" df[\"CompetitionOpenSince\"] = pd.to_datetime(dict(year=df.CompetitionOpenSinceYear, \n",
" month=df.CompetitionOpenSinceMonth, day=15))\n",
" df[\"CompetitionDaysOpen\"] = df.Date.subtract(df.CompetitionOpenSince).dt.days"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll replace some erroneous / outlying data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined,joined_test):\n",
" df.loc[df.CompetitionDaysOpen<0, \"CompetitionDaysOpen\"] = 0\n",
" df.loc[df.CompetitionOpenSinceYear<1990, \"CompetitionDaysOpen\"] = 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We add \"CompetitionMonthsOpen\" field, limiting the maximum to 2 years to limit number of unique categories."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([24, 3, 19, 9, 0, 16, 17, 7, 15, 22, 11, 13, 2, 23, 12, 4, 10, 1, 14, 20, 8, 18, 6, 21, 5])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for df in (joined,joined_test):\n",
" df[\"CompetitionMonthsOpen\"] = df[\"CompetitionDaysOpen\"]//30\n",
" df.loc[df.CompetitionMonthsOpen>24, \"CompetitionMonthsOpen\"] = 24\n",
"joined.CompetitionMonthsOpen.unique()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Same process for Promo dates."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined,joined_test):\n",
" df[\"Promo2Since\"] = pd.to_datetime(df.apply(lambda x: Week(\n",
" x.Promo2SinceYear, x.Promo2SinceWeek).monday(), axis=1).astype(pd.datetime))\n",
" df[\"Promo2Days\"] = df.Date.subtract(df[\"Promo2Since\"]).dt.days"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for df in (joined,joined_test):\n",
" df.loc[df.Promo2Days<0, \"Promo2Days\"] = 0\n",
" df.loc[df.Promo2SinceYear<1990, \"Promo2Days\"] = 0\n",
" df[\"Promo2Weeks\"] = df[\"Promo2Days\"]//7\n",
" df.loc[df.Promo2Weeks<0, \"Promo2Weeks\"] = 0\n",
" df.loc[df.Promo2Weeks>25, \"Promo2Weeks\"] = 25\n",
" df.Promo2Weeks.unique()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined.to_feather(f'{PATH}joined')\n",
"joined_test.to_feather(f'{PATH}joined_test')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Durations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is common when working with time series data to extract data that explains relationships across rows as opposed to columns, e.g.:\n",
"* Running averages\n",
"* Time until next event\n",
"* Time since last event\n",
"\n",
"This is often difficult to do with most table manipulation frameworks, since they are designed to work with relationships across columns. As such, we've created a class to handle this type of data.\n",
"\n",
"We'll define a function `get_elapsed` for cumulative counting across a sorted dataframe. Given a particular field `fld` to monitor, this function will start tracking time since the last occurrence of that field. When the field is seen again, the counter is set to zero.\n",
"\n",
"Upon initialization, this will result in datetime na's until the field is encountered. This is reset every time a new store is seen. We'll see how to use this shortly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_elapsed(fld, pre):\n",
" day1 = np.timedelta64(1, 'D')\n",
" last_date = np.datetime64()\n",
" last_store = 0\n",
" res = []\n",
"\n",
" for s,v,d in zip(df.Store.values,df[fld].values, df.Date.values):\n",
" if s != last_store:\n",
" last_date = np.datetime64()\n",
" last_store = s\n",
" if v: last_date = d\n",
" res.append(((d-last_date).astype('timedelta64[D]') / day1))\n",
" df[pre+fld] = res"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll be applying this to a subset of columns:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"columns = [\"Date\", \"Store\", \"Promo\", \"StateHoliday\", \"SchoolHoliday\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#df = train[columns]\n",
"df = train[columns].append(test[columns])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's walk through an example.\n",
"\n",
"Say we're looking at School Holiday. We'll first sort by Store, then Date, and then call `add_elapsed('SchoolHoliday', 'After')`:\n",
"This will apply to each row with School Holiday:\n",
"* A applied to every row of the dataframe in order of store and date\n",
"* Will add to the dataframe the days since seeing a School Holiday\n",
"* If we sort in the other direction, this will count the days until another holiday."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fld = 'SchoolHoliday'\n",
"df = df.sort_values(['Store', 'Date'])\n",
"get_elapsed(fld, 'After')\n",
"df = df.sort_values(['Store', 'Date'], ascending=[True, False])\n",
"get_elapsed(fld, 'Before')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll do this for two more fields."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fld = 'StateHoliday'\n",
"df = df.sort_values(['Store', 'Date'])\n",
"get_elapsed(fld, 'After')\n",
"df = df.sort_values(['Store', 'Date'], ascending=[True, False])\n",
"get_elapsed(fld, 'Before')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fld = 'Promo'\n",
"df = df.sort_values(['Store', 'Date'])\n",
"get_elapsed(fld, 'After')\n",
"df = df.sort_values(['Store', 'Date'], ascending=[True, False])\n",
"get_elapsed(fld, 'Before')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're going to set the active index to Date."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = df.set_index(\"Date\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then set null values from elapsed field calculations to 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"columns = ['SchoolHoliday', 'StateHoliday', 'Promo']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for o in ['Before', 'After']:\n",
" for p in columns:\n",
" a = o+p\n",
" df[a] = df[a].fillna(0).astype(int)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we'll demonstrate window functions in pandas to calculate rolling quantities.\n",
"\n",
"Here we're sorting by date (`sort_index()`) and counting the number of events of interest (`sum()`) defined in `columns` in the following week (`rolling()`), grouped by Store (`groupby()`). We do the same in the opposite direction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bwd = df[['Store']+columns].sort_index().groupby(\"Store\").rolling(7, min_periods=1).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fwd = df[['Store']+columns].sort_index(ascending=False\n",
" ).groupby(\"Store\").rolling(7, min_periods=1).sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we want to drop the Store indices grouped together in the window function.\n",
"\n",
"Often in pandas, there is an option to do this in place. This is time and memory efficient when working with large datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bwd.drop('Store',1,inplace=True)\n",
"bwd.reset_index(inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fwd.drop('Store',1,inplace=True)\n",
"fwd.reset_index(inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.reset_index(inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll merge these values onto the df."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = df.merge(bwd, 'left', ['Date', 'Store'], suffixes=['', '_bw'])\n",
"df = df.merge(fwd, 'left', ['Date', 'Store'], suffixes=['', '_fw'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.drop(columns,1,inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Date | \n",
" Store | \n",
" AfterSchoolHoliday | \n",
" BeforeSchoolHoliday | \n",
" AfterStateHoliday | \n",
" BeforeStateHoliday | \n",
" AfterPromo | \n",
" BeforePromo | \n",
" SchoolHoliday_bw | \n",
" StateHoliday_bw | \n",
" Promo_bw | \n",
" SchoolHoliday_fw | \n",
" StateHoliday_fw | \n",
" Promo_fw | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2015-09-17 | \n",
" 1 | \n",
" 13 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" 0 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 4.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 2015-09-16 | \n",
" 1 | \n",
" 12 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" 0 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 3.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 2015-09-15 | \n",
" 1 | \n",
" 11 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" 0 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 3.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 2015-09-14 | \n",
" 1 | \n",
" 10 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" 0 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 4.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 2015-09-13 | \n",
" 1 | \n",
" 9 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" -9223372036854775808 | \n",
" 9 | \n",
" -1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 4.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Date Store AfterSchoolHoliday BeforeSchoolHoliday \\\n",
"0 2015-09-17 1 13 -9223372036854775808 \n",
"1 2015-09-16 1 12 -9223372036854775808 \n",
"2 2015-09-15 1 11 -9223372036854775808 \n",
"3 2015-09-14 1 10 -9223372036854775808 \n",
"4 2015-09-13 1 9 -9223372036854775808 \n",
"\n",
" AfterStateHoliday BeforeStateHoliday AfterPromo BeforePromo \\\n",
"0 -9223372036854775808 -9223372036854775808 0 0 \n",
"1 -9223372036854775808 -9223372036854775808 0 0 \n",
"2 -9223372036854775808 -9223372036854775808 0 0 \n",
"3 -9223372036854775808 -9223372036854775808 0 0 \n",
"4 -9223372036854775808 -9223372036854775808 9 -1 \n",
"\n",
" SchoolHoliday_bw StateHoliday_bw Promo_bw SchoolHoliday_fw \\\n",
"0 0.0 0.0 4.0 0.0 \n",
"1 0.0 0.0 3.0 0.0 \n",
"2 0.0 0.0 2.0 0.0 \n",
"3 0.0 0.0 1.0 0.0 \n",
"4 0.0 0.0 0.0 0.0 \n",
"\n",
" StateHoliday_fw Promo_fw \n",
"0 0.0 1.0 \n",
"1 0.0 2.0 \n",
"2 0.0 3.0 \n",
"3 0.0 4.0 \n",
"4 0.0 4.0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's usually a good idea to back up large tables of extracted / wrangled features before you join them onto another one, that way you can go back to it easily if you need to make changes to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.to_feather(f'{PATH}df')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jhoward/anaconda3/lib/python3.6/site-packages/numpy/lib/arraysetops.py:463: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
" mask |= (ar1 == a)\n"
]
}
],
"source": [
"df = pd.read_feather(f'{PATH}df')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df[\"Date\"] = pd.to_datetime(df.Date)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Date', 'Store', 'AfterSchoolHoliday', 'BeforeSchoolHoliday',\n",
" 'AfterStateHoliday', 'BeforeStateHoliday', 'AfterPromo', 'BeforePromo',\n",
" 'SchoolHoliday_bw', 'StateHoliday_bw', 'Promo_bw', 'SchoolHoliday_fw',\n",
" 'StateHoliday_fw', 'Promo_fw'],\n",
" dtype='object')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined = join_df(joined, df, ['Store', 'Date'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined_test = join_df(joined_test, df, ['Store', 'Date'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The authors also removed all instances where the store had zero sale / was closed. We speculate that this may have cost them a higher standing in the competition. One reason this may be the case is that a little exploratory data analysis reveals that there are often periods where stores are closed, typically for refurbishment. Before and after these periods, there are naturally spikes in sales that one might expect. By ommitting this data from their training, the authors gave up the ability to leverage information about these periods to predict this otherwise volatile behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined = joined[joined.Sales!=0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll back this up as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined.reset_index(inplace=True)\n",
"joined_test.reset_index(inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined.to_feather(f'{PATH}joined')\n",
"joined_test.to_feather(f'{PATH}joined_test')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have our final set of engineered features.\n",
"\n",
"While these steps were explicitly outlined in the paper, these are all fairly typical feature engineering steps for dealing with time series data and are practical in any similar setting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined = pd.read_feather(f'{PATH}joined')\n",
"joined_test = pd.read_feather(f'{PATH}joined_test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
"
\n",
" \n",
" \n",
" \n",
" index | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
"
\n",
" \n",
" Store | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
" 5 | \n",
"
\n",
" \n",
" DayOfWeek | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
" 5 | \n",
"
\n",
" \n",
" Date | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
" 2015-07-31 00:00:00 | \n",
"
\n",
" \n",
" Sales | \n",
" 5263 | \n",
" 6064 | \n",
" 8314 | \n",
" 13995 | \n",
" 4822 | \n",
"
\n",
" \n",
" Customers | \n",
" 555 | \n",
" 625 | \n",
" 821 | \n",
" 1498 | \n",
" 559 | \n",
"
\n",
" \n",
" Open | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" Promo | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" StateHoliday | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" SchoolHoliday | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" Year | \n",
" 2015 | \n",
" 2015 | \n",
" 2015 | \n",
" 2015 | \n",
" 2015 | \n",
"
\n",
" \n",
" Month | \n",
" 7 | \n",
" 7 | \n",
" 7 | \n",
" 7 | \n",
" 7 | \n",
"
\n",
" \n",
" Week | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
"
\n",
" \n",
" Day | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
" 31 | \n",
"
\n",
" \n",
" Dayofweek | \n",
" 4 | \n",
" 4 | \n",
" 4 | \n",
" 4 | \n",
" 4 | \n",
"
\n",
" \n",
" Dayofyear | \n",
" 212 | \n",
" 212 | \n",
" 212 | \n",
" 212 | \n",
" 212 | \n",
"
\n",
" \n",
" Is_month_end | \n",
" True | \n",
" True | \n",
" True | \n",
" True | \n",
" True | \n",
"
\n",
" \n",
" Is_month_start | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" Is_quarter_end | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" Is_quarter_start | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" Is_year_end | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" Is_year_start | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
" False | \n",
"
\n",
" \n",
" Elapsed | \n",
" 1438300800 | \n",
" 1438300800 | \n",
" 1438300800 | \n",
" 1438300800 | \n",
" 1438300800 | \n",
"
\n",
" \n",
" StoreType | \n",
" c | \n",
" a | \n",
" a | \n",
" c | \n",
" a | \n",
"
\n",
" \n",
" Assortment | \n",
" a | \n",
" a | \n",
" a | \n",
" c | \n",
" a | \n",
"
\n",
" \n",
" CompetitionDistance | \n",
" 1270 | \n",
" 570 | \n",
" 14130 | \n",
" 620 | \n",
" 29910 | \n",
"
\n",
" \n",
" CompetitionOpenSinceMonth | \n",
" 9 | \n",
" 11 | \n",
" 12 | \n",
" 9 | \n",
" 4 | \n",
"
\n",
" \n",
" CompetitionOpenSinceYear | \n",
" 2008 | \n",
" 2007 | \n",
" 2006 | \n",
" 2009 | \n",
" 2015 | \n",
"
\n",
" \n",
" Promo2 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" Promo2SinceWeek | \n",
" 1 | \n",
" 13 | \n",
" 14 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" Promo2SinceYear | \n",
" 1900 | \n",
" 2010 | \n",
" 2011 | \n",
" 1900 | \n",
" 1900 | \n",
"
\n",
" \n",
" PromoInterval | \n",
" NaN | \n",
" Jan,Apr,Jul,Oct | \n",
" Jan,Apr,Jul,Oct | \n",
" NaN | \n",
" NaN | \n",
"
\n",
" \n",
" State | \n",
" HE | \n",
" TH | \n",
" NW | \n",
" BE | \n",
" SN | \n",
"
\n",
" \n",
" file | \n",
" Rossmann_DE_HE | \n",
" Rossmann_DE_TH | \n",
" Rossmann_DE_NW | \n",
" Rossmann_DE_BE | \n",
" Rossmann_DE_SN | \n",
"
\n",
" \n",
" week | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
"
\n",
" \n",
" trend | \n",
" 85 | \n",
" 80 | \n",
" 86 | \n",
" 74 | \n",
" 82 | \n",
"
\n",
" \n",
" file_DE | \n",
" Rossmann_DE | \n",
" Rossmann_DE | \n",
" Rossmann_DE | \n",
" Rossmann_DE | \n",
" Rossmann_DE | \n",
"
\n",
" \n",
" week_DE | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
" 2015-08-02 - 2015-08-08 | \n",
"
\n",
" \n",
" trend_DE | \n",
" 83 | \n",
" 83 | \n",
" 83 | \n",
" 83 | \n",
" 83 | \n",
"
\n",
" \n",
" Date_DE | \n",
" 2015-08-02 00:00:00 | \n",
" 2015-08-02 00:00:00 | \n",
" 2015-08-02 00:00:00 | \n",
" 2015-08-02 00:00:00 | \n",
" 2015-08-02 00:00:00 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 \\\n",
"index 0 1 \n",
"Store 1 2 \n",
"DayOfWeek 5 5 \n",
"Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n",
"Sales 5263 6064 \n",
"Customers 555 625 \n",
"Open 1 1 \n",
"Promo 1 1 \n",
"StateHoliday False False \n",
"SchoolHoliday 1 1 \n",
"Year 2015 2015 \n",
"Month 7 7 \n",
"Week 31 31 \n",
"Day 31 31 \n",
"Dayofweek 4 4 \n",
"Dayofyear 212 212 \n",
"Is_month_end True True \n",
"Is_month_start False False \n",
"Is_quarter_end False False \n",
"Is_quarter_start False False \n",
"Is_year_end False False \n",
"Is_year_start False False \n",
"Elapsed 1438300800 1438300800 \n",
"StoreType c a \n",
"Assortment a a \n",
"CompetitionDistance 1270 570 \n",
"CompetitionOpenSinceMonth 9 11 \n",
"CompetitionOpenSinceYear 2008 2007 \n",
"Promo2 0 1 \n",
"Promo2SinceWeek 1 13 \n",
"Promo2SinceYear 1900 2010 \n",
"PromoInterval NaN Jan,Apr,Jul,Oct \n",
"State HE TH \n",
"file Rossmann_DE_HE Rossmann_DE_TH \n",
"week 2015-08-02 - 2015-08-08 2015-08-02 - 2015-08-08 \n",
"trend 85 80 \n",
"file_DE Rossmann_DE Rossmann_DE \n",
"week_DE 2015-08-02 - 2015-08-08 2015-08-02 - 2015-08-08 \n",
"trend_DE 83 83 \n",
"Date_DE 2015-08-02 00:00:00 2015-08-02 00:00:00 \n",
"\n",
" 2 3 \\\n",
"index 2 3 \n",
"Store 3 4 \n",
"DayOfWeek 5 5 \n",
"Date 2015-07-31 00:00:00 2015-07-31 00:00:00 \n",
"Sales 8314 13995 \n",
"Customers 821 1498 \n",
"Open 1 1 \n",
"Promo 1 1 \n",
"StateHoliday False False \n",
"SchoolHoliday 1 1 \n",
"Year 2015 2015 \n",
"Month 7 7 \n",
"Week 31 31 \n",
"Day 31 31 \n",
"Dayofweek 4 4 \n",
"Dayofyear 212 212 \n",
"Is_month_end True True \n",
"Is_month_start False False \n",
"Is_quarter_end False False \n",
"Is_quarter_start False False \n",
"Is_year_end False False \n",
"Is_year_start False False \n",
"Elapsed 1438300800 1438300800 \n",
"StoreType a c \n",
"Assortment a c \n",
"CompetitionDistance 14130 620 \n",
"CompetitionOpenSinceMonth 12 9 \n",
"CompetitionOpenSinceYear 2006 2009 \n",
"Promo2 1 0 \n",
"Promo2SinceWeek 14 1 \n",
"Promo2SinceYear 2011 1900 \n",
"PromoInterval Jan,Apr,Jul,Oct NaN \n",
"State NW BE \n",
"file Rossmann_DE_NW Rossmann_DE_BE \n",
"week 2015-08-02 - 2015-08-08 2015-08-02 - 2015-08-08 \n",
"trend 86 74 \n",
"file_DE Rossmann_DE Rossmann_DE \n",
"week_DE 2015-08-02 - 2015-08-08 2015-08-02 - 2015-08-08 \n",
"trend_DE 83 83 \n",
"Date_DE 2015-08-02 00:00:00 2015-08-02 00:00:00 \n",
"\n",
" 4 \n",
"index 4 \n",
"Store 5 \n",
"DayOfWeek 5 \n",
"Date 2015-07-31 00:00:00 \n",
"Sales 4822 \n",
"Customers 559 \n",
"Open 1 \n",
"Promo 1 \n",
"StateHoliday False \n",
"SchoolHoliday 1 \n",
"Year 2015 \n",
"Month 7 \n",
"Week 31 \n",
"Day 31 \n",
"Dayofweek 4 \n",
"Dayofyear 212 \n",
"Is_month_end True \n",
"Is_month_start False \n",
"Is_quarter_end False \n",
"Is_quarter_start False \n",
"Is_year_end False \n",
"Is_year_start False \n",
"Elapsed 1438300800 \n",
"StoreType a \n",
"Assortment a \n",
"CompetitionDistance 29910 \n",
"CompetitionOpenSinceMonth 4 \n",
"CompetitionOpenSinceYear 2015 \n",
"Promo2 0 \n",
"Promo2SinceWeek 1 \n",
"Promo2SinceYear 1900 \n",
"PromoInterval NaN \n",
"State SN \n",
"file Rossmann_DE_SN \n",
"week 2015-08-02 - 2015-08-08 \n",
"trend 82 \n",
"file_DE Rossmann_DE \n",
"week_DE 2015-08-02 - 2015-08-08 \n",
"trend_DE 83 \n",
"Date_DE 2015-08-02 00:00:00 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined.head().T.head(40)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we've engineered all our features, we need to convert to input compatible with a neural network.\n",
"\n",
"This includes converting categorical variables into contiguous integers or one-hot encodings, normalizing continuous features to standard normal, etc..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"844338"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_vars = ['Store', 'DayOfWeek', 'Year', 'Month', 'Day', 'StateHoliday', 'CompetitionMonthsOpen',\n",
" 'Promo2Weeks', 'StoreType', 'Assortment', 'PromoInterval', 'CompetitionOpenSinceYear', 'Promo2SinceYear',\n",
" 'State', 'Week', 'Events', 'Promo_fw', 'Promo_bw', 'StateHoliday_fw', 'StateHoliday_bw',\n",
" 'SchoolHoliday_fw', 'SchoolHoliday_bw']\n",
"\n",
"contin_vars = ['CompetitionDistance', 'Max_TemperatureC', 'Mean_TemperatureC', 'Min_TemperatureC',\n",
" 'Max_Humidity', 'Mean_Humidity', 'Min_Humidity', 'Max_Wind_SpeedKm_h', \n",
" 'Mean_Wind_SpeedKm_h', 'CloudCover', 'trend', 'trend_DE',\n",
" 'AfterStateHoliday', 'BeforeStateHoliday', 'Promo', 'SchoolHoliday']\n",
"\n",
"n = len(joined); n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dep = 'Sales'\n",
"joined = joined[cat_vars+contin_vars+[dep, 'Date']].copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined_test[dep] = 0\n",
"joined_test = joined_test[cat_vars+contin_vars+[dep, 'Date', 'Id']].copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for v in cat_vars: joined[v] = joined[v].astype('category').cat.as_ordered()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"apply_cats(joined_test, joined)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for v in contin_vars:\n",
" joined[v] = joined[v].fillna(0).astype('float32')\n",
" joined_test[v] = joined_test[v].fillna(0).astype('float32')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're going to run on a sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"150000"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = get_cv_idxs(n, val_pct=150000/n)\n",
"joined_samp = joined.iloc[idxs].set_index(\"Date\")\n",
"samp_size = len(joined_samp); samp_size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run on the full dataset, use this instead:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"samp_size = n\n",
"joined_samp = joined.set_index(\"Date\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now process our data..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" Year | \n",
" Month | \n",
" Day | \n",
" StateHoliday | \n",
" CompetitionMonthsOpen | \n",
" Promo2Weeks | \n",
" StoreType | \n",
" Assortment | \n",
" ... | \n",
" Max_Wind_SpeedKm_h | \n",
" Mean_Wind_SpeedKm_h | \n",
" CloudCover | \n",
" trend | \n",
" trend_DE | \n",
" AfterStateHoliday | \n",
" BeforeStateHoliday | \n",
" Promo | \n",
" SchoolHoliday | \n",
" Sales | \n",
"
\n",
" \n",
" Date | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 2014-01-08 | \n",
" 781 | \n",
" 3 | \n",
" 2014 | \n",
" 1 | \n",
" 8 | \n",
" False | \n",
" 24 | \n",
" 0 | \n",
" a | \n",
" a | \n",
" ... | \n",
" 29.0 | \n",
" 14.0 | \n",
" 8.0 | \n",
" 45.0 | \n",
" 55.0 | \n",
" 2.0 | \n",
" -100.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 7395 | \n",
"
\n",
" \n",
" 2014-11-22 | \n",
" 626 | \n",
" 6 | \n",
" 2014 | \n",
" 11 | \n",
" 22 | \n",
" False | \n",
" 12 | \n",
" 0 | \n",
" c | \n",
" c | \n",
" ... | \n",
" 23.0 | \n",
" 14.0 | \n",
" 5.0 | \n",
" 85.0 | \n",
" 84.0 | \n",
" 3.0 | \n",
" -33.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 7884 | \n",
"
\n",
" \n",
"
\n",
"
2 rows × 39 columns
\n",
"
"
],
"text/plain": [
" Store DayOfWeek Year Month Day StateHoliday CompetitionMonthsOpen \\\n",
"Date \n",
"2014-01-08 781 3 2014 1 8 False 24 \n",
"2014-11-22 626 6 2014 11 22 False 12 \n",
"\n",
" Promo2Weeks StoreType Assortment ... Max_Wind_SpeedKm_h \\\n",
"Date ... \n",
"2014-01-08 0 a a ... 29.0 \n",
"2014-11-22 0 c c ... 23.0 \n",
"\n",
" Mean_Wind_SpeedKm_h CloudCover trend trend_DE AfterStateHoliday \\\n",
"Date \n",
"2014-01-08 14.0 8.0 45.0 55.0 2.0 \n",
"2014-11-22 14.0 5.0 85.0 84.0 3.0 \n",
"\n",
" BeforeStateHoliday Promo SchoolHoliday Sales \n",
"Date \n",
"2014-01-08 -100.0 1.0 0.0 7395 \n",
"2014-11-22 -33.0 0.0 0.0 7884 \n",
"\n",
"[2 rows x 39 columns]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joined_samp.head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df, y, nas, mapper = proc_df(joined_samp, 'Sales', do_scale=True)\n",
"yl = np.log(y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined_test = joined_test.set_index(\"Date\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_test, _, nas, mapper = proc_df(joined_test, 'Sales', do_scale=True, skip_flds=['Id'],\n",
" mapper=mapper, na_dict=nas)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Store | \n",
" DayOfWeek | \n",
" Year | \n",
" Month | \n",
" Day | \n",
" StateHoliday | \n",
" CompetitionMonthsOpen | \n",
" Promo2Weeks | \n",
" StoreType | \n",
" Assortment | \n",
" ... | \n",
" Mean_Wind_SpeedKm_h | \n",
" CloudCover | \n",
" trend | \n",
" trend_DE | \n",
" AfterStateHoliday | \n",
" BeforeStateHoliday | \n",
" Promo | \n",
" SchoolHoliday | \n",
" CompetitionDistance_na | \n",
" CloudCover_na | \n",
"
\n",
" \n",
" Date | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 2014-01-08 | \n",
" 781 | \n",
" 3 | \n",
" 2 | \n",
" 1 | \n",
" 8 | \n",
" 1 | \n",
" 25 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" 0.367717 | \n",
" 1.497856 | \n",
" -1.766166 | \n",
" -1.156709 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.112686 | \n",
" -0.491495 | \n",
" -0.051057 | \n",
" -0.295991 | \n",
"
\n",
" \n",
" 2014-11-22 | \n",
" 626 | \n",
" 6 | \n",
" 2 | \n",
" 11 | \n",
" 22 | \n",
" 1 | \n",
" 13 | \n",
" 1 | \n",
" 3 | \n",
" 3 | \n",
" ... | \n",
" 0.367717 | \n",
" -0.348017 | \n",
" 1.731215 | \n",
" 1.830993 | \n",
" 0.0 | \n",
" 0.0 | \n",
" -0.898726 | \n",
" -0.491495 | \n",
" -0.051057 | \n",
" -0.295991 | \n",
"
\n",
" \n",
"
\n",
"
2 rows × 40 columns
\n",
"
"
],
"text/plain": [
" Store DayOfWeek Year Month Day StateHoliday \\\n",
"Date \n",
"2014-01-08 781 3 2 1 8 1 \n",
"2014-11-22 626 6 2 11 22 1 \n",
"\n",
" CompetitionMonthsOpen Promo2Weeks StoreType Assortment \\\n",
"Date \n",
"2014-01-08 25 1 1 1 \n",
"2014-11-22 13 1 3 3 \n",
"\n",
" ... Mean_Wind_SpeedKm_h CloudCover trend \\\n",
"Date ... \n",
"2014-01-08 ... 0.367717 1.497856 -1.766166 \n",
"2014-11-22 ... 0.367717 -0.348017 1.731215 \n",
"\n",
" trend_DE AfterStateHoliday BeforeStateHoliday Promo \\\n",
"Date \n",
"2014-01-08 -1.156709 0.0 0.0 1.112686 \n",
"2014-11-22 1.830993 0.0 0.0 -0.898726 \n",
"\n",
" SchoolHoliday CompetitionDistance_na CloudCover_na \n",
"Date \n",
"2014-01-08 -0.491495 -0.051057 -0.295991 \n",
"2014-11-22 -0.491495 -0.051057 -0.295991 \n",
"\n",
"[2 rows x 40 columns]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In time series data, cross-validation is not random. Instead, our holdout data is generally the most recent data, as it would be in real application. This issue is discussed in detail in [this post](http://www.fast.ai/2017/11/13/validation-sets/) on our web site.\n",
"\n",
"One approach is to take the last 25% of rows (sorted by date) as our validation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_ratio = 0.75\n",
"# train_ratio = 0.9\n",
"train_size = int(samp_size * train_ratio); train_size\n",
"val_idx = list(range(train_size, len(df)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An even better option for picking a validation set is using the exact same length of time period as the test set uses - this is implemented here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_idx = np.flatnonzero(\n",
" (df.index<=datetime.datetime(2014,9,17)) & (df.index>=datetime.datetime(2014,8,1)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_idx=[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DL"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're ready to put together our models.\n",
"\n",
"Root-mean-squared percent error is the metric Kaggle used for this competition."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def inv_y(a): return np.exp(a)\n",
"\n",
"def exp_rmspe(y_pred, targ):\n",
" targ = inv_y(targ)\n",
" pct_var = (targ - inv_y(y_pred))/targ\n",
" return math.sqrt((pct_var**2).mean())\n",
"\n",
"max_log_y = np.max(yl)\n",
"y_range = (0, max_log_y*1.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a ModelData object directly from out data frame."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"md = ColumnarModelData.from_data_frame(PATH, val_idx, df, yl.astype(np.float32), cat_flds=cat_vars, bs=128,\n",
" test_df=df_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some categorical variables have a lot more levels than others. Store, in particular, has over a thousand!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cat_sz = [(c, len(joined_samp[c].cat.categories)+1) for c in cat_vars]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Store', 1116),\n",
" ('DayOfWeek', 8),\n",
" ('Year', 4),\n",
" ('Month', 13),\n",
" ('Day', 32),\n",
" ('StateHoliday', 3),\n",
" ('CompetitionMonthsOpen', 26),\n",
" ('Promo2Weeks', 27),\n",
" ('StoreType', 5),\n",
" ('Assortment', 4),\n",
" ('PromoInterval', 4),\n",
" ('CompetitionOpenSinceYear', 24),\n",
" ('Promo2SinceYear', 9),\n",
" ('State', 13),\n",
" ('Week', 53),\n",
" ('Events', 22),\n",
" ('Promo_fw', 7),\n",
" ('Promo_bw', 7),\n",
" ('StateHoliday_fw', 4),\n",
" ('StateHoliday_bw', 4),\n",
" ('SchoolHoliday_fw', 9),\n",
" ('SchoolHoliday_bw', 9)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_sz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use the *cardinality* of each variable (that is, its number of unique values) to decide how large to make its *embeddings*. Each level will be associated with a vector with length defined as below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb_szs = [(c, min(50, (c+1)//2)) for _,c in cat_sz]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(1116, 50),\n",
" (8, 4),\n",
" (4, 2),\n",
" (13, 7),\n",
" (32, 16),\n",
" (3, 2),\n",
" (26, 13),\n",
" (27, 14),\n",
" (5, 3),\n",
" (4, 2),\n",
" (4, 2),\n",
" (24, 12),\n",
" (9, 5),\n",
" (13, 7),\n",
" (53, 27),\n",
" (22, 11),\n",
" (7, 4),\n",
" (7, 4),\n",
" (4, 2),\n",
" (4, 2),\n",
" (9, 5),\n",
" (9, 5)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_szs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n",
" 0.04, 1, [1000,500], [0.001,0.01], y_range=y_range)\n",
"m.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0dcee5fad2f84107975671f29d455619",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 70%|██████▉ | 611/879 [00:06<00:02, 110.07it/s, loss=0.234] \n",
" \r"
]
}
],
"source": [
"lr = 1e-3\n",
"m.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEOCAYAAABIESrBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8XXWd//HX52ZPszZJm+4tpXSDtrShpUXZN9GhjCsV\nEURFHZcZZ36O+JuZnw7ODI6z/EZxwYoI6k8QQbAisslugTYtLbSULpTuS5KmS5o0283n98c9rZc0\ny81tbk6W9/PxuI+ce7b76SHcd77ne873mLsjIiKSjEjYBYiIyMClEBERkaQpREREJGkKERERSZpC\nREREkqYQERGRpClEREQkaQoRERFJmkJERESSphAREZGkpYddQG8qLS31iRMnhl2GiMiAsWrVqhp3\nL0t2+0EVIhMnTqSysjLsMkREBgwz234q2+t0loiIJC1lLREzuwt4H1Dl7md2sPwrwHVxdUwHyty9\n1sy2AXVAFGh194pU1SkiIslLZUvkbuDKzha6+3+4+xx3nwN8DXjO3WvjVrkoWK4AERHpp1IWIu7+\nPFDb7YoxS4B7U1WLiIikRuh9ImaWS6zF8mDcbAeeMLNVZnZzOJWJiEh3+sPVWX8B/Kndqazz3H2P\nmY0AnjSzN4OWzUmCkLkZYPz48amvVkRETgi9JQJcS7tTWe6+J/hZBTwEzO9sY3df6u4V7l5RVtbz\nS52jbc4Lm6t5c9+RHm8rIjLUhRoiZlYIXAD8Nm7eMDPLPz4NXA6sS1kNwGd+vor7VuxM1UeIiAxa\nqbzE917gQqDUzHYBXwcyANz9jmC1vwSecPf6uE1HAg+Z2fH6funuj6WqzkjEmDIyn4376lL1ESIi\ng1bKQsTdlySwzt3ELgWOn7cVmJ2aqjo2bWQ+T27Yj7sThJeIiCSgP/SJhG5qeT619c3UHG0OuxQR\nkQFFIQJMK88H0CktEZEeUogQa4kAukJLRKSHFCJASV4WpXlZaomIiPSQQiQwrTyfjfsVIiIiPaEQ\nCUwtz2fT/jqibR52KSIiA4ZCJDC1PJ/GljZ21DaEXYqIyIChEAn8+Qotda6LiCRKIRKYMiIfM3hT\nnesiIglTiARyMtOYWDJMV2iJiPSAQiTOVI2hJSLSIwqROFPL89l2oJ7GlmjYpYiIDAgKkTjTyvNp\nc9i8/2jYpYiIDAgKkTga/kREpGcUInEmlAwjKz2ifhERkQQpROKkRYwpI/M0/ImISIIUIu1MLy9g\nw94juGv4ExGR7ihE2pk5uoCao81U1TWFXYqISL+nEGln5phCANbvORxyJSIi/Z9CpJ3jY2it360r\ntEREuqMQaSc/O4OJJbms36MQERHpTspCxMzuMrMqM1vXyfILzeywma0JXv8nbtmVZrbRzLaY2S2p\nqrEzM0cXsn6vTmeJiHQnlS2Ru4Eru1nnBXefE7xuBTCzNOD7wHuAGcASM5uRwjpPMmN0ATtrj3H4\nWEtffqyIyICTshBx9+eB2iQ2nQ9scfet7t4M3Acs7tXiujFzdAEAG/bqlJaISFfC7hNZaGZrzewP\nZjYzmDcG2Bm3zq5gXp+ZEYSI+kVERLqWHuJnrwYmuPtRM7sKeBiYAlgH63Z655+Z3QzcDDB+/Phe\nKWxEfjZl+Vm6zFdEpBuhtUTc/Yi7Hw2mHwUyzKyUWMtjXNyqY4E9XexnqbtXuHtFWVlZr9U3c3QB\nb6glIiLSpdBCxMzKzcyC6flBLQeAlcAUM5tkZpnAtcCyvq5v5ugCNlcd1bNFRES6kLLTWWZ2L3Ah\nUGpmu4CvAxkA7n4H8EHgc2bWChwDrvXYgFWtZvYF4HEgDbjL3denqs7OzBxdSLTN2bS/jllji/r6\n40VEBoSUhYi7L+lm+feA73Wy7FHg0VTUlagZo2Kd6+t2H1GIiIh0Iuyrs/qtCSW5FGSn8/ruQ2GX\nIiLSbylEOmFmzB5XxJqdukJLRKQzCpEuzB5bxKb9dRxrVue6iEhHFCJdmD2uiGib634REZFOKES6\nMHts7Nkia3aqX0REpCMKkS6MKMhmVGE2a3epJSIi0hGFSDdmjy3itV1qiYiIdEQh0o3Z44rYfqCB\ng/XNYZciItLvKES6caJfRK0REZGTKES6MWtcERGDV3coRERE2lOIdCMvK52p5QWs3n4w7FJERPod\nhUgC5k0o4tUdB4m2dfpYExGRIUkhkoB5E4qpb46ycV9d2KWIiPQrCpEEVEwYDsCqHTqlJSISTyGS\ngLHFOZTlZ6lfRESkHYVIAsyMeeOLWaUQERF5B4VIguZNKGZHbQNVdY1hlyIi0m8oRBI0d0IxAKu3\n634REZHjFCIJOnNMAZlpEVZtrw27FBGRfkMhkqCs9DTOGluofhERkTgKkR6omFDM67sP09iiJx2K\niEAKQ8TM7jKzKjNb18ny68zsteC13Mxmxy3bZmavm9kaM6tMVY09NX/ScFqizmrdLyIiAqS2JXI3\ncGUXy98GLnD3WcA3gaXtll/k7nPcvSJF9fVYxcThmMErW9UvIiICkJ6qHbv782Y2sYvly+PevgyM\nTVUtvaUwJ4MZowp45e0DYZciItIv9Jc+kU8Cf4h778ATZrbKzG7uakMzu9nMKs2ssrq6OqVFAiyY\nVMKrOw7R1Kp+ERGR0EPEzC4iFiJfjZt9nrvPBd4DfN7Mzu9se3df6u4V7l5RVlaW4mphwWnDaWpt\nY+1OPXddRCTUEDGzWcCdwGJ3P3GOyN33BD+rgIeA+eFUeLIFk473i+iUlohIaCFiZuOB3wDXu/um\nuPnDzCz/+DRwOdDhFV5hKMrNZOrIfF55W53rIiIp61g3s3uBC4FSM9sFfB3IAHD3O4D/A5QAPzAz\ngNbgSqyRwEPBvHTgl+7+WKrqTMa5p5Xwq5U7aYm2kZEW+hlBEZHQpPLqrCXdLP8U8KkO5m8FZp+8\nRf+xYNJw7l6+jdd2HWZeMKaWiMhQpD+jkzB/UuwhVbrUV0SGOoVIEkryspgyIk83HYrIkKcQSdKC\n04ZTua2W1mhb2KWIiIRGIZKkc08rob45yuu7db+IiAxdCpEkLZpcihm8sLkm7FJEREKjEEnS8GGZ\nnDWmkBc2p36oFRGR/kohcgrePaWU1TsOcaSxJexSRERCoRA5BedPKSPa5izfokt9RWRoUoicgrPH\nFzMsM02ntERkyFKInILM9AgLJ5fy/OZq3D3sckRE+pxC5BRdcEYpO2uP8XZNfdiliIj0OYXIKbpw\n6ggA/rihKuRKRET6nkLkFI0bnsu08nye3LA/7FJERPqcQqQXXDZjJJXbajlY3xx2KSIifUoh0gsu\nmzGSNodnNuqUlogMLQqRXnDm6EJGFmTxlE5picgQoxDpBZGIccn0kTy3sZqm1mjY5YiI9BmFSC+5\nbPpI6pujvPSW7l4XkaFDIdJLFk4uIScjTae0RGRIUYj0kuyMNM4/o5Sn3qjS3esiMmSkNETM7C4z\nqzKzdZ0sNzP7rpltMbPXzGxu3LIbzGxz8LohlXX2lkunj2TfkUbW7zkSdikiIn0i1S2Ru4Eru1j+\nHmBK8LoZ+CGAmQ0Hvg4sAOYDXzez4pRW2gsunjaCtIjx6Ot7wy5FRKRPpDRE3P15oLaLVRYDP/OY\nl4EiMxsFXAE86e617n4QeJKuw6hfKMnL4rzTS/ntmj20temUlogMfmH3iYwBdsa93xXM62x+v3fN\nnNHsPnSM1TsOhl2KiEjKhR0i1sE872L+yTswu9nMKs2ssro6/Od6XD6znOyMCA+v2R12KSIiKRd2\niOwCxsW9Hwvs6WL+Sdx9qbtXuHtFWVlZygpNVF5WOpdOH8nvX9tLS7Qt7HJERFIq7BBZBnw8uErr\nXOCwu+8FHgcuN7PioEP98mDegHDNnDEcbGjREw9FZNBLT+XOzexe4EKg1Mx2EbviKgPA3e8AHgWu\nArYADcAngmW1ZvZNYGWwq1vdvasO+n7l/DPKKMrN4OFX93DxtJFhlyMikjIpDRF3X9LNcgc+38my\nu4C7UlFXqmWmR7jqrFE8tHo39U2tDMtK6WEWEQlNQqezzOyvzawgOO30EzNbbWaXp7q4geyaOWM4\n1hLlyTc0DIqIDF6J9onc5O5HiPVNlBE77fStlFU1CFRMKGZMUY6u0hKRQS3REDl+ye1VwE/dfS0d\nX4YrgUjEuHrOaF7YXEPN0aawyxERSYlEQ2SVmT1BLEQeN7N8QNevdmPxnNFE25zfv6ZhUERkcEo0\nRD4J3AKc4+4NxK6w+kTKqhokppUXMK08n9/qlJaIDFKJhshCYKO7HzKzjwH/CBxOXVmDx+I5Y1i9\n4xA7DjSEXYqISK9LNER+CDSY2Wzg74HtwM9SVtUg8hezRwGoNSIig1KiIdIa3NOxGPiOu38HyE9d\nWYPH2OJc5k8czsNrduthVSIy6CQaInVm9jXgeuD3ZpZGcOe5dG/x2aN5q7peD6sSkUEn0RD5CNBE\n7H6RfcSGZf+PlFU1yFx15igy0kyntERk0EkoRILg+H9AoZm9D2h0d/WJJKh4WCYXnFHGsrV7aNXI\nviIyiCQ67MmHgRXAh4APA6+Y2QdTWdhg88F549h/pInnNbKviAwiiY4M+A/E7hGpAjCzMuAp4IFU\nFTbYXDJ9BKV5mdy3YqdG9hWRQSPRPpHI8QAJHOjBtgJkpEX4wLyx/PHNKqrqGsMuR0SkVyQaBI+Z\n2eNmdqOZ3Qj8ntizQKQHPlIxjmib8+AqdbCLyOCQaMf6V4ClwCxgNrDU3b+aysIGo9PK8lgwaTj3\nrdxBW5vuGRGRgS/hU1Lu/qC7/627f9ndH0plUYPZ9QsnsP1AA89uqup+ZRGRfq7LEDGzOjM70sGr\nzsx051wSrphZTnlBNj/907awSxEROWVdhoi757t7QQevfHcv6KsiB5OMtAjXL5zAC5tr2FJVF3Y5\nIiKnRFdYheDac8aRmR7hnuXbwy5FROSUKERCUJKXxeLZo3lw9S4OH2sJuxwRkaSlNETM7Eoz22hm\nW8zslg6W/18zWxO8NpnZobhl0bhly1JZZxhuWDSRhuYov67cGXYpIiJJS/SO9R4LRvr9PnAZsAtY\naWbL3P2N4+u4+5fj1v8icHbcLo65+5xU1Re2M8cUMn/icO55aRufOG8SaRE9sl5EBp5UtkTmA1vc\nfau7NwP3EXseSWeWAPemsJ5+56Z3TWRn7TEeeW1P2KWIiCQllSEyBog/V7MrmHcSM5sATAKejpud\nbWaVZvaymV2TujLDc/mMcqaV5/OdpzZrdF8RGZBSGSIdnZ/p7Dbta4EH3D0aN2+8u1cAHwX+x8wm\nd/ghZjcHYVNZXT2wRsiNRIy/uXQKW2vqWbZWrRERGXhSGSK7gHFx78cCnX1TXku7U1nuvif4uRV4\nlnf2l8Svt9TdK9y9oqys7FRr7nOXzyhnxqgCvvtHtUZEZOBJZYisBKaY2SQzyyQWFCddZWVmU4Fi\n4KW4ecVmlhVMlwLnAW+033YwON4a2XaggYde1cCMIjKwpCxE3L0V+ALwOLABuN/d15vZrWZ2ddyq\nS4D73D3+VNd0oNLM1gLPAN+Kv6prsLlsxkjOHFPA7U9voUWtEREZQOyd390DW0VFhVdWVoZdRlL+\nuGE/n7ynktvefxZL5o8PuxwRGSLMbFXQ/5wU3bHeT1w8bQRnjy/iu3/cTGNLtPsNRET6AYVIP2Fm\nfOWKqew93MgvXtaYWiIyMChE+pFFk0t595RSvv/MFuoaNaaWiPR/CpF+5u+vmMbBhhbufOHtsEsR\nEemWQqSfOWtsIVedVc6dL2zlwNGmsMsREemSQqQf+tvLpnKsJcrtT28JuxQRkS4pRPqh00fk8dEF\n4/nZS9t4Y4+eQiwi/ZdCpJ/6yuXTKM7N5J9+u462tsFzL4+IDC4KkX6qMDeDW94zjVXbD/LAql1h\nlyMi0iGFSD/2gbljqZhQzLcee5NDDc1hlyMichKFSD8WiRjfvOZMDh9r4duPbwy7HBGRkyhE+rnp\nowq4cdFE7l2xgzU7D3W/gYhIH1KIDAB/c+kURuZn85Vfr9W4WiLSryhEBoD87Az+40Oz2Fx1lG8/\nptNaItJ/KEQGiHdPKePGRRO5609v86ctNWGXIyICKEQGlK9eOY3Tyobxd/evpbZeV2uJSPgUIgNI\nTmYa3/nI2dQ2NPOFX67WM9lFJHQKkQHmrLGF/Os1Z7L8rQN86w9vhl2OiAxx6WEXID33oYpxrNt9\nmDtffJuzxhayeM6YsEsSkSFKLZEB6h/fN4P5E4fz1QdfY93uw2GXIyJDlEJkgMpIi/D96+ZSnJvJ\nZ36+Sh3tIhKKlIaImV1pZhvNbIuZ3dLB8hvNrNrM1gSvT8Utu8HMNgevG1JZ50BVlp/FHR+bR/XR\nJnW0i0goUhYiZpYGfB94DzADWGJmMzpY9VfuPid43RlsOxz4OrAAmA983cyKU1XrQDZ7XNGJjvZ/\n+u163DVsvIj0nVS2ROYDW9x9q7s3A/cBixPc9grgSXevdfeDwJPAlSmqc8D7UMU4PnfhZO5dsYPv\n6WmIItKHUhkiY4Cdce93BfPa+4CZvWZmD5jZuB5uK4G/v2Iq7587hv96chP3V+7sfgMRkV6QyhCx\nDua1P9fyO2Ciu88CngLu6cG2sRXNbjazSjOrrK6uTrrYgc7M+PcPzOLdU0r52m9e55mNVWGXJCJD\nQCpDZBcwLu79WGBP/ArufsDdm4K3PwbmJbpt3D6WunuFu1eUlZX1SuEDVUZahB9+bB7TyvP5q1+s\nZq2GjheRFEtliKwEppjZJDPLBK4FlsWvYGaj4t5eDWwIph8HLjez4qBD/fJgnnQjLyudn37iHEry\nMrnp7pVsq6kPuyQRGcRSFiLu3gp8gdiX/wbgfndfb2a3mtnVwWpfMrP1ZrYW+BJwY7BtLfBNYkG0\nErg1mCcJGJGfzT03zafNnRt+uoKao03dbyQikgQbTJeEVlRUeGVlZdhl9Burdxzkoz9+mTNG5nPv\np89lWJZGuRGRdzKzVe5ekez2umN9EJs7vpjbl8xl3e7DfPTOV6iuU4tERHqXQmSQu2zGSH74sXls\n3HeEa77/Jzbvrwu7JBEZRBQiQ8AVM8u5/zMLaY628f4fLOfFzXoyooj0DoXIEDFrbBEPf/48Rhfl\ncONPV3Dfih1hlyQig4BCZAgZU5TDA59byKLTS7nlN69z2x820NY2eC6sEJG+pxAZYvKzM7jrhgqu\nWzCeHz23lc//cjWNLdGwyxKRAUohMgSlp0X4l2vO5B/fO53H1u/jI0tfpupIY9hlicgApBAZosyM\nT737NO742Dw27avjvbe/SOU23c8pIj2jEBnirphZzkOfX8SwzDSuXfoy9yzfpmeSiEjCFCLCtPIC\nfvuFd3HBGWV8fdl6/u7+teonEZGEKEQEgMKcDH788Qq+fOkZ/ObV3Xz8Jyv03HYR6ZZCRE6IRIy/\nvnQKty85mzU7D3Hpfz/Hc5uG7jNaRKR7ChE5yV/MHs2yL57HiPwsbvzpCr6xbD2HG1rCLktE+iGF\niHRoWnkBv/mrRVy3YDw/e2kbF/7nM/z85e20RtvCLk1E+hGFiHQqNzOdf7nmLB754ruZWp7PPz28\njvfd/iLLt2jsLRGJUYhIt2aMLuDeT5/LD6+by9GmVj565yt89uer2HGgIezSRCRkekqRJMTMeM9Z\no7ho2gjufGEr33/mLZ7eWMWn3jWJT7/7NIqHZYZdooiEQE82lKTsO9zItx97k9+8upvsjAjXnzuB\nL1w8hcKcjLBLE5EeONUnGypE5JRs3FfHj55/i4de3U1eZjofPmccH184gQklw8IuTUQSoBCJoxAJ\nz7rdh1n6/FYefX0vUXcumTaCT5w3iUWTSzCzsMsTkU4oROIoRMK3/0gjv3h5O798ZQcH6puZMiKP\njy+ayMXTRjCmKCfs8kSknX4dImZ2JfAdIA24092/1W753wKfAlqBauAmd98eLIsCrwer7nD3q7v7\nPIVI/9HYEuWR1/by0z+9zfo9RzCDq84cxUcXjOfc00pIi6h1ItIf9NsQMbM0YBNwGbALWAkscfc3\n4ta5CHjF3RvM7HPAhe7+kWDZUXfP68lnKkT6H3dnw946HnltDz9/eTt1ja2MLMji6tmjuXb+eCaX\n9eg/sYj0slMNkVRe4jsf2OLuWwHM7D5gMXAiRNz9mbj1XwY+lsJ6JARmxozRBcwYXcCXLpnCHzdU\n8dCru7l7+TbufPFtrphRzqfPP42zxxURUetEZMBJZYiMAXbGvd8FLOhi/U8Cf4h7n21mlcROdX3L\n3R/u/RKlL2VnpPHeWaN476xR1Bxt4p7l27hn+TYeW7+PwpwMLp0+kvfPHaPTXSIDSCpDpKNvgQ7P\nnZnZx4AK4IK42ePdfY+ZnQY8bWavu/tbHWx7M3AzwPjx40+9aukTpXlZ/N3lU/nMBZN5fN0+/vRW\nDU+s38eDq3dRXpDN4jmjWTxnDNNH5evqLpF+LJV9IguBb7j7FcH7rwG4+23t1rsUuB24wN2rOtnX\n3cAj7v5AV5+pPpGBrbElylMb9vPQ6t08t6ma1jZndGE2F00bwSXTR7DwtFJyMtPCLlNkUOnPHevp\nxDrWLwF2E+tY/6i7r49b52zgAeBKd98cN78YaHD3JjMrBV4CFsd3yndEITJ4HDjaxJNv7OfpN6t4\ncUsNDc1RstIjLJpcwsXTR+qSYZFe0m9DBMDMrgL+h9glvne5+7+a2a1ApbsvM7OngLOAvcEmO9z9\najNbBPwIaCM2SOT/uPtPuvs8hcjg1NQa5ZWttTz9ZhXPbKxiezDw49SR+Vw8fQQXTxvB2eOKSE/T\neKIiPdWvQ6SvKUQGP3dna009T2+o4uk3q1i5rZbWNqcwJ4MLp5Zx8bQRXHBGGUW5GhBShobWaBtp\nEUu671AhEkchMvQcaWzhhU01PP1mFc9urOJAfTMRg3kTirloWqyVMnWkOudl8Pre05t5blM1P7tp\nQVJ9hgqROAqRoa2tzVm76xDPvFnF0xurWLf7CABjinK4aFqslaLOeRksGluiPL5+H3//wGtcNHUE\nd1w/L6n9KETiKEQk3v4jjbFAieucz0yLMHdCEYsml7JocgmzxxWRob4UGWAamlu5/icrWLX9IJNK\nh/GLTy1I+kIThUgchYh05njn/Itbalj+Vg3r9xzBHXIz0zhn4nAWTS7hvNNLmT6qQDc6Sr/V2BLl\nv57YyG/X7KHmaBO3vf8s3j937Cn9IdSfhz0R6Tey0tM4/4wyzj+jDIBDDc28vPUAy9+KvW77w5sA\nFOZkcO5pw7lw6ggumzGS0rysMMsWAWIPgXt8/T6eeGMff9pygPPPKOO6BeO5YmZ52KWpJSICUHWk\nkZe2HmD5lgO8uKWG3YeOETGYPqqAueOLWTi5hEWTS3TVl/S5dbsP87GfvMKhhhZyMtL43++dzvXn\nTui1/et0VhyFiPSG4yMPP/HGPlZuq2XNjkPUN0cxg1ljCnnXlFLedXoZcycUkZWuTnpJjWPNUb73\nzGZ+tnw7BTkZ/OTGCk4rzSMzvXf78BQicRQikgot0TbW7jzEC5treHFLDWt2HiLa5mRnRKiYMJx5\nE4qZN6GYOeOLKMjWM+bl1LVE2/j0zyp5dmM1l80YyS3vmZayxyYoROIoRKQv1DW2nOikf3nrATbt\nr6PNwSx2F/3cCcXMHR8LloklubpHRRLm7vxq5U5+s3o3K7bVctv7z2LJ/NQOLKsQiaMQkTDUNbaw\ndudhVm0/yKodB3l1x0HqGlsBGD4s80SgzJtQzKyxhWRn6BSYnMzd+bdHN/DjF96mNC+TL192Btct\n6L2+j87o6iyRkOVnZ8T6SaaUArGbHrdUH42FyvaDrN5+kKc27AcgPWLMHF3A2eOLOX1EHpNKhzGx\ndBijCrL1UK4h7KW3DvDPv1vPm/vquGHhBL5x9cwB04JVS0SkD9TWN/PqjoMnguW1XYc51hI9sTwr\nPcKEklwmlgxjQkkuo4tyGF2Uw7jiXMYOz1FfyyDU1ua8sfcIz22q5kfPvUXxsEw+e8FkPlwxrk/v\nVVJLRGQAGD4sk0umj+SS6SOB2BfIviONbKup5+0D9bGfNQ1srannuU3VNLW2vWP7vKx0RhdlM6ow\nh9K8LErzMxlbnMvYohwKcjIYVZjNiPwsjWTcz7RG2zja1EpdY+y1v66RFW/XsnFfHSu31Z447Tm5\nbBh33nAOk0qHhVxxzylEREIQidiJ1sai00vfsczdqa1vZvehY+w6eIydtQ3sPdzInkPH2Hekkc37\n66g52kxz9J1BEzEYPiyL0rxMyvKzKMvLoqwgixH5sYAZWRD7OaIgi9xM/a/fmcaWKPVNrew6eIyW\naBstUae1re3EdLTNaYm2EW1zWtuc1qhT39RKXVMr1XWNbKtp4GBDMztrG6hvjp60/7SIMaEkl/fN\nGs2M0QVcOn0E5QXZA+b0VXv6TRLpZ8yMkrwsSvKymDW2qMN1jrdk9h1p5HBDC3sPN7L38DFqjjZR\nXddM9dEmtlbXU13XdFLYAAzLTKMkLxY4JXlZlAzLpDAng/zsdLIz0qhvipKTGSEnM53cjDRyM9PI\nyUwjNzM9bjqN3Ix0cjLTqGtsASAzPUJORlrKWkRtbU7N0SZ2HmxgZ+0xdh1sYPehYzS1tOFAtK2D\nL/m2NlqjTsSMrIwImUFt9c2t1DdFOdYcpb65lWPNUeqaWmluPfl4Jao4N4NJpcMYXZTDwsklFOdm\nkpeVTn527FWQncGc8UWDKsQHz79EZAiJb8l0xd051NBCVV0TVXWN7D/SxP4jjRw42syB+qbYF3Jt\nA2t2HqKusYXGluS/QONlpkXIzoiQmxkLmZyMtBM/szPSyEqPkBYx2jz2l3zsL/24n9E2Wtv8HdO1\n9c0caWyhfTduaV7miZGZ0yOx/aZHjPQ0Iy0SIT1ipEWM1rY26uv/HBLDgi/38oLsE8GYF3zR52Sk\nMbY4h6yMNDLSjIy0CBlpkRP7TY+8czo3K428zPQheXGEQkRkEDMziodlUjwsk6nl+d2u39Qa+8s8\nLyudxtY2GoK/0BuCV2y6lWMt8fNayc1MJy1iNLVGaWxpo6E5SmPL8XXbONYc5VhLKw3NrdQcbTrR\nUogc/8Js2annAAAJnUlEQVSPRMhIM9KDL+rczPQTX9DH5+dnpzM8N5MRBVmMK85l3PAcxhTlamj/\nkClEROSErPS0E0O55KVFyMvSV4R0TZdyiIhI0hQiIiKSNIWIiIgkLaUhYmZXmtlGM9tiZrd0sDzL\nzH4VLH/FzCbGLftaMH+jmV2RyjpFRCQ5KQsRM0sDvg+8B5gBLDGzGe1W+yRw0N1PB/4v8O/BtjOA\na4GZwJXAD4L9iYhIP5LKlsh8YIu7b3X3ZuA+YHG7dRYD9wTTDwCXWOy2zcXAfe7e5O5vA1uC/YmI\nSD+SyhAZA+yMe78rmNfhOu7eChwGShLcVkREQpbKEOno1s32QwZ3tk4i28Z2YHazmVWaWWV1dXUP\nSxQRkVORyjuJdgHj4t6PBfZ0ss4uM0sHCoHaBLcFwN2XAksBzKzazLZ3Uk8hsZZOV7pbp7PlPZnf\nfl4pUNNNXb0tkWPRm9snun5X6/V02WA99sns41R/91Nx7KHvj/9APPZdLe+t751Te/KVu6fkRSyg\ntgKTgExgLTCz3TqfB+4Ipq8F7g+mZwbrZwXbbwXSTrGepae6TmfLezK//TygMlX/DU7lWPTm9omu\n39V6PV02WI99qo5/Xx/7MI7/QDz2XS3vL987KWuJuHurmX0BeBxIA+5y9/VmdmvwD1gG/AT4uZlt\nIdYCuTbYdr2Z3Q+8AbQCn3f3k8dU7pnf9cI6nS3vyfxE6ki1U62hp9snun5X6/V02WA99sns41R/\n93Xsk9/HoP/eGVRPNhyIzKzST+GpYpI8Hftw6fiHpzePve5YD9/SsAsYwnTsw6XjH55eO/ZqiYiI\nSNLUEhERkaQpREREJGkKERERSZpCpJ8zs2FmtsrM3hd2LUOJmU03szvM7AEz+1zY9QwlZnaNmf3Y\nzH5rZpeHXc9QY2anmdlPzOyBRNZXiKSImd1lZlVmtq7d/C6Hx+/AV4H7U1Pl4NQbx97dN7j7Z4EP\nA7oMNUG9dOwfdvdPAzcCH0lhuYNOLx3/re7+yYQ/U1dnpYaZnQ8cBX7m7mcG89KATcBlxIZ2WQks\nIXYz5m3tdnETMIvY8ATZQI27P9I31Q9svXHs3b3KzK4GbgG+5+6/7Kv6B7LeOvbBdv8F/D93X91H\n5Q94vXz8H3D3D3b3makcO2tIc/fn4x+yFTgxPD6Amd0HLHb324CTTleZ2UXAMGLPYzlmZo+6e1tK\nCx8EeuPYB/tZBiwzs98DCpEE9NLvvQHfAv6gAOmZ3vrd7wmFSN/qaIj7BZ2t7O7/AGBmNxJriShA\nktejY29mFwLvJzZ+26MprWzw69GxB74IXAoUmtnp7n5HKosbAnr6u18C/Ctwtpl9LQibTilE+lbC\nQ9y/YwX3u3u/lCGnR8fe3Z8Fnk1VMUNMT4/9d4Hvpq6cIaenx/8A8NlEd66O9b6V8BD30ut07MOj\nYx+ulB5/hUjfWglMMbNJZpZJbNTiZSHXNFTo2IdHxz5cKT3+CpEUMbN7gZeAqWa2y8w+6bFHAB8f\nHn8DseenrA+zzsFIxz48OvbhCuP46xJfERFJmloiIiKSNIWIiIgkTSEiIiJJU4iIiEjSFCIiIpI0\nhYiIiCRNISKhMbOjffAZVyc45H5vfuaFZrYoie3ONrM7g+kbzex7vV9dz5nZxPZDi3ewTpmZPdZX\nNUn/oRCRAS8Y6rpD7r7M3b+Vgs/saty5C4Eehwjwv4HbkyooZO5eDew1s/PCrkX6lkJE+gUz+4qZ\nrTSz18zsn+PmPxw82XG9md0cN/+omd1qZq8AC81sm5n9s5mtNrPXzWxasN6Jv+jN7G4z+66ZLTez\nrWb2wWB+xMx+EHzGI2b26PFl7Wp81sz+zcyeA/7azP7CzF4xs1fN7CkzGxkMw/1Z4MtmtsbM3h38\nlf5g8O9b2dEXrZnlA7PcfW0HyyaY2R+DY/NHMxsfzJ9sZi8H+7y1o5adxZ6M+XszW2tm68zsI8H8\nc4LjsNbMVphZftDieCE4hqs7ak2ZWZqZ/Ufcf6vPxC1+GLiuw//AMni5u156hfICjgY/LweWEhtt\nNAI8ApwfLBse/MwB1gElwXsHPhy3r23AF4PpvwLuDKZvJPZQKYC7gV8HnzGD2DMWAD5IbLj3CFAO\nHAQ+2EG9zwI/iHtfzJ9HffgU8F/B9DeA/xW33i+BdwXT44ENHez7IuDBuPfxdf8OuCGYvgl4OJh+\nBFgSTH/2+PFst98PAD+Oe18IZAJbgXOCeQXERvTOBbKDeVOAymB6IrAumL4Z+MdgOguoBCYF78cA\nr4f9e6VX3740FLz0B5cHr1eD93nEvsSeB75kZn8ZzB8XzD8ARIEH2+3nN8HPVcSeBdKRhz32XJY3\nzGxkMO9dwK+D+fvM7Jkuav1V3PRY4FdmNorYF/PbnWxzKTDD7MSI3AVmlu/udXHrjAKqO9l+Ydy/\n5+fAt+PmXxNM/xL4zw62fR34TzP7d+ARd3/BzM4C9rr7SgB3PwKxVgvwPTObQ+z4ntHB/i4HZsW1\n1AqJ/Td5G6gCRnfyb5BBSiEi/YEBt7n7j94xM/ZgqEuBhe7eYGbPEntUMECju0fb7acp+Bml89/t\nprhpa/czEfVx07cD/+3uy4Jav9HJNhFi/4ZjXez3GH/+t3Un4QHv3H2Tmc0DrgJuM7MniJ126mgf\nXwb2A7ODmhs7WMeItfge72BZNrF/hwwh6hOR/uBx4CYzywMwszFmNoLYX7kHgwCZBpybos9/EfhA\n0DcykljHeCIKgd3B9A1x8+uA/Lj3TxAbRRWA4C/99jYAp3fyOcuJDd8NsT6HF4Ppl4mdriJu+TuY\n2Wigwd1/QaylMhd4ExhtZucE6+QHFwoUEmuhtAHXE3sGd3uPA58zs4xg2zOCFgzEWi5dXsUlg49C\nRELn7k8QOx3zkpm9DjxA7Ev4MSDdzF4DvknsSzMVHiT24J51wI+AV4DDCWz3DeDXZvYCUBM3/3fA\nXx7vWAe+BFQEHdFv0MFT49z9TWKPg81vvyzY/hPBcbge+Otg/t8Af2tmK4idDuuo5rOAFWa2BvgH\n4F/cvRn4CHC7ma0FniTWivgBcIOZvUwsEOo72N+dwBvA6uCy3x/x51bfRcDvO9hGBjENBS8CmFme\nux+12POlVwDnufu+Pq7hy0Cdu9+Z4Pq5wDF3dzO7llgn++KUFtl1Pc8Di939YFg1SN9Tn4hIzCNm\nVkSsg/ybfR0ggR8CH+rB+vOIdYQbcIjYlVuhMLMyYv1DCpAhRi0RERFJmvpEREQkaQoRERFJmkJE\nRESSphAREZGkKURERCRpChEREUna/wdBgdlE7WADOAAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m.sched.plot(100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sample"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n",
" 0.04, 1, [1000,500], [0.001,0.01], y_range=y_range)\n",
"lr = 1e-3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2127e0eaa44d4189b569de2c6544e6a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.02479 0.02205 0.19309] \n",
"[ 1. 0.02044 0.01751 0.18301] \n",
"[ 2. 0.01598 0.01571 0.17248] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e03e0e29fbdb4098b8297a41f50a9054",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.01258 0.01278 0.16 ] \n",
"[ 1. 0.01147 0.01214 0.15758] \n",
"[ 2. 0.01157 0.01157 0.15585] \n",
"[ 3. 0.00984 0.01124 0.15251] \n",
"[ 4. 0.00946 0.01094 0.15197] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 5, metrics=[exp_rmspe], cycle_len=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1fe2c401969c45a39dc02c2c7c86291d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.01179 0.01242 0.15512] \n",
"[ 1. 0.00921 0.01098 0.15003] \n",
"[ 2. 0.00771 0.01031 0.14431] \n",
"[ 3. 0.00632 0.01016 0.14358] \n",
"[ 4. 0.01003 0.01305 0.16574] \n",
"[ 5. 0.00827 0.01087 0.14937] \n",
"[ 6. 0.00628 0.01025 0.14506] \n",
"[ 7. 0.0053 0.01 0.14449] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 2, metrics=[exp_rmspe], cycle_len=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### All"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n",
" 0.04, 1, [1000,500], [0.001,0.01], y_range=y_range)\n",
"lr = 1e-3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "15d0c7d5e9634030a624fb91c90b081c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.01456 0.01544 0.1148 ] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 1, metrics=[exp_rmspe])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "58bd52a856754c359eae9fa25f951dd6",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type HBox
.
\n",
"\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or NBViewer),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.01418 0.02066 0.12765] \n",
"[ 1. 0.01081 0.01276 0.11221] \n",
"[ 2. 0.00976 0.01233 0.10987] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "738899f6b0574b50be6d3f6efff9e578",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type HBox
.
\n",
"\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or NBViewer),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.00801 0.01081 0.09899] \n",
"[ 1. 0.00714 0.01083 0.09846] \n",
"[ 2. 0.00707 0.01088 0.09878] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe], cycle_len=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = md.get_learner(emb_szs, len(df.columns)-len(cat_vars),\n",
" 0.04, 1, [1000,500], [0.001,0.01], y_range=y_range)\n",
"lr = 1e-3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74827eff47794b5eb8d09982d30202fe",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type HBox
.
\n",
"\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or NBViewer),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.01413 0.0063 0.07628] \n",
"[ 1. 0.01022 0.00859 0.08851] \n",
"[ 2. 0.00932 0.00001 0.00243] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "07f0b0fea31c4afa8c6770085d233230",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type HBox
.
\n",
"\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or NBViewer),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.00748 0. 0.00167] \n",
"[ 1. 0.00717 0.00009 0.00947] \n",
"[ 2. 0.00643 0.00013 0.01147] \n",
"\n"
]
}
],
"source": [
"m.fit(lr, 3, metrics=[exp_rmspe], cycle_len=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m.save('val0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m.load('val0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x,y=m.predict_with_targs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.01147316926177568"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exp_rmspe(x,y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_test=m.predict(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_test = np.exp(pred_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined_test['Sales']=pred_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"csv_fn=f'{PATH}tmp/sub.csv'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joined_test[['Id','Sales']].to_csv(csv_fn, index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"data/rossmann/tmp/sub.csv
"
],
"text/plain": [
"/home/ubuntu/fastai/courses/dl1/data/rossmann/tmp/sub.csv"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FileLink(csv_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestRegressor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"((val,trn), (y_val,y_trn)) = split_by_idx(val_idx, df.values, yl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = RandomForestRegressor(n_estimators=40, max_features=0.99, min_samples_leaf=2,\n",
" n_jobs=-1, oob_score=True)\n",
"m.fit(trn, y_trn);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.98086411192483902,\n",
" 0.92614447508562714,\n",
" 0.9193358549649463,\n",
" 0.11557443993375387)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds = m.predict(val)\n",
"m.score(trn, y_trn), m.score(val, y_val), m.oob_score_, exp_rmspe(preds, y_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}