{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This notebook gives you a short introduction on how to use Dask to parallelize model training, particularly if you have multiple learning tasks on which you want to train individual models for.\n", "\n", "For brevity, I will not be elaborating on the exact machine learning task here, but focus on the idioms that we need to use Dask for this task." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "from dask.distributed import LocalCluster, Client\n", "import numpy as np\n", "import pandas as pd\n", "import janitor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate a Dask Cluster\n", "\n", "Here, we instantiate a Dask `cluster` (this is only a `LocalCluster`, but other cluster types can be created too, such as an `SGECluster` or `KubeCluster`. We then connect a `client` to the cluster." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ericmjl/anaconda/envs/minimal-panel/lib/python3.7/site-packages/distributed/dashboard/core.py:72: UserWarning: \n", "Port 8787 is already in use. \n", "Perhaps you already have a cluster running?\n", "Hosting the diagnostics dashboard on a random port instead.\n", " warnings.warn(\"\\n\" + msg)\n" ] } ], "source": [ "client = Client()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "\n", "We will now preprocess our data and get it into a shape for machine learning." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from utils import molecular_weights, featurize_sequence_" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "drugs = ['ATV', 'DRV', 'FPV', 'IDV', 'LPV', 'NFV', 'SQV', 'TPV']\n", "\n", "data = (\n", " pd.read_csv(\"data/hiv-protease-data-expanded.csv\", index_col=0)\n", " .query(\"weight == 1.0\")\n", " .transform_column(\"sequence\", lambda x: len(x), \"seq_length\")\n", " .query(\"seq_length == 99\")\n", " .transform_column(\"sequence\", featurize_sequence_, \"features\")\n", " .transform_columns(drugs, np.log10)\n", ")\n", "\n", "features = pd.DataFrame(np.vstack(data['features'])).set_index(data.index)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | ATV | \n", "DRV | \n", "FPV | \n", "IDV | \n", "LPV | \n", "NFV | \n", "SQV | \n", "SeqID | \n", "TPV | \n", "seqid | \n", "sequence | \n", "sequence_object | \n", "weight | \n", "seq_length | \n", "features | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6 | \n", "1.50515 | \n", "NaN | \n", "0.477121 | \n", "1.544068 | \n", "1.50515 | \n", "1.462398 | \n", "2.214844 | \n", "4426 | \n", "NaN | \n", "4426-0 | \n", "PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEEMNLPGKWKPKM... | \n", "ID: 4426-0\\nName: <unknown name>\\nDescription:... | \n", "1.0 | \n", "99 | \n", "[[115.131, 146.1451, 131.1736, 119.1197, 131.1... | \n", "
7 | \n", "NaN | \n", "NaN | \n", "0.176091 | \n", "0.000000 | \n", "NaN | \n", "0.342423 | \n", "0.041393 | \n", "4432 | \n", "NaN | \n", "4432-0 | \n", "PQITLWQRPLVTVKIGGQLKEALLDTGADDTVLEEMNLPGRWKPKM... | \n", "ID: 4432-0\\nName: <unknown name>\\nDescription:... | \n", "1.0 | \n", "99 | \n", "[[115.131, 146.1451, 131.1736, 119.1197, 131.1... | \n", "
14 | \n", "NaN | \n", "NaN | \n", "0.491362 | \n", "0.939519 | \n", "NaN | \n", "1.505150 | \n", "1.227887 | \n", "4664 | \n", "NaN | \n", "4664-0 | \n", "PQITLWQRPIVTIKVGGQLIEALLDTGADDTVLEEINLPGRWKPKM... | \n", "ID: 4664-0\\nName: <unknown name>\\nDescription:... | \n", "1.0 | \n", "99 | \n", "[[115.131, 146.1451, 131.1736, 119.1197, 131.1... | \n", "
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "89 | \n", "90 | \n", "91 | \n", "92 | \n", "93 | \n", "94 | \n", "95 | \n", "96 | \n", "97 | \n", "98 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6 | \n", "115.131 | \n", "146.1451 | \n", "131.1736 | \n", "119.1197 | \n", "131.1736 | \n", "204.2262 | \n", "146.1451 | \n", "174.2017 | \n", "115.131 | \n", "131.1736 | \n", "... | \n", "131.1736 | \n", "119.1197 | \n", "146.1451 | \n", "131.1736 | \n", "75.0669 | \n", "121.159 | \n", "119.1197 | \n", "131.1736 | \n", "132.1184 | \n", "165.19 | \n", "
7 | \n", "115.131 | \n", "146.1451 | \n", "131.1736 | \n", "119.1197 | \n", "131.1736 | \n", "204.2262 | \n", "146.1451 | \n", "174.2017 | \n", "115.131 | \n", "131.1736 | \n", "... | \n", "131.1736 | \n", "119.1197 | \n", "146.1451 | \n", "131.1736 | \n", "75.0669 | \n", "121.159 | \n", "119.1197 | \n", "131.1736 | \n", "132.1184 | \n", "165.19 | \n", "
14 | \n", "115.131 | \n", "146.1451 | \n", "131.1736 | \n", "119.1197 | \n", "131.1736 | \n", "204.2262 | \n", "146.1451 | \n", "174.2017 | \n", "115.131 | \n", "131.1736 | \n", "... | \n", "149.2124 | \n", "119.1197 | \n", "146.1451 | \n", "131.1736 | \n", "75.0669 | \n", "121.159 | \n", "119.1197 | \n", "131.1736 | \n", "132.1184 | \n", "165.19 | \n", "
3 rows × 99 columns
\n", "