{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Overview\n", "\n", "This notebook contains all experiment results exhibited in our paper." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import glob\n", "import numpy as np\n", "import pandas as pd\n", "import json\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "sns.set(style='white')\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "\n", "from tqdm.auto import tqdm\n", "from joblib import Parallel, delayed\n", "\n", "def func(x, N=80):\n", " ret = x.ret.copy()\n", " x = x.rank(pct=True)\n", " x['ret'] = ret\n", " diff = x.score.sub(x.label)\n", " r = x.nlargest(N, columns='score').ret.mean()\n", " r -= x.nsmallest(N, columns='score').ret.mean()\n", " return pd.Series({\n", " 'MSE': diff.pow(2).mean(), \n", " 'MAE': diff.abs().mean(), \n", " 'IC': x.score.corr(x.label),\n", " 'R': r\n", " })\n", " \n", "ret = pd.read_pickle(\"data/ret.pkl\").clip(-0.1, 0.1)\n", "def backtest(fname, **kwargs):\n", " pred = pd.read_pickle(fname).loc['2018-09-21':'2020-06-30'] # test period\n", " pred['ret'] = ret\n", " dates = pred.index.unique(level=0)\n", " res = Parallel(n_jobs=-1)(delayed(func)(pred.loc[d], **kwargs) for d in dates)\n", " res = {\n", " dates[i]: res[i]\n", " for i in range(len(dates))\n", " }\n", " res = pd.DataFrame(res).T\n", " r = res['R'].copy()\n", " r.index = pd.to_datetime(r.index)\n", " r = r.reindex(pd.date_range(r.index[0], r.index[-1])).fillna(0) # paper use 365 days\n", " return {\n", " 'MSE': res['MSE'].mean(),\n", " 'MAE': res['MAE'].mean(),\n", " 'IC': res['IC'].mean(),\n", " 'ICIR': res['IC'].mean()/res['IC'].std(),\n", " 'AR': r.mean()*365,\n", " 'AV': r.std()*365**0.5,\n", " 'SR': r.mean()/r.std()*365**0.5,\n", " 'MDD': (r.cumsum().cummax() - r.cumsum()).max()\n", " }, r\n", "\n", "def fmt(x, p=3, scale=1, std=False):\n", " _fmt = '{:.%df}'%p\n", " string = _fmt.format((x.mean() if not isinstance(x, (float, np.floating)) else x) * scale)\n", " if std and len(x) > 1:\n", " string += ' ('+_fmt.format(x.std()*scale)+')'\n", " return string\n", "\n", "def backtest_multi(files, **kwargs):\n", " res = []\n", " pnl = []\n", " for fname in files:\n", " metric, r = backtest(fname, **kwargs)\n", " res.append(metric)\n", " pnl.append(r)\n", " res = pd.DataFrame(res)\n", " pnl = pd.concat(pnl, axis=1)\n", " return {\n", " 'MSE': fmt(res['MSE'], std=True),\n", " 'MAE': fmt(res['MAE'], std=True),\n", " 'IC': fmt(res['IC']),\n", " 'ICIR': fmt(res['ICIR']),\n", " 'AR': fmt(res['AR'], scale=100, p=1)+'%',\n", " 'VR': fmt(res['AV'], scale=100, p=1)+'%',\n", " 'SR': fmt(res['SR']),\n", " 'MDD': fmt(res['MDD'], scale=100, p=1)+'%'\n", " }, pnl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Preparation\n", "\n", "\n", "You could prepare the source data as below for the backtest code:\n", "1. Linear: see Qlib examples\n", "2. LightGBM: see Qlib examples\n", "3. MLP: see Qlib examples\n", "4. SFM: see Qlib examples\n", "5. ALSTM: `qrun` configs/config_alstm.yaml\n", "6. Transformer: `qrun` configs/config_transformer.yaml\n", "7. ALSTM+TRA: `qrun` configs/config_alstm_tra_init.yaml && `qrun` configs/config_alstm_tra.yaml\n", "8. Tranformer+TRA: `qrun` configs/config_transformer_tra_init.yaml && `qrun` configs/config_transformer_tra.yaml" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "exps = {\n", " 'Linear': ['output/Linear/pred.pkl'],\n", " 'LightGBM': ['output/GBDT/lr0.05_leaves128/pred.pkl'],\n", " 'MLP': glob.glob('output/search/MLP/hs128_bs512_do0.3_lr0.001_seed*/pred.pkl'),\n", " 'SFM': glob.glob('output/search/SFM/hs32_bs512_do0.5_lr0.001_seed*/pred.pkl'),\n", " 'ALSTM': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", " 'Trans.': glob.glob('output/search/Transformer/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", " 'ALSTM+TS':glob.glob('output/LSTM_Attn_TS/hs256_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", " 'Trans.+TS':glob.glob('output/Transformer_TS/head4_hs64_bs1024_do0.1_lr0.0002_seed*/pred.pkl'),\n", " 'ALSTM+TRA(Ours)': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", " 'Trans.+TRA(Ours)': glob.glob('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb1.0_head4_hs64_bs512_do0.1_lr0.0005_seed*/pred.pkl')\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0acd535e05944e539fd001009ed0748d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MSEMAEICICIRARVRSRMDD
Linear0.1630.3270.0200.132-3.2%16.8%-0.19132.1%
LightGBM0.1600.3230.0410.2927.8%15.5%0.50325.7%
MLP0.160 (0.002)0.323 (0.003)0.0370.2733.7%15.3%0.26426.2%
SFM0.159 (0.001)0.321 (0.001)0.0470.3817.1%14.3%0.49722.9%
ALSTM0.158 (0.001)0.320 (0.001)0.0530.41912.3%13.7%0.89720.2%
Trans.0.158 (0.001)0.322 (0.001)0.0510.40014.5%14.2%1.02822.5%
ALSTM+TS0.160 (0.002)0.321 (0.002)0.0390.2916.7%14.6%0.48022.3%
Trans.+TS0.160 (0.004)0.324 (0.005)0.0370.27810.4%14.7%0.72223.7%
ALSTM+TRA(Ours)0.157 (0.000)0.318 (0.000)0.0590.46012.4%14.0%0.88520.4%
Trans.+TRA(Ours)0.157 (0.000)0.320 (0.000)0.0560.44216.1%14.2%1.13323.1%
\n", "" ], "text/plain": [ " MSE MAE IC ICIR AR VR \\\n", "Linear 0.163 0.327 0.020 0.132 -3.2% 16.8% \n", "LightGBM 0.160 0.323 0.041 0.292 7.8% 15.5% \n", "MLP 0.160 (0.002) 0.323 (0.003) 0.037 0.273 3.7% 15.3% \n", "SFM 0.159 (0.001) 0.321 (0.001) 0.047 0.381 7.1% 14.3% \n", "ALSTM 0.158 (0.001) 0.320 (0.001) 0.053 0.419 12.3% 13.7% \n", "Trans. 0.158 (0.001) 0.322 (0.001) 0.051 0.400 14.5% 14.2% \n", "ALSTM+TS 0.160 (0.002) 0.321 (0.002) 0.039 0.291 6.7% 14.6% \n", "Trans.+TS 0.160 (0.004) 0.324 (0.005) 0.037 0.278 10.4% 14.7% \n", "ALSTM+TRA(Ours) 0.157 (0.000) 0.318 (0.000) 0.059 0.460 12.4% 14.0% \n", "Trans.+TRA(Ours) 0.157 (0.000) 0.320 (0.000) 0.056 0.442 16.1% 14.2% \n", "\n", " SR MDD \n", "Linear -0.191 32.1% \n", "LightGBM 0.503 25.7% \n", "MLP 0.264 26.2% \n", "SFM 0.497 22.9% \n", "ALSTM 0.897 20.2% \n", "Trans. 1.028 22.5% \n", "ALSTM+TS 0.480 22.3% \n", "Trans.+TS 0.722 23.7% \n", "ALSTM+TRA(Ours) 0.885 20.4% \n", "Trans.+TRA(Ours) 1.133 23.1% " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "report\n", "# print(report.to_latex())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RQ1\n", "\n", "Case study" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed1000/pred.pkl')\n", "code = 'SH600157'\n", "date = '2018-09-28'\n", "lookbackperiod = 50\n", "\n", "prob = df.iloc[:, -3:].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n", "pred = df.loc[:,[\"score_0\",\"score_1\",\"score_2\",\"label\"]].loc(axis=0)[:, code].reset_index(level=1, drop=True).loc[date:].iloc[:lookbackperiod]\n", "e_all = pred.iloc[:,:-1].sub(pred.iloc[:,-1], axis=0).pow(2)\n", "e_all = e_all.sub(e_all.min(axis=1), axis=0)\n", "e_all.columns = [r'$\\theta_%d$'%d for d in range(1, 4)]\n", "prob = pd.Series(np.argmax(prob.values, axis=1), index=prob.index).rolling(7).mean().round()\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(7, 3))\n", "e_all.plot(ax=axes[0], xlabel='', rot=30)\n", "prob.plot(ax=axes[1], xlabel='', rot=30, color='red', linestyle='None', marker='^', markersize=5)\n", "plt.yticks(np.array([0, 1, 2]), e_all.columns.values)\n", "axes[0].set_ylabel('Predictor Loss')\n", "axes[1].set_ylabel('Router Selection')\n", "plt.tight_layout()\n", "# plt.savefig('select.pdf', bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RQ2\n", "\n", "You could prepared the source data for this test as below:\n", "1. Random: Setting `src_info` = \"NONE\"\n", "2. LR: Setting `src_info` = \"LR\"\n", "3. TPE: Setting `src_info` = \"TPE\"\n", "4. LR+TPE: Setting `src_info` = \"LR_TPE\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "exps = {\n", " 'Random': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcNONE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", " 'LR': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcLR_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", " 'TPE': glob.glob('output/search/LSTM_Attn_tra/K10_traHs16_traSrcTPE_traLamb1.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl'),\n", " 'LR+TPE': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/pred.pkl')\n", "}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "910721fc4a7b46eea5ba6d50647320d4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/4 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MSEMAEICICIRARVRSRMDD
Random0.159 (0.001)0.321 (0.002)0.0480.36211.4%14.1%0.81021.1%
LR0.158 (0.001)0.320 (0.001)0.0530.40910.3%13.4%0.77220.8%
TPE0.158 (0.001)0.321 (0.001)0.0490.38110.3%14.0%0.74121.2%
LR+TPE0.157 (0.000)0.318 (0.000)0.0590.46012.4%14.0%0.88520.4%
\n", "" ], "text/plain": [ " MSE MAE IC ICIR AR VR SR MDD\n", "Random 0.159 (0.001) 0.321 (0.002) 0.048 0.362 11.4% 14.1% 0.810 21.1%\n", "LR 0.158 (0.001) 0.320 (0.001) 0.053 0.409 10.3% 13.4% 0.772 20.8%\n", "TPE 0.158 (0.001) 0.321 (0.001) 0.049 0.381 10.3% 14.0% 0.741 21.2%\n", "LR+TPE 0.157 (0.000) 0.318 (0.000) 0.059 0.460 12.4% 14.0% 0.885 20.4%" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "report\n", "# print(report.to_latex())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RQ3\n", "\n", "Set `lamb` = 0 to obtain results without Optimal Transport(OT)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUAAAAEDCAYAAABEXN1oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUjUlEQVR4nO3de0zV9/3H8RceRExBDcjlMEdtaWRk1EvXOGzqthQNpD30uMz2dNgtSye2GtalW1a1S7nMpM5s7pfW3iLZ7Cgmpdi1JzBvc6ZbaWi3NlvAHS/TYSzdERB+ririgcP5/UE4vyHYc5Dv8cj5PB9JEzl++J43X+3Tc/2cuEAgEBAAGGhatAcAgGghgACMRQABGIsAAjBWfLQHkKT+/n4dOXJEaWlpstls0R4HQIzw+/3q7u5Wfn6+EhMTx/z+TRHAI0eOaM2aNdEeA0CM2r17t+6+++4xl98UAUxLS5M0PGRmZmaUpwEQK86ePas1a9YEG3O1myKAI3d7MzMzNW/evChPAyDWXOuhNZ4EAWAsAgjAWAQQgLFuiscAgalqYGBAHR0d6u/vj/YoRrPZbJozZ47mzp2radPCv11HAIFJ6OjoUHJysubPn6+4uLhoj2OkQCCggYEBdXZ2qqOjQ9nZ2WF/L3eBgUno7+9Xamoq8YuiuLg4JSQk6Atf+IIuXbo0oe8lgMAkEb+bw0Tu+ga/JwJzAMCUwGOAmJJ8/gEl2KbfdMf2DfiVMN3697NbeVyn06n6+nolJibqtddeU0lJiVJTUyVJO3bsUF9fnzZu3Bh6Jp9Pv/rVr3To0CHFx8crMTFR5eXlWrFihd577z398pe/lCSdO3dOQ0NDSk9PlySVl5dr5cqVlvwsk0UAMSUl2Kbr4fr1ETn2m65Xrvt7E6bbVPJjt4XTDGvc7rTsWG73/89XW1ure+65JxjAiaiqqlJfX59+//vfa8aMGTpx4oTWrl2r2bNna/ny5Vq+fLmkiUX1RuMuMBAj3njjDVVXV0uSWltblZubq9bWVknDsaqvr5ck5ebm6tKlS3rllVfU1dWlJ598Uk6nUydPnpQkdXZ2qqysTMXFxVq3bp0uX7485ro+/fRT7du3T1VVVZoxY4YkacGCBXriiSf04osv3ogf1xIEEIgRy5YtU0tLiySppaVFS5Ys0QcffBD8etmyZaPWr1+/Xunp6XrhhRfkdrt1xx13SBrenWn79u3at2+fBgcH1djYOOa6Tpw4oezsbM2ZM2fU5YsXL9axY8ci8NNFBgEEYsStt96qK1eu6OzZs2ppadGPfvQjtbS0yOv1amBgIOzXx917772aNWuW4uLitHDhQp05c2bMmlj5LDUCCMSQgoICvfvuu+rp6dHSpUvV3d2td999V1/96lfDPsbIXVpp+B0Wfr9/zJoFCxbozJkzOn/+/KjL//73vys3N/e657/RCCAQQwoKCrRz504tWbJEknTXXXeppqZmzN3fEbfccosuXLgw4euZN2+eiouLVVVVpStXrkgavlv86quvqry8/Pp/gBuMZ4EBC/kG/JY+Y/vfxw3nZTAFBQV6+umng8ErKChQfX29CgoKxl3/3e9+V88884wSExO1ffv2Cc1UVVWl7du36/7779f06dM1Y8YM/fSnP9XSpUsndJxoirsZPhi9o6NDhYWF+uMf/8iGqAjbzfAymKNHjyovLy8ic2Dirv7zCNUW7gIDMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAgIV8/oEpdVzT8UJowEKR2qZrMlt0Xc2q/QDDdfV1XC2a+woSQMAwVu0HGK5Q1xHNfQXDugvc3t4ul8uloqIiuVwunT59esyanp4erVu3TiUlJcH3CA4ODlo2KIDPF+n9AC9duqTNmzfL4XDI4XBo586dweu+7777dOLEiTFfX+s6RkR7X8GwAlhZWanS0lIdOHBApaWlqqioGLPm1VdfVU5OjhobG9XY2Kh//OMfOnjwoOUDAxhfpPcDfPnllzU0NKTGxka98cYbcrvd+tOf/vS5M13rOkZEe1/BkAHs6emRx+ORw+GQJDkcDnk8HvX29o5aFxcXp0uXLmloaEg+n08DAwPKyMiIzNQAxoj0foAtLS166KGHFBcXp6SkJD3wwAPB4F6vaG9FEDKAXq9XGRkZstmGd6Kw2WxKT0+X1+sdtW7Dhg1qb2/XvffeG/zvK1/5SmSmBjCuSO4HGAgExnwE6MjXNptNQ0NDwctHtsgKJdr7Clr2Mpj9+/crNzdXzc3N+vOf/6yPPvpI+/fvt+rwAMIQyf0A77nnHu3Zs0eBQEAXL17U3r17g8fNzs5WW1ubpOFbiufOnQvrOqK9r2DIZ4Htdrs6Ozvl9/uD/xp0dXXJbrePWldXV6fnnntO06ZNU3Jysu677z59+OGHKi4ujtjwwM3G5x+w9CUr/33ccD6qM5L7AW7YsEFbtmxRSUmJJOnBBx/U1772NUnSD3/4Q23atEkNDQ266667lJWVdc3ruPpxwGjuKxjWfoDf+c53tHr1ajmdTrndbu3Zs0evv/76qDVPPPGE8vPzVV5eLp/Pp8cff1wrV65UaWlpyCHYDxDXg/0AcbWI7AdYVVWluro6FRUVqa6uLvhUe1lZWfBm7zPPPKOPP/5YJSUlWrVqlebPn6+HH37Yip8JACIirBdC5+TkqKGhYczlNTU1wV9nZ2dr165d1k0GABHGe4GBSYr2Szkw7L+fhQ4XAQQmITExUT09PUQwigKBgHw+nz799FPdcsstE/pe3gsMTMK8efPU0dGh7u7uaI9itPj4eM2ePVtz586d2PdFaB7ACNOnT9dtt90W7TFwnbgLDMBYBBCAsQggAGMRQADGIoAAjEUAARiLAAIwFgEEYCwCCMBYBBCAsQggAGMRQADGIoAAjEUAARiLAAIwFgEEYCwCCMBYBBCAsQggAGMRQADGIoAAjEUAARiLAAIwFgEEYCwCCMBYBBCAsQggAGMRQADGIoAAjEUAARiLAAIwFgEEYCwCCMBYBBCAscIKYHt7u1wul4qKiuRyuXT69Olx1+3du1clJSVyOBwqKSnRuXPnrJwVACwVH86iyspKlZaWyul0yu12q6KiQrW1taPWtLW16cUXX9Rvf/tbpaWl6cKFC0pISIjI0ABghZC3AHt6euTxeORwOCRJDodDHo9Hvb29o9a99tpreuyxx5SWliZJSk5O1owZMyIwMgBYI2QAvV6vMjIyZLPZJEk2m03p6enyer2j1p06dUqffPKJ1qxZo29+85t6+eWXFQgEIjM1AFggrLvA4fD7/Tp+/Lh27doln8+ntWvXKisrS6tWrbLqKgDAUiFvAdrtdnV2dsrv90saDl1XV5fsdvuodVlZWSouLlZCQoKSkpJUWFio1tbWyEwNABYIGcDU1FTl5eWpqalJktTU1KS8vDylpKSMWudwONTc3KxAIKCBgQF98MEH+tKXvhSZqQHAAmG9DKaqqkp1dXUqKipSXV2dqqurJUllZWVqa2uTJD3wwANKTU3V/fffr1WrVumOO+7Q6tWrIzc5AExSWI8B5uTkqKGhYczlNTU1wV9PmzZNmzdv1ubNm62bDgAiiHeCADAWAQRgLAIIwFgEEICxCCAAYxFAAMYigACMRQABGIsAAjAWAQRgLAIIwFgEEICxCCAAYxFAAMYigACMRQABGIsAAjAWAQRgLAIIwFgEEICxCCAAYxFAAMYigACMRQABGIsAAjAWAQRgLAIIwFgEEICxCCAAYxFAAMYigACMRQABGIsAAjAWAQRgLAIIwFgEEICxCCAAY4UVwPb2drlcLhUVFcnlcun06dPXXPuvf/1LixYt0rZt26yaEVOQb8Af7RGAkOLDWVRZWanS0lI5nU653W5VVFSotrZ2zDq/36/KykqtWLHC8kExtSRMt6nkx+6IHb9xuzNix4Y5Qt4C7OnpkcfjkcPhkCQ5HA55PB719vaOWbtz50594xvf0Pz58y0fFACsFjKAXq9XGRkZstlskiSbzab09HR5vd5R644dO6bm5mZ973vfi8igAGC1sO4ChzIwMKBnn31WW7duDYYSAG52IQNot9vV2dkpv98vm80mv9+vrq4u2e324Jru7m6dOXNG69atkyR99tlnCgQCunjxorZs2RK56QFgEkIGMDU1VXl5eWpqapLT6VRTU5Py8vKUkpISXJOVlaUPP/ww+PWOHTvU19enjRs3RmZqALBAWC+DqaqqUl1dnYqKilRXV6fq6mpJUllZmdra2iI6IABESliPAebk5KihoWHM5TU1NeOu/8EPfjC5qQDgBuCdIACMRQABGIsAAjAWAQRgLAIIwFgEEICxCCAAYxFAIEZEeg/GWNzj0ZLNEABEH3swThy3AAEYiwACMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBBGAsAgjAWAQQgLGMCCAfFwhgPEZ8LCYfFwhgPEbcAgSA8RBAAMYigACMRQABGIsAAjBWWM8Ct7e3a9OmTTp//rzmzJmjbdu2af78+aPWvPTSS9q7d69sNpvi4+P11FNPafny5ZGYGQAsEVYAKysrVVpaKqfTKbfbrYqKCtXW1o5as3DhQj322GOaOXOmjh07pkcffVTNzc1KTEyMyOAAMFkh7wL39PTI4/HI4XBIkhwOhzwej3p7e0etW758uWbOnClJys3NVSAQ0Pnz562fGAAsEjKAXq9XGRkZstlskiSbzab09HR5vd5rfs8777yj7OxsZWZmWjcpAFjM8neC/OUvf9Hzzz+v3/zmN1YfGgAsFfIWoN1uV2dnp/z+4fe7+v1+dXV1yW63j1n7t7/9TT/5yU/00ksv6fbbb7d+WgCwUMgApqamKi8vT01NTZKkpqYm5eXlKSUlZdS61tZWPfXUU3rhhRf05S9/OTLTAoCFwnodYFVVlerq6lRUVKS6ujpVV1dLksrKytTW1iZJqq6uVn9/vyoqKuR0OuV0OnX8+PHITQ4AkxTWY4A5OTlqaGgYc3lNTU3w12+99ZZ1UwHADcA7QQAYiwACMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBBGAsAgjAWAQQgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBtIDPPzAljw2YLj7aA8SCBNt0PVy/PiLHftP1SkSOC4BbgAAMRgABGIsAAjAWAQQQllh8so8nQQCEJRaf7OMWIABjEUAAxiKAAIxFAAEYiwACMFZYAWxvb5fL5VJRUZFcLpdOnz49Zo3f71d1dbVWrFihlStXqqGhwepZAcBSYQWwsrJSpaWlOnDggEpLS1VRUTFmTWNjo86cOaODBw+qvr5eO3bsUEdHh+UDA4BVQr4OsKenRx6PR7t27ZIkORwObdmyRb29vUpJSQmu27t3rx566CFNmzZNKSkpWrFihfbv36+1a9eGHMLv90uSzp49e70/R0gDfb0RO3ZHR4d8/3s5YseeqjjnNx7nfLSRpow05mohA+j1epWRkSGbzSZJstlsSk9Pl9frHRVAr9errKys4Nd2uz3soHV3d0uS1qxZE9b6m03h4Z9H7tj/UxixY09lnPMbbyqf8+7ubt16661jLr8p3gmSn5+v3bt3Ky0tLRhaAJgsv9+v7u5u5efnj/v7IQNot9vV2dkpv98vm80mv9+vrq4u2e32Mev+/e9/a+HChZLG3iL8PImJibr77rvDWgsAEzHeLb8RIZ8ESU1NVV5enpqamiRJTU1NysvLG3X3V5KKi4vV0NCgoaEh9fb26tChQyoqKprk6AAQOXGBQCAQatGpU6e0adMmffbZZ5o1a5a2bdum22+/XWVlZXryySd15513yu/362c/+5nef/99SVJZWZlcLlfEfwAAuF5hBRAAYhHvBAFgLAIIwFgEEICxCCAAYxFAAMYigACMRQABGIsAAjAWAZwEXkN+43HOb7xYPucEcII6Ojp07NgxSVJcXFyUpzHDeOd8aGgomiPFvPHO+bX21JvKeCtcmAYHB/Xzn/9chw8f1qxZs/Ttb39bq1evZvuuCBrvnH/rW99SfHx88PdHfg1rhPp7HmvnnFuAYWpra9OFCxd0+PBhbd26VXv27JHH45Ek7d69W++9956k2PxXMlrGO+cnT56UNLwD+bPPPqt33nlHUmzfTbuRxjvnx48flyQdOHBAGzduDH7eTyyccwIYwsgfckdHR3CvwytXrujUqVM6ePCgPvnkEx04cEA1NTWShnfMjoW/GNF0rXN+8uRJ7dmzR83NzWpra9Ojjz6qP/zhD/rPf/7DwxGT9Hnn3O126+OPP9aJEydUXl6uv/71r+rr64uJc85d4DA999xzmjlzpi5cuKB//vOfWrRokdxut/bt26ekpCT94he/0Be/+EU98sgjMXc3IVquPudLlizR7373O7355pvBzXZ//etf6/vf/35ww15MztXnfPHixXr77bf19ttvKy0tTY2NjTp06JAWLlyoBx98UGlpadEeeVK4BRjCyL8PBQUFqq2tVWZmpl5//XWtX79eixYtCn6eyYYNG7R//37t2rVLR44ciebIU961zvm6deu0ePHi4MMMR48e1eXLl+V2u3X06NFojjzlXeucP/7441q8eLF8Pp8uX76spUuX6vnnn9fRo0fV19cX5aknjwCGMHIzPzMzUw6HQ8uWLZMkdXV1qa2tTXPnzlUgENDFixd17tw5dXV1acGCBdEcecq71jnv7u5WW1ubZs+eLWn4ManDhw/rypUruu2226I2bywIdc6Tk5M1c+ZMnT17Vlu3blV+fr4yMzOjObIluJ8WpszMTCUlJcntdis3N1fnz59XYWGhkpOTJUlJSUnasWMH/yNa6FrnfNasWfL5fFqwYIEKCwt15513RnvUmPF557y/v199fX165JFHYubvOY8BTsBHH32kt956S62trRocHNTTTz+twsLRH+cXCAQUCAQ0bRo3rq0Q7jmXeF2mVcI559LweZ/q55wATtDg4KDef/99ff3rX4/2KMbgnN94ppxzAjgBV/+LxzOPkcc5v/FMOucEEICxeKAKgLEIIABjEUAAxiKAAIxFAAEYiwACMBYBBGAsAgjAWP8HimDX59TKOMMAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "a = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb0.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n", "b = pd.read_pickle('output/search/finetune/Transformer_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_head4_hs64_bs512_do0.1_lr0.0005_seed3000/pred.pkl')\n", "a = a.iloc[:, -3:]\n", "b = b.iloc[:, -3:]\n", "b = np.eye(3)[b.values.argmax(axis=1)]\n", "a = np.eye(3)[a.values.argmax(axis=1)]\n", "\n", "res = pd.DataFrame({\n", " 'with OT': b.sum(axis=0) / b.sum(),\n", " 'without OT': a.sum(axis=0)/ a.sum() \n", "},index=[r'$\\theta_1$',r'$\\theta_2$',r'$\\theta_3$'])\n", "res.plot.bar(rot=30, figsize=(5, 4), color=['b', 'g'])\n", "del a, b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RQ4\n", "\n", "You could prepared the source data for this test as below:\n", "1. K=1: which is exactly the alstm model\n", "2. K=3: Setting `num_states` = 3\n", "3. K=5: Setting `num_states` = 5\n", "4. K=10: Setting `num_states` = 10\n", "5. K=20: Setting `num_states` = 20\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "exps = {\n", " 'K=1': glob.glob('output/search/LSTM_Attn/hs256_bs1024_do0.1_lr0.0002_seed*/info.json'),\n", " 'K=3': glob.glob('output/search/finetune/LSTM_Attn_tra/K3_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", " 'K=5': glob.glob('output/search/finetune/LSTM_Attn_tra/K5_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", " 'K=10': glob.glob('output/search/finetune/LSTM_Attn_tra/K10_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json'),\n", " 'K=20': glob.glob('output/search/finetune/LSTM_Attn_tra/K20_traHs16_traSrcLR_TPE_traLamb2.0_hs256_bs1024_do0.1_lr0.0001_seed*/info.json')\n", "}" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "report = dict()\n", "for k, v in exps.items():\n", " \n", " tmp = dict()\n", " for fname in v:\n", " with open(fname) as f:\n", " info = json.load(f)\n", " tmp[fname] = (\n", " {\n", " \"IC\":info[\"metric\"][\"IC\"],\n", " \"MSE\":info[\"metric\"][\"MSE\"]\n", " })\n", " tmp = pd.DataFrame(tmp).T\n", " report[k] = tmp.mean()\n", "report = pd.DataFrame(report).T" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(6,3)); axes = axes.flatten()\n", "report['IC'].plot.bar(rot=30, ax=axes[0])\n", "axes[0].set_ylim(0.045, 0.062)\n", "axes[0].set_title('IC performance')\n", "report['MSE'].astype(float).plot.bar(rot=30, ax=axes[1], color='green')\n", "axes[1].set_ylim(0.155, 0.1585)\n", "axes[1].set_title('MSE performance')\n", "plt.tight_layout()\n", "# plt.savefig('sensitivity.pdf')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ICMSE
K=10.0532470.157792
K=30.0555350.157410
K=50.0592240.156796
K=100.0594030.156766
K=200.0591930.156801
\n", "
" ], "text/plain": [ " IC MSE\n", "K=1 0.053247 0.157792\n", "K=3 0.055535 0.157410\n", "K=5 0.059224 0.156796\n", "K=10 0.059403 0.156766\n", "K=20 0.059193 0.156801" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "report" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "9de784e21d4a351f53a5792b09a6ae66a23802b850ad98f62e10c0156e418c04" }, "kernelspec": { "display_name": "Python 3.8.5 64-bit ('base': conda)", "name": "python3" }, "language_info": { "name": "python", "version": "" } }, "nbformat": 4, "nbformat_minor": 5 }