{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "93a40bdd-42f0-4c0d-9ffd-8517edbb59fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "from constants import *\n",
    "from new_scripts.extraction_data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ac367576-beb0-46f1-9be9-cedb5ace1147",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean(data, cluster=-1):\n",
    "    t = np.zeros(data.shape[0])\n",
    "    div = 0\n",
    "    for j in range(36):\n",
    "        for k in range(72):\n",
    "            if cluster == -1 or cluster_map[j,k] == cluster:\n",
    "                t += data[:, j, k] * np.cos(np.radians(LAT[j]))\n",
    "    \n",
    "                div += np.cos(np.radians(LAT[j]))\n",
    "    t /= div\n",
    "    return t\n",
    "\n",
    "def get_obs(cluster=-1):\n",
    "    fn = data_dir + 'obs.nc'\n",
    "\n",
    "    f = nc4.Dataset(fn, 'r')\n",
    "    data = f.variables['temperature_anomaly'][:]\n",
    "\n",
    "    return get_mean(data,cluster=cluster)\n",
    "\n",
    "test = get_obs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "aad48503-f7a3-490f-8a2c-20289506a3ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pre_ind(type, model='IPSL', phys=1):\n",
    "    \n",
    "    dic = MODELS[model]\n",
    "    giss_cond = ((type == 'hist-aer' and ((phys == 1 and i not in range(5, 10)) or (phys != 1 and i in range(5, 10)))) or \\\n",
    "                (type == 'historical' and ((phys == 1 and i < 10) or (phys != 1 and i >= 10)))) or \\\n",
    "                (type not in ['hist-aer', 'historical'] and i in range(5, 10))\n",
    "    result = np.zeros((36,72))\n",
    "\n",
    "    for i in range(dic[type]):\n",
    "\n",
    "        if (model == 'GISS' and giss_cond) or model != 'GISS':\n",
    "            \n",
    "            fn = f'{data_dir}{model}_{type}_{str(i+1)}.nc'\n",
    "            f = nc4.Dataset(fn, 'r')\n",
    "        \n",
    "            data = f.variables['tas'][0:50]\n",
    "        \n",
    "            result +=np.mean(data,axis=0)\n",
    "        \n",
    "    result /= dic[type]\n",
    "    \n",
    "    return result\n",
    "\n",
    "def get_simu(type, simu, model='IPSL', cluster=-1, filtrage=False):\n",
    "    \n",
    "    if model == 'GISS':\n",
    "        phys = 1\n",
    "        i = simu\n",
    "        if type == 'hist-aer':\n",
    "            if i in range(6, 11):\n",
    "                phys = 2\n",
    "        elif type == 'historical':\n",
    "            if i > 10:\n",
    "                phys = 2\n",
    "        pre_ind = get_pre_ind(type, model=model, phys=phys)\n",
    "\n",
    "    else:\n",
    "        pre_ind = get_pre_ind(type, model=model)\n",
    "\n",
    "    fn = f'{data_dir}{model}_{type}_{str(i+1)}.nc'\n",
    "    f = nc4.Dataset(fn, 'r')\n",
    "    data = f.variables['tas'][50:]\n",
    "\n",
    "    data = data - pre_ind\n",
    "    result = get_mean(data, cluster=cluster)\n",
    "    \n",
    "    if(filtrage):\n",
    "        if(type=='hist-GHG' or type=='hist-aer'):\n",
    "    \n",
    "            result = signal.filtfilt(b, a, result)\n",
    "    return result\n",
    "\n",
    "def get_data_forcage(type, model='IPSL', cluster=-1, filtrage=False):\n",
    "\n",
    "    dic = MODELS[model]\n",
    "    result = np.zeros((dic[type],115))\n",
    "    for i in range(dic[type]):\n",
    "        result[i] = get_simu(type, i+1, model, cluster, filtrage=filtrage)[0:115]\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "11adf3c9-3fd9-4fdb-b639-1eaf42622f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_metric(model='IPSL', cluster=-1, normalis=False, filtrage=False, metric='mean', as_tensor=False):\n",
    "\n",
    "    aer = get_data_forcage('hist-aer', model=model, cluster=cluster, filtrage=filtrage)[:,0:115]\n",
    "    ghg = get_data_forcage('hist-GHG', model=model, cluster=cluster, filtrage=filtrage)[:,0:115]\n",
    "    nat = get_data_forcage('hist-nat', model=model, cluster=cluster, filtrage=filtrage)[:,0:115]\n",
    "    historical = get_data_forcage('historical', model=model, cluster=cluster, filtrage=filtrage)[:,0:115]\n",
    "\n",
    "    max_hist = np.max(np.mean(historical, axis=0))\n",
    "    aer = aer / max_hist\n",
    "    ghg = ghg / max_hist\n",
    "    nat = nat / max_hist\n",
    "    historical = historical / max_hist\n",
    "    \n",
    "    if normalis:\n",
    "        if metric == 'std':\n",
    "    \n",
    "            aer = np.std(aer, axis=0)\n",
    "            ghg = np.std(ghg, axis=0)\n",
    "            nat = np.std(nat, axis=0)\n",
    "            historical = np.std(historical, axis=0)\n",
    "\n",
    "        elif metric == 'mean':\n",
    "\n",
    "            aer = np.mean(aer, axis=0)\n",
    "            ghg = np.mean(ghg, axis=0)\n",
    "            nat = np.mean(nat, axis=0)\n",
    "            historical = np.mean(historical, axis=0)\n",
    "\n",
    "    if as_tensor:\n",
    "        \n",
    "        aer = torch.tensor(aer).float()\n",
    "        ghg = torch.tensor(ghg).float(\n",
    "        nat = torch.tensor(nat).float(\n",
    "        historical = torch.tensor(historical).float(\n",
    "\n",
    "    return aer, ghg, nat, historical\n",
    "\n",
    "def get_data_set(model='IPSL', cluster=-1, normalis=False, filtrage=False):\n",
    "    \n",
    "    liste_max = []\n",
    "    if (model != 'ALL'):\n",
    "\n",
    "        aer, ghg, nat, historical = get_metric(model, cluster, normalis, filtrage, as_tensor=True)\n",
    "        max_hist = np.max(np.mean(historical, axis=0))\n",
    "        liste_max.append(max_hist)\n",
    "\n",
    "    elif model == 'ALL':\n",
    "        \n",
    "        liste_models = ['CanESM5', 'CNRM', 'IPSL', 'ACCESS', 'BCC', 'FGOALS', \n",
    "                        'HadGEM3', 'MIRO', 'ESM2', 'NorESM2','CESM2','GISS']\n",
    "\n",
    "        aer = []\n",
    "        ghg = []\n",
    "        nat = []\n",
    "        historical = []\n",
    "\n",
    "        for model_curr in liste_models:\n",
    "\n",
    "            aer_curr, ghg_curr, nat_curr, historical_curr = get_std(model_curr, cluster, normalis, filtrage, as_tensor=True)\n",
    "\n",
    "            max_hist = torch.max(torch.mean(historical_curr, dim=0))\n",
    "            liste_max.append(max_hist)\n",
    "\n",
    "            aer.append(aer_curr)\n",
    "            ghg.append(ghg_curr)\n",
    "            nat.append(nat_curr)\n",
    "            historical.append(historical_curr)\n",
    "\n",
    "    return ghg, aer, nat, historical, np.array(liste_max)\n",
    "\n",
    "def get_metric_data_set(model='IPSL', cluster=-1, normalis=False, filtrage=False, metric='mean'):\n",
    "    \n",
    "    if model != 'ALL':\n",
    "\n",
    "        aer, ghg, nat, historical = get_metric(model, cluster, normalis, filtrage, metric)\n",
    "\n",
    "        result = np.stack((ghg, aer, nat))        \n",
    "\n",
    "    elif model == 'ALL':\n",
    "\n",
    "        liste_models = ['CanESM5', 'CNRM', 'IPSL', 'ACCESS', 'BCC', 'FGOALS', \n",
    "                        'HadGEM3', 'MIRO', 'ESM2', 'NorESM2','CESM2','GISS']\n",
    "        result = []\n",
    "        historical = []\n",
    "        \n",
    "        for model_curr in liste_models:\n",
    "\n",
    "            aer_ipsl, ghg_ipsl, nat_ipsl, historical_ipsl = get_metric(model_curr, cluster, normalis, filtrage, metric)\n",
    "\n",
    "            result_ipsl = np.stack((ghg_ipsl, aer_ipsl, nat_ipsl))\n",
    "            result.append(result_ipsl)\n",
    "            historical.append(historical_ipsl)\n",
    "\n",
    "        result = np.mean(np.array(result), axis=0)\n",
    "        historical = np.mean(np.array(historical), axis=0)\n",
    "    \n",
    "    return torch.tensor(result).unsqueeze(0), historical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5bf00c17-a294-42d5-a812-f2f149c11d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = get_mean(data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 + Jaspy",
   "language": "python",
   "name": "jaspy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}