{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# scatter_selector widget\n", "\n", "A set of custom matplotlib widgets that allow you to select points on a scatter plot as use that as input to other interactive plots. There are three variants that differ only in what they pass to their callbacks:\n", "\n", "1. `scatter_selector`: callbacks will receive `index, (x, y)` where `index` is the position of the point in the of the points.\n", "2. `scatter_selector_values`: callbacks will receive `x, y`\n", "3. `scatter_selector_index`: callbacks will receive `index`\n", "\n", "\n", "In this example we will use `scatter_selector_index` along with the `indexer` convenience function to make line plots of stock data. However, you can use custom functions for the interactive plots, or even attach your own callbacks to the scatter_selector widgets.\n", "\n", "\n", "## PCA of Stock Data\n", "\n", "For this example we will plot companies in SP500 in a scatter plot by principle components extracted from principal components analysis (PCA) an interactive visualization of companies in SP500 using [PCA](https://towardsdatascience.com/a-one-stop-shop-for-principal-component-analysis-5582fb7e0a9c). The data was originally obtained from https://www.kaggle.com/camnugent/sandp500 and the data was cleaned using code derived from https://github.com/Hekstra-Lab/scientific-python-bootcamp/tree/master/day3\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib ipympl\n", "import pickle\n", "\n", "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import mpl_interactions.ipyplot as iplt\n", "from mpl_interactions import indexer, panhandler, zoom_factory\n", "from mpl_interactions.utils import indexer\n", "from mpl_interactions.widgets import scatter_selector_index" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data loading/cleaning\n", "\n", "For this example we have pre-cleaned data that we will just load. If you are curious on how the data was originally processed you see the full code at the bottom of this notebook.\n", "\n", "The datafiles that we load for this example are available for download at https://github.com/ianhi/mpl-interactions/tree/master/examples/data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "with open(\"data/stock-metadata.pickle\", \"rb\") as f:\n", " meta = pickle.load(f)\n", "prices = np.load(\"data/stock-prices.npz\")[\"prices\"]\n", "names = meta[\"names\"]\n", "good_idx = meta[\n", " \"good_idx\"\n", "] # only plot the ones for which we were able to parse sector info\n", "data_colors = meta[\"data_colors\"]\n", "\n", "# calculate the daily price difference\n", "price_changes = np.diff(prices)\n", "\n", "# Below is a pretty standard way of normalizing numerical data\n", "normalized_price_changes = price_changes - price_changes.mean(axis=-1, keepdims=True)\n", "normalized_price_changes /= price_changes.std(axis=-1, keepdims=True)\n", "\n", "# calculate the covariance matrix\n", "covariance = np.cov(normalized_price_changes.T)\n", "\n", "# Calculate the eigenvectors (i.e. the principle components)\n", "evals, evecs = np.linalg.eig(covariance)\n", "evecs = np.real(evecs)\n", "\n", "# project the companies onto the principle components\n", "transformed = normalized_price_changes @ evecs\n", "\n", "# take only the first two components for plotting\n", "# we also take only the subset of companies for which it was easy to extract a sector and a name\n", "x, y = transformed[good_idx][:, 0], transformed[good_idx][:, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making the plot\n", "\n", "We create the left scatter plot using the `scatter_selector_index` which will tell use the index of the company that was clicked on. Since this is just a Matplotlib `AxesWidget` it can be passed directly to `iplt.plot` as a kwarg and the `controls` object will handle it approriately.\n", "\n", "In this example we also make use of the function `mpl_interactions.utils.indexer`. This is a convenience function that handles indexing an array for you. So these two statements are equivalent:\n", "\n", "```python\n", "# set up data\n", "arr = np.random.randn(4,100).cumsum(-1)\n", "\n", "def f(idx):\n", " return arr[idx]\n", "iplt.plot(f, idx=np.arange(4))\n", "\n", "# or equivalently\n", "iplt.plot(indexer(arr), idx=np.arange(4))\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "gif": "scatter-selector-stocks.apng" }, "outputs": [], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 5), gridspec_kw={\"width_ratios\": [1.5, 1]})\n", "index = scatter_selector_index(axs[0], x, y, c=data_colors, cmap=\"tab20\")\n", "\n", "# plot all the stock traces in light gray\n", "plt.plot(prices.T, color=\"k\", alpha=0.05)\n", "\n", "# add interactive components to the subplot on the right\n", "# note the use of indexer\n", "controls = iplt.plot(indexer(prices), idx=index, color=\"r\")\n", "iplt.title(indexer(names), controls=controls[\"idx\"])\n", "\n", "# styling + zooming\n", "axs[0].set_xlabel(\"PC-1\")\n", "axs[0].set_ylabel(\"PC-2\")\n", "axs[1].set_xlabel(\"days\")\n", "axs[1].set_ylabel(\"Price in $\")\n", "axs[1].set_yscale(\"log\")\n", "cid = zoom_factory(axs[0])\n", "ph = panhandler(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Datacleaning\n", "\n", "Below is the code we used to clean and save the datasets. While we start out with 500 companies we end up with only 468 as some of them we were unable to easily and correctly parse so they were thrown away." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "# Download the data from https://www.kaggle.com/camnugent/sandp500\n", "# and save it into a folder named `data`\n", "import glob\n", "\n", "test = np.loadtxt(\"data/A_data.csv\", delimiter=\",\", skiprows=1, usecols=1)\n", "sp500_glob = glob.glob(\n", " \"data/*.csv\",\n", ")\n", "names = []\n", "prices = np.zeros((len(sp500_glob), test.shape[0]))\n", "prices_good = []\n", "fails = []\n", "for i, f in enumerate(sp500_glob):\n", " fname = f.split(\"/\")[-1]\n", " names.append(fname.split(\"_\")[0])\n", " try:\n", " prices[i] = np.loadtxt(f, delimiter=\",\", skiprows=1, usecols=1)\n", " prices_good.append(True)\n", " except:\n", " fails.append(fname.split(\"_\")[0])\n", " prices_good.append(False)\n", " pass\n", "prices = prices[prices_good]\n", "np.savez_compressed(\"data/stock-prices.npz\", prices=prices)\n", "\n", "# processing names and sector info\n", "\n", "arr = np.loadtxt(\n", " \"data/SP500_names.csv\", delimiter=\"|\", skiprows=1, dtype=str, encoding=\"utf-8\"\n", ")\n", "name_dict = {a[0].strip(): a[[1, 2, 3]] for a in arr}\n", "# idx_to_info = {i:name_dict[real_names[i]] for i in range(468)}\n", "good_names = []\n", "primary = []\n", "secondary = []\n", "good_idx = np.zeros(real_names.shape[0], dtype=bool)\n", "for i, name in enumerate(real_names):\n", " try:\n", " info = name_dict[name]\n", " good_idx[i] = True\n", " good_names.append(info[0])\n", " primary.append(info[1])\n", " secondary.append(info[2])\n", " except:\n", " pass\n", "psector_dict = {val: i for i, val in enumerate(np.unique(primary))}\n", "data_colors = np.array([psector_dict[val] for val in primary], dtype=int)\n", "import pickle\n", "\n", "meta = {\n", " \"good_idx\": good_idx,\n", " \"names\": good_names,\n", " \"sector\": psector_dict,\n", " \"data_colors\": data_colors,\n", "}\n", "with open(\"data/stock-metadata.pickle\", \"wb\") as outfile:\n", " pickle.dump(meta, outfile)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.8" } }, "nbformat": 4, "nbformat_minor": 4 }