{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment ensemble prediction." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Goal: use a very simple setup (only [0,-1] timesteps), early stopping and a simple or no scheduler to have a look at the ensemble predictions.\n", "\n", "\n", "Therefore, it seems as if we need the following dataset ID: 278771" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We have three experiments, one with a constant learning rate at 1e-4:\n", "\n", "MID: 67D778\n", "\n", "and one with a learning rate that decreases on plateaus:\n", "\n", "MID: 542704\n", "\n", "and one, without early stopping:\n", "\n", "MID: BB7742" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import xarray as xr\n", "import numpy as np\n", "import torch\n", "from benchmark.bm.score import compute_weighted_rmse, compute_weighted_mae, compute_weighted_acc\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "import copy\n", "import matplotlib.gridspec as gridspec\n", "\n", "from WD.plotting import plot_map, add_label_to_axes\n", "\n", "import cartopy.crs as ccrs" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "ds = xr.open_dataset(\"/data/compoundx/WeatherDiff/model_input/13B689_output_min_max.nc\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "ds_id_2 = \"278771\"\n", "ds_id_3 = \"93EB03\"\n", "\n", "# model_id_1 = \"DC855F\"\n", "model_id_2 = \"8841D8\"\n", "model_id_3 = \"BCE7BB\"\n", "\n", "# targets_1 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id}/{model_id_1}_target.nc\")\n", "# predictions_1 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id}/{model_id_1}_gen.nc\")\n", "\n", "targets_2 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id_2}/{model_id_2}_target.nc\")\n", "predictions_2 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id_2}/{model_id_2}_gen.nc\")\n", "\n", "targets_3 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id_3}/{model_id_3}_target.nc\")\n", "predictions_3 = xr.load_dataset(f\"/data/compoundx/WeatherDiff/model_output/{ds_id_3}/{model_id_3}_gen.nc\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
<xarray.Dataset>\n",
"Dimensions: (lat: 32, lon: 64, lead_time: 1, ensemble_member: 1,\n",
" init_time: 2903)\n",
"Coordinates:\n",
" * lat (lat) float64 -87.19 -81.56 -75.94 ... 75.94 81.56 87.19\n",
" * lon (lon) float64 0.0 5.625 11.25 16.88 ... 343.1 348.8 354.4\n",
" * lead_time (lead_time) int64 12\n",
" * ensemble_member (ensemble_member) int64 0\n",
" * init_time (init_time) datetime64[ns] 2017-01-01T12:00:00 ... 2018-...\n",
"Data variables:\n",
" z (ensemble_member, init_time, lead_time, lat, lon) float32 ...