{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Transfer Learning on Specific Counties.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "vHa7we1aSUVt", "colab_type": "text" }, "source": [ "## Examining Transfer Effects" ] }, { "cell_type": "code", "metadata": { "id": "T63Nnp4iSSQD", "colab_type": "code", "colab": {} }, "source": [ "import os\n", "import pandas as pd\n", "from google.colab import auth\n", "from datetime import datetime\n", "auth.authenticate_user()\n", "!gcloud source repos clone github_aistream-peelout_flow-forecast --project=gmap-997\n", "os.chdir('/content/github_aistream-peelout_flow-forecast')\n", "!git checkout -t origin/covid_fixes\n", "!python setup.py develop\n", "!pip install -r requirements.txt\n", "!mkdir data\n", "from flood_forecast.trainer import train_function\n", "!pip install git+https://github.com/CoronaWhy/task-geo.git\n", "!wandb login" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "uWwx1RxjX1k2", "colab_type": "code", "colab": {} }, "source": [ "# Pretrained solar data\n", "!mkdir weights\n", "!gsutil cp -r gs://coronaviruspublicdata/pretrained/model_save weights/" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bjevk3exSpp_", "colab_type": "code", "colab": {} }, "source": [ "def make_config_file(file_path, df_len, weight_path=None):\n", " run = wandb.init(project=\"pretrain-counties\")\n", " wandb_config = wandb.config\n", " train_number = df_len * .7\n", " validation_number = df_len *.9\n", " config_default={ \n", " \"model_name\": \"MultiAttnHeadSimple\",\n", " \"model_type\": \"PyTorch\",\n", " \"model_params\": {\n", " \"number_time_series\":3,\n", " \"seq_len\":wandb_config[\"forecast_history\"], \n", " \"output_seq_len\":wandb_config[\"out_seq_length\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"]\n", " },\n", " \"weight_path_add\":{\n", " \"excluded_layers\":[\"last_layer.weight\", \"last_layer.bias\"]\n", " },\n", " \"dataset_params\":\n", " { \"class\": \"default\",\n", " \"training_path\": file_path,\n", " \"validation_path\": file_path,\n", " \"test_path\": file_path,\n", " \"batch_size\":wandb_config[\"batch_size\"],\n", " \"forecast_history\":wandb_config[\"forecast_history\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"],\n", " \"train_end\": int(train_number),\n", " \"valid_start\":int(train_number+1),\n", " \"valid_end\": int(validation_number),\n", " \"target_col\": [\"new_cases\"],\n", " \"relevant_cols\": [\"new_cases\", \"month\", \"weekday\"],\n", " \"scaler\": \"StandardScaler\", \n", " \"interpolate\": False\n", " },\n", " \"training_params\":\n", " {\n", " \"criterion\":\"MSE\",\n", " \"optimizer\": \"Adam\",\n", " \"optim_params\":\n", " {\n", "\n", " },\n", " \"lr\": wandb_config[\"lr\"],\n", " \"epochs\": 10,\n", " \"batch_size\":wandb_config[\"batch_size\"]\n", " \n", " },\n", " \"GCS\": False,\n", " \n", " \"sweep\":True,\n", " \"wandb\":False,\n", " \"forward_params\":{},\n", " \"metrics\":[\"MSE\"],\n", " \"inference_params\":\n", " { \n", " \"datetime_start\":\"2020-04-21\",\n", " \"hours_to_forecast\":10, \n", " \"test_csv_path\":file_path,\n", " \"decoder_params\":{\n", " \"decoder_function\": \"simple_decode\", \n", " \"unsqueeze_dim\": 1\n", " },\n", " \"dataset_params\":{\n", " \"file_path\": file_path,\n", " \"forecast_history\":wandb_config[\"forecast_history\"],\n", " \"forecast_length\":wandb_config[\"out_seq_length\"],\n", " \"relevant_cols\": [\"new_cases\", \"month\", \"weekday\"],\n", " \"target_col\": [\"new_cases\"],\n", " \"scaling\": \"StandardScaler\",\n", " \"interpolate_param\": False\n", " }\n", " }\n", " }\n", " if weight_path: \n", " config_default[\"weight_path\"] = weight_path\n", " wandb.config.update(config_default)\n", " return config_default\n", "\n", "sweep_config = {\n", " \"name\": \"Default sweep\",\n", " \"method\": \"grid\",\n", " \"parameters\": {\n", " \"batch_size\": {\n", " \"values\": [2, 3]\n", " },\n", " \"lr\":{\n", " \"values\":[0.001, 0.01]\n", " },\n", " \"forecast_history\":{\n", " \"values\":[1, 2, 3]\n", " },\n", " \"out_seq_length\":{\n", " \"values\":[1, 2, 3]\n", " }\n", " }\n", "}" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Bf549BXjdqDv", "colab_type": "code", "colab": {} }, "source": [ "def format_corona_data(region_df:pd.DataFrame, region_name:str):\n", " \"\"\"\n", " Format data for a specific region into \n", " a format that can be used with flow forecast. \n", " \"\"\"\n", " if region_name == 'county':\n", " region_name = region_df['full_county'].iloc[0]\n", " else:\n", " region_name = region_df['state'].iloc[0]\n", " #else:\n", " #region_name = region_df['country'].iloc[0]\n", " print(region_name)\n", " region_df['datetime'] = region_df['date']\n", " region_df['precip'] = 0\n", " region_df['temp'] = 0\n", " region_df = region_df.fillna(0)\n", " region_df['new_cases'] = region_df['cases'].diff()\n", " region_df.iloc[0]['new_cases'] = 0\n", " region_df= region_df.fillna(0)\n", " region_df.to_csv(region_name+\".csv\")\n", " return region_df, len(region_df), region_name+\".csv\"\n", "\n", "def loop_through_geo_codes(df, column='full_county'):\n", " df_county_list = []\n", " df['full_county'] = df['state'] + \"_\" + df['county'] \n", " for code in df['full_county'].unique():\n", " mask = df['full_county'] == code\n", " df_code = df[mask]\n", " ts_count = len(df_code)\n", " if ts_count > 60:\n", " df_county_list.append(df_code)\n", " return df_county_list \n", "\n", "def fetch_time_series() -> pd.DataFrame:\n", " \"\"\"Fetch raw time series data from coronadatascraper.com\n", " Returns:\n", " pd.DataFrame: raw timeseries data at county/sub-region level\n", " \"\"\"\n", " if 1==1:\n", " url = \"https://coronadatascraper.com/timeseries.csv\"\n", " urllib.request.urlretrieve(url, \"timeseries.csv\")\n", "\n", " time_series_df = pd.read_csv(\"timeseries.csv\")\n", " return time_series_df" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9uWO6ZxlepDP", "colab_type": "code", "outputId": "424840ab-835f-4cda-fc87-acc67dd063b2", "colab": { "base_uri": "https://localhost:8080/", "height": 989 } }, "source": [ "import urllib \n", "df = fetch_time_series()\n", "df['month'] = pd.to_datetime(df['date']).map(lambda x: x.month)\n", "df['weekday'] = pd.to_datetime(df['date']).map(lambda x: x.weekday())\n", "df['year'] = pd.to_datetime(df['date']).map(lambda x: x.year)\n", "df_list = loop_through_geo_codes(df)\n", "region_df, full_len, file_path = format_corona_data(df_list[9], 'county')\n", "region_df.head()" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py:2822: DtypeWarning: Columns (2) have mixed types.Specify dtype option on import or set low_memory=False.\n", " if self.run_code(code, result):\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Washington, D.C._District of Columbia\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:13: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " del sys.path[0]\n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:14: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " \n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:15: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " from ipykernel import kernelapp as app\n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:18: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namelevelcitycountystatecountrypopulationlatlongurlaggregatetzcasesdeathsrecoveredactivetestedhospitalizeddischargedicugrowthFactordatemonthweekdayyearfull_countydatetimepreciptempnew_cases
13412District of Columbia, Washington, D.C., United...county0District of ColumbiaWashington, D.C.United States705749.038.894-77.0145https://coronavirus.dc.gov/page/coronavirus-data0America/New_York2.00.00.02.00.00.00.00.00.02020-03-07352020Washington, D.C._District of Columbia2020-03-07000.0
13413District of Columbia, Washington, D.C., United...county0District of ColumbiaWashington, D.C.United States705749.038.894-77.0145https://coronavirus.dc.gov/page/coronavirus-data0America/New_York2.00.00.02.00.00.00.00.01.02020-03-08362020Washington, D.C._District of Columbia2020-03-08000.0
13414District of Columbia, Washington, D.C., United...county0District of ColumbiaWashington, D.C.United States705749.038.894-77.0145https://coronavirus.dc.gov/page/coronavirus-data0America/New_York4.00.00.04.00.00.00.00.02.02020-03-09302020Washington, D.C._District of Columbia2020-03-09002.0
13415District of Columbia, Washington, D.C., United...county0District of ColumbiaWashington, D.C.United States705749.038.894-77.0145https://coronavirus.dc.gov/page/coronavirus-data0America/New_York4.00.00.04.00.00.00.00.01.02020-03-10312020Washington, D.C._District of Columbia2020-03-10000.0
13416District of Columbia, Washington, D.C., United...county0District of ColumbiaWashington, D.C.United States705749.038.894-77.0145https://coronavirus.dc.gov/page/coronavirus-data0America/New_York10.00.00.010.00.00.00.00.02.52020-03-11322020Washington, D.C._District of Columbia2020-03-11006.0
\n", "
" ], "text/plain": [ " name ... new_cases\n", "13412 District of Columbia, Washington, D.C., United... ... 0.0\n", "13413 District of Columbia, Washington, D.C., United... ... 0.0\n", "13414 District of Columbia, Washington, D.C., United... ... 2.0\n", "13415 District of Columbia, Washington, D.C., United... ... 0.0\n", "13416 District of Columbia, Washington, D.C., United... ... 6.0\n", "\n", "[5 rows x 30 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { "id": "VIF20U5NSV14", "colab_type": "code", "colab": {} }, "source": [ "special_city_list1 = [\"California_Los Angeles County\", \"Illinois_Cook County\", \"Arizona_Maricopa County\", \"Massachusetts_Middlesex County\", \"Texas_Dallas County\", \"Texas_Harris County\", \"Florida_Miami Dade County\", \"California_Riverside County\", \"Colorado_Denver County\", \"Ohio_Cuyahoga County\", \"New York_Queens County\", \"New York_Bronx County\"]" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "J0g8l3oIUsIP", "colab_type": "code", "colab": {} }, "source": [ "selected_list = {}\n", "for dfs in df_list:\n", " if dfs['full_county'].iloc[0] in special_city_list1:\n", " selected_list[dfs['full_county'].iloc[0]] = dfs" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "q7AABWBdSXMl", "colab_type": "code", "colab": {} }, "source": [ "import wandb\n", "#12_May_202004_39AM_model.pth <- Solar pretrained model\n", "for county in selected_list.values():\n", " region_df, full_len, file_path = format_corona_data(county, 'county')\n", " sweep_id = wandb.sweep(sweep_config, project=\"pretrain-counties\")\n", " wandb.agent(sweep_id, lambda:train_function(\"PyTorch\", make_config_file(file_path, full_len, weight_path=\"12_May_202004_39AM_model.pth\")))\n", " !gsutil cp -r -n model_save gs://coronaviruspublicdata/pretrained/" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eu9tqu4mk0lB", "colab_type": "text" }, "source": [ "**Check out the sweeps here :** *https://app.wandb.ai/pranjalya/pretrain-counties*" ] }, { "cell_type": "code", "metadata": { "id": "scXyRxzvjWGa", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }