{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Molecular Representation using Graph Neural Network - Training Molecular Graph\n", "\n", "> Taking a look at how graph neural network operate for molecular representations\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- categories: [rdkit, machine learning, graph neural network]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Motivation\n", "\n", "In the [previous post](https://sunhwan.github.io/blog/2021/02/20/Learning-Molecular-Representation-Using-Graph-Neural-Network-Molecular-Graph.html), I have looked into how a molecular graph is constructed and message can be passed around in a MPNN architecture. In this post, I'll take a look at how the graph neural net can be trained. The details of message passing update will vary by implementation; here we choose what was used in this [paper](http://dx.doi.org/10.1021/acs.jcim.9b00237).\n", "\n", "Again, many code examples were taken from [chemprop](https://github.com/chemprop/chemprop) repository. The code was initially taken from the chemprop repository and I edited them for the sake of simplicity." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Recap\n", "\n", "![feature](files/molecular_graph_features.png)\n", "\n", "Let's briefly recap how the molecular graph is constructed and messages are passed around. As shown in the above figure, each atom and bonds are labeled as $x$ and $e$. Here we are discussing D-MPNN, which represents the graph with a directional edges, which means, for each bond between two atoms $v$ and $w$, there are two directional bond, $e_{vw}$ and $e_{wv}$. \n", "\n", "The initial hidden message is constructed as $h_{vw}^0 = \\tau (W_i \\mathrm{cat}(x_v, e_{vw}))$ where $W_i$ is a learned matrix and $\\tau$ is an activation function. \n", "\n", "![message](files/molecular_graph_message.png)\n", "\n", "Above figure shows two messages, $m_{57}$ and $m_{54}$, are constructed. First, the message from atom $v$ to $w$ is the sum of all the hidden state for the incoming bonds to $v$ (excluding the one originating from $w$). Then the learned matrix $W_m$ is multiplied to the message and the initial hidden message is added to form the new hidden message for the depth 1. This is repeated several times for the message to be passed around to multiple depth. \n", "\n", "After the messages are passed up to the given number of depth, the hidden states are summed to be a final message per each atom (all incoming hidden state) and the hidden state for each atom is computed as follows:\n", "\n", "\n", "$$m_v = \\sum_{k \\in N(v)} h_{kv}^t$$\n", "\n", "$$h_v = \\tau(W_a \\mathrm{cat} (x_v, m_v))$$\n", "\n", "Finally, the readout phase uses the sum of all $h_v$ to obtain the feature vector of the molecule and property prediction is carried out using a fully-connected feed forward network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train Data\n", "\n", "As an example, I'll use Enamine Real's diversity discovery set composed of 10240 compounds. This dataset contains some molecular properties, such as ClogP and TPSA, so we should be able to train a GCNN that predicts those properties.\n", "\n", "For this example, let's train using ClogP values." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2020.03.2\n" ] } ], "source": [ "#collapse-hide\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "\n", "from io import BytesIO\n", "import pandas as pd\n", "import numpy as np\n", "from IPython.display import SVG\n", "\n", "# RDKit \n", "import rdkit\n", "from rdkit.Chem import PandasTools\n", "from rdkit import Chem\n", "from rdkit.Chem import AllChem\n", "from rdkit.Chem import DataStructs\n", "from rdkit.Chem import rdMolDescriptors\n", "from rdkit.Chem import rdRGroupDecomposition\n", "from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules\n", "from rdkit.Chem import Draw\n", "from rdkit.Chem import rdDepictor\n", "from rdkit.Chem.Draw import rdMolDraw2D\n", "from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults\n", "\n", "DrawingOptions.bondLineWidth=1.8\n", "IPythonConsole.ipython_useSVG=True\n", "from rdkit import RDLogger\n", "RDLogger.DisableLog('rdApp.warning')\n", "print(rdkit.__version__)\n", "\n", "# pytorch\n", "import torch\n", "from torch.utils.data import DataLoader, Dataset, Sampler\n", "from torch import nn\n", "\n", "# misc\n", "from typing import Dict, Iterator, List, Optional, Union, OrderedDict, Tuple\n", "from tqdm.notebook import tqdm\n", "from functools import reduce" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "# we will define a class which holds various parameter for D-MPNN\n", "class TrainArgs:\n", " smiles_column = None\n", " no_cuda = False\n", " gpu = None\n", " num_workers = 8\n", " batch_size = 50\n", " no_cache_mol = False\n", " dataset_type = 'regression'\n", " task_names = []\n", " seed = 0\n", " hidden_size = 300\n", " bias = False\n", " depth = 3\n", " dropout = 0.0\n", " undirected = False\n", " aggregation = 'mean'\n", " aggregation_norm = 100\n", " ffn_num_layers = 2\n", " ffn_hidden_size = 300\n", " init_lr = 1e-4\n", " max_lr = 1e-3\n", " final_lr = 1e-4\n", " num_lrs = 1\n", " warmup_epochs = 2.0\n", " epochs = 30\n", "\n", " @property\n", " def device(self) -> torch.device:\n", " \"\"\"The :code:`torch.device` on which to load and process data and models.\"\"\"\n", " if not self.cuda:\n", " return torch.device('cpu')\n", "\n", " return torch.device('cuda', self.gpu)\n", "\n", " @device.setter\n", " def device(self, device: torch.device) -> None:\n", " self.cuda = device.type == 'cuda'\n", " self.gpu = device.index\n", "\n", " @property\n", " def cuda(self) -> bool:\n", " \"\"\"Whether to use CUDA (i.e., GPUs) or not.\"\"\"\n", " return not self.no_cuda and torch.cuda.is_available()\n", "\n", " @cuda.setter\n", " def cuda(self, cuda: bool) -> None:\n", " self.no_cuda = not cuda" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "args = TrainArgs()\n", "args.data_path = 'files/enamine_discovery_diversity_set_10240.csv'\n", "args.target_column = 'ClogP'\n", "args.smiles_column = 'SMILES'\n", "args.dataset_type = 'regression'\n", "args.task_names = [args.target_column]\n", "args.num_tasks = 1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Name | \n", "SMILES | \n", "Catalog ID | \n", "PlateID | \n", "Well | \n", "MW (desalted) | \n", "ClogP | \n", "HBD | \n", "TPSA | \n", "RotBonds | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "NaN | \n", "CN(C(=O)NC1CCOc2ccccc21)C(c1ccccc1)c1ccccn1 | \n", "Z447596076 | \n", "1186474-R-001 | \n", "A02 | \n", "373.448 | \n", "2.419 | \n", "1 | \n", "54.46 | \n", "4 | \n", "
1 | \n", "NaN | \n", "Cn1cc(C(=O)N2CCC(OC3CCOC3)CC2)c(C2CC2)n1 | \n", "Z2180753156 | \n", "1186474-R-001 | \n", "A03 | \n", "319.399 | \n", "-0.570 | \n", "0 | \n", "56.59 | \n", "4 | \n", "
2 | \n", "NaN | \n", "CC(=O)N(C)C1CCN(C(=O)c2ccccc2-c2ccccc2C(=O)O)CC1 | \n", "Z2295858832 | \n", "1186474-R-001 | \n", "A04 | \n", "380.437 | \n", "0.559 | \n", "1 | \n", "77.92 | \n", "4 | \n", "
3 | \n", "NaN | \n", "COCC1(CNc2cnccc2C#N)CCNCC1 | \n", "Z2030994006 | \n", "1186474-R-001 | \n", "A05 | \n", "260.335 | \n", "0.902 | \n", "2 | \n", "69.97 | \n", "5 | \n", "
4 | \n", "NaN | \n", "CCCCOc1ccc(-c2nnc3n2CCCC3)cc1OC | \n", "Z273627850 | \n", "1186474-R-001 | \n", "A06 | \n", "301.383 | \n", "3.227 | \n", "0 | \n", "49.17 | \n", "6 | \n", "