{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scikit-Learn Train-Test Split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook explains how to generate a train-test split from `scikit-learn` to allow validation of machine learning models with out of sample data.\n", "\n", "This notebook will use hourly weather data for multiple weather stations (`origin`) for flights from New York airports in 2013. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Packages" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial uses:\n", "* [pandas](https://pandas.pydata.org/docs/)\n", "* [statsmodels](https://www.statsmodels.org/stable/index.html)\n", " * [statsmodels.api](https://www.statsmodels.org/stable/api.html#statsmodels-api)\n", "* [scikit-learn](https://scikit-learn.org/stable/)\n", " * [sklearn.model_selection](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import statsmodels.api as sm\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reading the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data is from `rdatasets` imported using the Python package `statsmodels`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 26115 entries, 0 to 26114\n", "Data columns (total 15 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 origin 26115 non-null object \n", " 1 year 26115 non-null int64 \n", " 2 month 26115 non-null int64 \n", " 3 day 26115 non-null int64 \n", " 4 hour 26115 non-null int64 \n", " 5 temp 26114 non-null float64\n", " 6 dewp 26114 non-null float64\n", " 7 humid 26114 non-null float64\n", " 8 wind_dir 25655 non-null float64\n", " 9 wind_speed 26111 non-null float64\n", " 10 wind_gust 5337 non-null float64\n", " 11 precip 26115 non-null float64\n", " 12 pressure 23386 non-null float64\n", " 13 visib 26115 non-null float64\n", " 14 time_hour 26115 non-null object \n", "dtypes: float64(9), int64(4), object(2)\n", "memory usage: 3.0+ MB\n" ] } ], "source": [ "df = sm.datasets.get_rdataset('weather', 'nycflights13').data\n", "df.info()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['EWR', 'JFK', 'LGA'], dtype=object)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.origin.unique()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fix dates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**time_hour** contains the hour of the observation as a string. Convert it to a datetime as **observation_time**. **year**, **month**, **day** and **hour** are duplicates and can be dropped from the dataframe." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
origintempdewphumidwind_dirwind_speedwind_gustprecippressurevisibobservation_time
0EWR39.0226.0659.37270.010.35702NaN0.01012.010.02013-01-01 01:00:00
1EWR39.0226.9661.63250.08.05546NaN0.01012.310.02013-01-01 02:00:00
2EWR39.0228.0464.43240.011.50780NaN0.01012.510.02013-01-01 03:00:00
3EWR39.9228.0462.21250.012.65858NaN0.01012.210.02013-01-01 04:00:00
4EWR39.0228.0464.43260.012.65858NaN0.01011.910.02013-01-01 05:00:00
\n", "
" ], "text/plain": [ " origin temp dewp humid wind_dir wind_speed wind_gust precip \\\n", "0 EWR 39.02 26.06 59.37 270.0 10.35702 NaN 0.0 \n", "1 EWR 39.02 26.96 61.63 250.0 8.05546 NaN 0.0 \n", "2 EWR 39.02 28.04 64.43 240.0 11.50780 NaN 0.0 \n", "3 EWR 39.92 28.04 62.21 250.0 12.65858 NaN 0.0 \n", "4 EWR 39.02 28.04 64.43 260.0 12.65858 NaN 0.0 \n", "\n", " pressure visib observation_time \n", "0 1012.0 10.0 2013-01-01 01:00:00 \n", "1 1012.3 10.0 2013-01-01 02:00:00 \n", "2 1012.5 10.0 2013-01-01 03:00:00 \n", "3 1012.2 10.0 2013-01-01 04:00:00 \n", "4 1011.9 10.0 2013-01-01 05:00:00 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['observation_time'] = pd.to_datetime(df.time_hour)\n", "df.drop(columns=['year', 'month', 'day', 'hour', 'time_hour'], inplace=True)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train-test splitting" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
origintempdewphumidwind_dirwind_speedwind_gustprecippressurevisibobservation_time
9030JFK53.0633.9848.16340.09.20624NaN0.01021.610.02013-01-14 17:00:00
22499LGA75.9264.9468.78180.010.3570218.412480.01014.310.02013-08-01 09:00:00
6287EWR78.0855.9446.49160.08.05546NaN0.01017.010.02013-09-20 13:00:00
15793JFK48.0233.9858.07310.013.80936NaN0.01007.610.02013-10-23 22:00:00
11971JFK64.9439.9239.79300.010.3570221.864820.01018.310.02013-05-17 10:00:00
....................................
3598EWR73.9464.9473.49220.06.90468NaN0.01019.010.02013-05-31 04:00:00
4973EWR80.9662.9654.35170.011.5078019.563260.01016.610.02013-07-27 13:00:00
6147EWR64.9446.0450.32290.016.1109221.864820.01017.310.02013-09-14 17:00:00
15586JFK53.0651.0892.9640.03.45234NaN0.01023.810.02013-10-15 07:00:00
9050JFK39.0224.9856.77350.09.20624NaN0.01025.310.02013-01-15 13:00:00
\n", "

20892 rows × 11 columns

\n", "
" ], "text/plain": [ " origin temp dewp humid wind_dir wind_speed wind_gust precip \\\n", "9030 JFK 53.06 33.98 48.16 340.0 9.20624 NaN 0.0 \n", "22499 LGA 75.92 64.94 68.78 180.0 10.35702 18.41248 0.0 \n", "6287 EWR 78.08 55.94 46.49 160.0 8.05546 NaN 0.0 \n", "15793 JFK 48.02 33.98 58.07 310.0 13.80936 NaN 0.0 \n", "11971 JFK 64.94 39.92 39.79 300.0 10.35702 21.86482 0.0 \n", "... ... ... ... ... ... ... ... ... \n", "3598 EWR 73.94 64.94 73.49 220.0 6.90468 NaN 0.0 \n", "4973 EWR 80.96 62.96 54.35 170.0 11.50780 19.56326 0.0 \n", "6147 EWR 64.94 46.04 50.32 290.0 16.11092 21.86482 0.0 \n", "15586 JFK 53.06 51.08 92.96 40.0 3.45234 NaN 0.0 \n", "9050 JFK 39.02 24.98 56.77 350.0 9.20624 NaN 0.0 \n", "\n", " pressure visib observation_time \n", "9030 1021.6 10.0 2013-01-14 17:00:00 \n", "22499 1014.3 10.0 2013-08-01 09:00:00 \n", "6287 1017.0 10.0 2013-09-20 13:00:00 \n", "15793 1007.6 10.0 2013-10-23 22:00:00 \n", "11971 1018.3 10.0 2013-05-17 10:00:00 \n", "... ... ... ... \n", "3598 1019.0 10.0 2013-05-31 04:00:00 \n", "4973 1016.6 10.0 2013-07-27 13:00:00 \n", "6147 1017.3 10.0 2013-09-14 17:00:00 \n", "15586 1023.8 10.0 2013-10-15 07:00:00 \n", "9050 1025.3 10.0 2013-01-15 13:00:00 \n", "\n", "[20892 rows x 11 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df, test_df = train_test_split(df, test_size=.2)\n", "train_df" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: ['JFK' 'LGA' 'EWR']\n", "Test: ['LGA' 'JFK' 'EWR']\n", "Train: 2013-01-01 01:00:00 2013-12-30 18:00:00\n", "Test: 2013-01-01 02:00:00 2013-12-30 18:00:00\n" ] } ], "source": [ "print(\"Train:\", train_df.origin.unique())\n", "print(\"Test:\", test_df.origin.unique())\n", "print(\"Train:\", train_df.observation_time.min(), train_df.observation_time.max())\n", "print(\"Test:\", test_df.observation_time.min(), test_df.observation_time.max())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }