{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This notebook gives you a short introduction on how to use Dask to parallelize model training, particularly if you have multiple learning tasks on which you want to train individual models for.\n", "\n", "For brevity, I will not be elaborating on the exact machine learning task here, but focus on the idioms that we need to use Dask for this task." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "from dask.distributed import LocalCluster, Client\n", "import numpy as np\n", "import pandas as pd\n", "import janitor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate a Dask Cluster\n", "\n", "Here, we instantiate a Dask `cluster` (this is only a `LocalCluster`, but other cluster types can be created too, such as an `SGECluster` or `KubeCluster`. We then connect a `client` to the cluster." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ericmjl/anaconda/envs/minimal-panel/lib/python3.7/site-packages/distributed/dashboard/core.py:72: UserWarning: \n", "Port 8787 is already in use. \n", "Perhaps you already have a cluster running?\n", "Hosting the diagnostics dashboard on a random port instead.\n", " warnings.warn(\"\\n\" + msg)\n" ] } ], "source": [ "client = Client()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "\n", "We will now preprocess our data and get it into a shape for machine learning." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from utils import molecular_weights, featurize_sequence_" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "drugs = ['ATV', 'DRV', 'FPV', 'IDV', 'LPV', 'NFV', 'SQV', 'TPV']\n", "\n", "data = (\n", " pd.read_csv(\"data/hiv-protease-data-expanded.csv\", index_col=0)\n", " .query(\"weight == 1.0\")\n", " .transform_column(\"sequence\", lambda x: len(x), \"seq_length\")\n", " .query(\"seq_length == 99\")\n", " .transform_column(\"sequence\", featurize_sequence_, \"features\")\n", " .transform_columns(drugs, np.log10)\n", ")\n", "\n", "features = pd.DataFrame(np.vstack(data['features'])).set_index(data.index)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ATVDRVFPVIDVLPVNFVSQVSeqIDTPVseqidsequencesequence_objectweightseq_lengthfeatures
61.50515NaN0.4771211.5440681.505151.4623982.2148444426NaN4426-0PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEEMNLPGKWKPKM...ID: 4426-0\\nName: <unknown name>\\nDescription:...1.099[[115.131, 146.1451, 131.1736, 119.1197, 131.1...
7NaNNaN0.1760910.000000NaN0.3424230.0413934432NaN4432-0PQITLWQRPLVTVKIGGQLKEALLDTGADDTVLEEMNLPGRWKPKM...ID: 4432-0\\nName: <unknown name>\\nDescription:...1.099[[115.131, 146.1451, 131.1736, 119.1197, 131.1...
14NaNNaN0.4913620.939519NaN1.5051501.2278874664NaN4664-0PQITLWQRPIVTIKVGGQLIEALLDTGADDTVLEEINLPGRWKPKM...ID: 4664-0\\nName: <unknown name>\\nDescription:...1.099[[115.131, 146.1451, 131.1736, 119.1197, 131.1...
\n", "
" ], "text/plain": [ " ATV DRV FPV IDV LPV NFV SQV SeqID TPV \\\n", "6 1.50515 NaN 0.477121 1.544068 1.50515 1.462398 2.214844 4426 NaN \n", "7 NaN NaN 0.176091 0.000000 NaN 0.342423 0.041393 4432 NaN \n", "14 NaN NaN 0.491362 0.939519 NaN 1.505150 1.227887 4664 NaN \n", "\n", " seqid sequence \\\n", "6 4426-0 PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEEMNLPGKWKPKM... \n", "7 4432-0 PQITLWQRPLVTVKIGGQLKEALLDTGADDTVLEEMNLPGRWKPKM... \n", "14 4664-0 PQITLWQRPIVTIKVGGQLIEALLDTGADDTVLEEINLPGRWKPKM... \n", "\n", " sequence_object weight seq_length \\\n", "6 ID: 4426-0\\nName: \\nDescription:... 1.0 99 \n", "7 ID: 4432-0\\nName: \\nDescription:... 1.0 99 \n", "14 ID: 4664-0\\nName: \\nDescription:... 1.0 99 \n", "\n", " features \n", "6 [[115.131, 146.1451, 131.1736, 119.1197, 131.1... \n", "7 [[115.131, 146.1451, 131.1736, 119.1197, 131.1... \n", "14 [[115.131, 146.1451, 131.1736, 119.1197, 131.1... " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head(3)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...89909192939495969798
6115.131146.1451131.1736119.1197131.1736204.2262146.1451174.2017115.131131.1736...131.1736119.1197146.1451131.173675.0669121.159119.1197131.1736132.1184165.19
7115.131146.1451131.1736119.1197131.1736204.2262146.1451174.2017115.131131.1736...131.1736119.1197146.1451131.173675.0669121.159119.1197131.1736132.1184165.19
14115.131146.1451131.1736119.1197131.1736204.2262146.1451174.2017115.131131.1736...149.2124119.1197146.1451131.173675.0669121.159119.1197131.1736132.1184165.19
\n", "

3 rows × 99 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "6 115.131 146.1451 131.1736 119.1197 131.1736 204.2262 146.1451 \n", "7 115.131 146.1451 131.1736 119.1197 131.1736 204.2262 146.1451 \n", "14 115.131 146.1451 131.1736 119.1197 131.1736 204.2262 146.1451 \n", "\n", " 7 8 9 ... 89 90 91 92 \\\n", "6 174.2017 115.131 131.1736 ... 131.1736 119.1197 146.1451 131.1736 \n", "7 174.2017 115.131 131.1736 ... 131.1736 119.1197 146.1451 131.1736 \n", "14 174.2017 115.131 131.1736 ... 149.2124 119.1197 146.1451 131.1736 \n", "\n", " 93 94 95 96 97 98 \n", "6 75.0669 121.159 119.1197 131.1736 132.1184 165.19 \n", "7 75.0669 121.159 119.1197 131.1736 132.1184 165.19 \n", "14 75.0669 121.159 119.1197 131.1736 132.1184 165.19 \n", "\n", "[3 rows x 99 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features.head(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training functions\n", "\n", "When writing code to interface with Dask, a functional paradigm is often preferred. Hence, we will write the procedures that are needed inside functions that can be submitted by the `client` to the `cluster`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from utils import featurize_sequence_, fit_model, cross_validate, predict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we'll scatter the data around the workers. `dataf` is named as such because this is the \"data futures\", a \"promise\" to the workers that `data` will exist for them and that they can access it. Likewise for `featuresf`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "dataf = client.scatter(data)\n", "featuresf = client.scatter(features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we fit the models, and collect their cross-validated scores." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "models = dict()\n", "scores = dict()\n", "\n", "\n", "for drug in drugs:\n", " models[drug] = client.submit(fit_model, dataf, featuresf, drug)\n", " scores[drug] = client.submit(cross_validate, dataf, featuresf, drug)\n", " \n", "models = client.gather(models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's save the models. To save space on disk, we will pickle and gzip them." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import pickle as pkl\n", "import gzip\n", "\n", "for name, model in models.items():\n", " with gzip.open(f\"data/models/{name}.pkl.gz\", 'wb') as f:\n", " pkl.dump(model, f)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "scores = client.gather(scores)\n", "with gzip.open(\"data/scores.pkl.gz\", \"wb\") as f:\n", " pkl.dump(scores, f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "minimal-panel", "language": "python", "name": "minimal-panel" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }