{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Transfer Learning on Specific Counties.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "vHa7we1aSUVt", "colab_type": "text" }, "source": [ "## Examining Transfer Effects" ] }, { "cell_type": "code", "metadata": { "id": "T63Nnp4iSSQD", "colab_type": "code", "colab": {} }, "source": [ "import os\n", "import pandas as pd\n", "from google.colab import auth\n", "from datetime import datetime\n", "auth.authenticate_user()\n", "!gcloud source repos clone github_aistream-peelout_flow-forecast --project=gmap-997\n", "os.chdir('/content/github_aistream-peelout_flow-forecast')\n", "!git checkout -t origin/covid_fixes\n", "!python setup.py develop\n", "!pip install -r requirements.txt\n", "!mkdir data\n", "from flood_forecast.trainer import train_function\n", "!pip install git+https://github.com/CoronaWhy/task-geo.git\n", "!wandb login" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "uWwx1RxjX1k2", "colab_type": "code", "colab": {} }, "source": [ "# Pretrained solar data\n", "!mkdir weights\n", "!gsutil cp -r gs://coronaviruspublicdata/pretrained/model_save weights/" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bjevk3exSpp_", "colab_type": "code", "colab": {} }, "source": [ "def make_config_file(file_path, df_len, weight_path=None):\n", " run = wandb.init(project=\"pretrain-counties\")\n", " wandb_config = wandb.config\n", " train_number = df_len * .7\n", " validation_number = df_len *.9\n", " config_default={ \n", " \"model_name\": \"MultiAttnHeadSimple\",\n", " \"model_type\": \"PyTorch\",\n", " \"model_params\": {\n", " \"number_time_series\":3,\n", " \"seq_len\":wandb_config[\"forecast_history\"], \n", " \"output_seq_len\":wandb_config[\"out_seq_length\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"]\n", " },\n", " \"weight_path_add\":{\n", " \"excluded_layers\":[\"last_layer.weight\", \"last_layer.bias\"]\n", " },\n", " \"dataset_params\":\n", " { \"class\": \"default\",\n", " \"training_path\": file_path,\n", " \"validation_path\": file_path,\n", " \"test_path\": file_path,\n", " \"batch_size\":wandb_config[\"batch_size\"],\n", " \"forecast_history\":wandb_config[\"forecast_history\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"],\n", " \"train_end\": int(train_number),\n", " \"valid_start\":int(train_number+1),\n", " \"valid_end\": int(validation_number),\n", " \"target_col\": [\"new_cases\"],\n", " \"relevant_cols\": [\"new_cases\", \"month\", \"weekday\"],\n", " \"scaler\": \"StandardScaler\", \n", " \"interpolate\": False\n", " },\n", " \"training_params\":\n", " {\n", " \"criterion\":\"MSE\",\n", " \"optimizer\": \"Adam\",\n", " \"optim_params\":\n", " {\n", "\n", " },\n", " \"lr\": wandb_config[\"lr\"],\n", " \"epochs\": 10,\n", " \"batch_size\":wandb_config[\"batch_size\"]\n", " \n", " },\n", " \"GCS\": False,\n", " \n", " \"sweep\":True,\n", " \"wandb\":False,\n", " \"forward_params\":{},\n", " \"metrics\":[\"MSE\"],\n", " \"inference_params\":\n", " { \n", " \"datetime_start\":\"2020-04-21\",\n", " \"hours_to_forecast\":10, \n", " \"test_csv_path\":file_path,\n", " \"decoder_params\":{\n", " \"decoder_function\": \"simple_decode\", \n", " \"unsqueeze_dim\": 1\n", " },\n", " \"dataset_params\":{\n", " \"file_path\": file_path,\n", " \"forecast_history\":wandb_config[\"forecast_history\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"],\n", " \"relevant_cols\": [\"new_cases\", \"month\", \"weekday\"],\n", " \"target_col\": [\"new_cases\"],\n", " \"scaling\": \"StandardScaler\",\n", " \"interpolate_param\": False\n", " }\n", " }\n", " }\n", " if weight_path: \n", " config_default[\"weight_path\"] = weight_path\n", " wandb.config.update(config_default)\n", " return config_default\n", "\n", "sweep_config = {\n", " \"name\": \"Default sweep\",\n", " \"method\": \"grid\",\n", " \"parameters\": {\n", " \"batch_size\": {\n", " \"values\": [2, 3]\n", " },\n", " \"lr\":{\n", " \"values\":[0.001, 0.01]\n", " },\n", " \"forecast_history\":{\n", " \"values\":[1, 2, 3]\n", " },\n", " \"out_seq_length\":{\n", " \"values\":[1, 2, 3]\n", " }\n", " }\n", "}" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Bf549BXjdqDv", "colab_type": "code", "colab": {} }, "source": [ "def format_corona_data(region_df:pd.DataFrame, region_name:str):\n", " \"\"\"\n", " Format data for a specific region into \n", " a format that can be used with flow forecast. \n", " \"\"\"\n", " if region_name == 'county':\n", " region_name = region_df['full_county'].iloc[0]\n", " else:\n", " region_name = region_df['state'].iloc[0]\n", " #else:\n", " #region_name = region_df['country'].iloc[0]\n", " print(region_name)\n", " region_df['datetime'] = region_df['date']\n", " region_df['precip'] = 0\n", " region_df['temp'] = 0\n", " region_df = region_df.fillna(0)\n", " region_df['new_cases'] = region_df['cases'].diff()\n", " region_df.iloc[0]['new_cases'] = 0\n", " region_df= region_df.fillna(0)\n", " region_df.to_csv(region_name+\".csv\")\n", " return region_df, len(region_df), region_name+\".csv\"\n", "\n", "def loop_through_geo_codes(df, column='full_county'):\n", " df_county_list = []\n", " df['full_county'] = df['state'] + \"_\" + df['county'] \n", " for code in df['full_county'].unique():\n", " mask = df['full_county'] == code\n", " df_code = df[mask]\n", " ts_count = len(df_code)\n", " if ts_count > 60:\n", " df_county_list.append(df_code)\n", " return df_county_list \n", "\n", "def fetch_time_series() -> pd.DataFrame:\n", " \"\"\"Fetch raw time series data from coronadatascraper.com\n", " Returns:\n", " pd.DataFrame: raw timeseries data at county/sub-region level\n", " \"\"\"\n", " if 1==1:\n", " url = \"https://coronadatascraper.com/timeseries.csv\"\n", " urllib.request.urlretrieve(url, \"timeseries.csv\")\n", "\n", " time_series_df = pd.read_csv(\"timeseries.csv\")\n", " return time_series_df" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9uWO6ZxlepDP", "colab_type": "code", "outputId": "424840ab-835f-4cda-fc87-acc67dd063b2", "colab": { "base_uri": "https://localhost:8080/", "height": 989 } }, "source": [ "import urllib \n", "df = fetch_time_series()\n", "df['month'] = pd.to_datetime(df['date']).map(lambda x: x.month)\n", "df['weekday'] = pd.to_datetime(df['date']).map(lambda x: x.weekday())\n", "df['year'] = pd.to_datetime(df['date']).map(lambda x: x.year)\n", "df_list = loop_through_geo_codes(df)\n", "region_df, full_len, file_path = format_corona_data(df_list[9], 'county')\n", "region_df.head()" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py:2822: DtypeWarning: Columns (2) have mixed types.Specify dtype option on import or set low_memory=False.\n", " if self.run_code(code, result):\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Washington, D.C._District of Columbia\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:13: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " del sys.path[0]\n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:14: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " \n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:15: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " from ipykernel import kernelapp as app\n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:18: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "
| \n", " | name | \n", "level | \n", "city | \n", "county | \n", "state | \n", "country | \n", "population | \n", "lat | \n", "long | \n", "url | \n", "aggregate | \n", "tz | \n", "cases | \n", "deaths | \n", "recovered | \n", "active | \n", "tested | \n", "hospitalized | \n", "discharged | \n", "icu | \n", "growthFactor | \n", "date | \n", "month | \n", "weekday | \n", "year | \n", "full_county | \n", "datetime | \n", "precip | \n", "temp | \n", "new_cases | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 13412 | \n", "District of Columbia, Washington, D.C., United... | \n", "county | \n", "0 | \n", "District of Columbia | \n", "Washington, D.C. | \n", "United States | \n", "705749.0 | \n", "38.894 | \n", "-77.0145 | \n", "https://coronavirus.dc.gov/page/coronavirus-data | \n", "0 | \n", "America/New_York | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2020-03-07 | \n", "3 | \n", "5 | \n", "2020 | \n", "Washington, D.C._District of Columbia | \n", "2020-03-07 | \n", "0 | \n", "0 | \n", "0.0 | \n", "
| 13413 | \n", "District of Columbia, Washington, D.C., United... | \n", "county | \n", "0 | \n", "District of Columbia | \n", "Washington, D.C. | \n", "United States | \n", "705749.0 | \n", "38.894 | \n", "-77.0145 | \n", "https://coronavirus.dc.gov/page/coronavirus-data | \n", "0 | \n", "America/New_York | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "2020-03-08 | \n", "3 | \n", "6 | \n", "2020 | \n", "Washington, D.C._District of Columbia | \n", "2020-03-08 | \n", "0 | \n", "0 | \n", "0.0 | \n", "
| 13414 | \n", "District of Columbia, Washington, D.C., United... | \n", "county | \n", "0 | \n", "District of Columbia | \n", "Washington, D.C. | \n", "United States | \n", "705749.0 | \n", "38.894 | \n", "-77.0145 | \n", "https://coronavirus.dc.gov/page/coronavirus-data | \n", "0 | \n", "America/New_York | \n", "4.0 | \n", "0.0 | \n", "0.0 | \n", "4.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "2020-03-09 | \n", "3 | \n", "0 | \n", "2020 | \n", "Washington, D.C._District of Columbia | \n", "2020-03-09 | \n", "0 | \n", "0 | \n", "2.0 | \n", "
| 13415 | \n", "District of Columbia, Washington, D.C., United... | \n", "county | \n", "0 | \n", "District of Columbia | \n", "Washington, D.C. | \n", "United States | \n", "705749.0 | \n", "38.894 | \n", "-77.0145 | \n", "https://coronavirus.dc.gov/page/coronavirus-data | \n", "0 | \n", "America/New_York | \n", "4.0 | \n", "0.0 | \n", "0.0 | \n", "4.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "2020-03-10 | \n", "3 | \n", "1 | \n", "2020 | \n", "Washington, D.C._District of Columbia | \n", "2020-03-10 | \n", "0 | \n", "0 | \n", "0.0 | \n", "
| 13416 | \n", "District of Columbia, Washington, D.C., United... | \n", "county | \n", "0 | \n", "District of Columbia | \n", "Washington, D.C. | \n", "United States | \n", "705749.0 | \n", "38.894 | \n", "-77.0145 | \n", "https://coronavirus.dc.gov/page/coronavirus-data | \n", "0 | \n", "America/New_York | \n", "10.0 | \n", "0.0 | \n", "0.0 | \n", "10.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.5 | \n", "2020-03-11 | \n", "3 | \n", "2 | \n", "2020 | \n", "Washington, D.C._District of Columbia | \n", "2020-03-11 | \n", "0 | \n", "0 | \n", "6.0 | \n", "