{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Molecular Representation using Graph Neural Network - Training Molecular Graph\n", "\n", "> Taking a look at how graph neural network operate for molecular representations\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- categories: [rdkit, machine learning, graph neural network]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Motivation\n", "\n", "In the [previous post](https://sunhwan.github.io/blog/2021/02/20/Learning-Molecular-Representation-Using-Graph-Neural-Network-Molecular-Graph.html), I have looked into how a molecular graph is constructed and message can be passed around in a MPNN architecture. In this post, I'll take a look at how the graph neural net can be trained. The details of message passing update will vary by implementation; here we choose what was used in this [paper](http://dx.doi.org/10.1021/acs.jcim.9b00237).\n", "\n", "Again, many code examples were taken from [chemprop](https://github.com/chemprop/chemprop) repository. The code was initially taken from the chemprop repository and I edited them for the sake of simplicity." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Recap\n", "\n", "![feature](files/molecular_graph_features.png)\n", "\n", "Let's briefly recap how the molecular graph is constructed and messages are passed around. As shown in the above figure, each atom and bonds are labeled as $x$ and $e$. Here we are discussing D-MPNN, which represents the graph with a directional edges, which means, for each bond between two atoms $v$ and $w$, there are two directional bond, $e_{vw}$ and $e_{wv}$. \n", "\n", "The initial hidden message is constructed as $h_{vw}^0 = \\tau (W_i \\mathrm{cat}(x_v, e_{vw}))$ where $W_i$ is a learned matrix and $\\tau$ is an activation function. \n", "\n", "![message](files/molecular_graph_message.png)\n", "\n", "Above figure shows two messages, $m_{57}$ and $m_{54}$, are constructed. First, the message from atom $v$ to $w$ is the sum of all the hidden state for the incoming bonds to $v$ (excluding the one originating from $w$). Then the learned matrix $W_m$ is multiplied to the message and the initial hidden message is added to form the new hidden message for the depth 1. This is repeated several times for the message to be passed around to multiple depth. \n", "\n", "After the messages are passed up to the given number of depth, the hidden states are summed to be a final message per each atom (all incoming hidden state) and the hidden state for each atom is computed as follows:\n", "\n", "\n", "$$m_v = \\sum_{k \\in N(v)} h_{kv}^t$$\n", "\n", "$$h_v = \\tau(W_a \\mathrm{cat} (x_v, m_v))$$\n", "\n", "Finally, the readout phase uses the sum of all $h_v$ to obtain the feature vector of the molecule and property prediction is carried out using a fully-connected feed forward network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train Data\n", "\n", "As an example, I'll use Enamine Real's diversity discovery set composed of 10240 compounds. This dataset contains some molecular properties, such as ClogP and TPSA, so we should be able to train a GCNN that predicts those properties.\n", "\n", "For this example, let's train using ClogP values." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2020.03.2\n" ] } ], "source": [ "#collapse-hide\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "\n", "from io import BytesIO\n", "import pandas as pd\n", "import numpy as np\n", "from IPython.display import SVG\n", "\n", "# RDKit \n", "import rdkit\n", "from rdkit.Chem import PandasTools\n", "from rdkit import Chem\n", "from rdkit.Chem import AllChem\n", "from rdkit.Chem import DataStructs\n", "from rdkit.Chem import rdMolDescriptors\n", "from rdkit.Chem import rdRGroupDecomposition\n", "from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules\n", "from rdkit.Chem import Draw\n", "from rdkit.Chem import rdDepictor\n", "from rdkit.Chem.Draw import rdMolDraw2D\n", "from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults\n", "\n", "DrawingOptions.bondLineWidth=1.8\n", "IPythonConsole.ipython_useSVG=True\n", "from rdkit import RDLogger\n", "RDLogger.DisableLog('rdApp.warning')\n", "print(rdkit.__version__)\n", "\n", "# pytorch\n", "import torch\n", "from torch.utils.data import DataLoader, Dataset, Sampler\n", "from torch import nn\n", "\n", "# misc\n", "from typing import Dict, Iterator, List, Optional, Union, OrderedDict, Tuple\n", "from tqdm.notebook import tqdm\n", "from functools import reduce" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "# we will define a class which holds various parameter for D-MPNN\n", "class TrainArgs:\n", " smiles_column = None\n", " no_cuda = False\n", " gpu = None\n", " num_workers = 8\n", " batch_size = 50\n", " no_cache_mol = False\n", " dataset_type = 'regression'\n", " task_names = []\n", " seed = 0\n", " hidden_size = 300\n", " bias = False\n", " depth = 3\n", " dropout = 0.0\n", " undirected = False\n", " aggregation = 'mean'\n", " aggregation_norm = 100\n", " ffn_num_layers = 2\n", " ffn_hidden_size = 300\n", " init_lr = 1e-4\n", " max_lr = 1e-3\n", " final_lr = 1e-4\n", " num_lrs = 1\n", " warmup_epochs = 2.0\n", " epochs = 30\n", "\n", " @property\n", " def device(self) -> torch.device:\n", " \"\"\"The :code:`torch.device` on which to load and process data and models.\"\"\"\n", " if not self.cuda:\n", " return torch.device('cpu')\n", "\n", " return torch.device('cuda', self.gpu)\n", "\n", " @device.setter\n", " def device(self, device: torch.device) -> None:\n", " self.cuda = device.type == 'cuda'\n", " self.gpu = device.index\n", "\n", " @property\n", " def cuda(self) -> bool:\n", " \"\"\"Whether to use CUDA (i.e., GPUs) or not.\"\"\"\n", " return not self.no_cuda and torch.cuda.is_available()\n", "\n", " @cuda.setter\n", " def cuda(self, cuda: bool) -> None:\n", " self.no_cuda = not cuda" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "args = TrainArgs()\n", "args.data_path = 'files/enamine_discovery_diversity_set_10240.csv'\n", "args.target_column = 'ClogP'\n", "args.smiles_column = 'SMILES'\n", "args.dataset_type = 'regression'\n", "args.task_names = [args.target_column]\n", "args.num_tasks = 1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
NameSMILESCatalog IDPlateIDWellMW (desalted)ClogPHBDTPSARotBonds
0NaNCN(C(=O)NC1CCOc2ccccc21)C(c1ccccc1)c1ccccn1Z4475960761186474-R-001A02373.4482.419154.464
1NaNCn1cc(C(=O)N2CCC(OC3CCOC3)CC2)c(C2CC2)n1Z21807531561186474-R-001A03319.399-0.570056.594
2NaNCC(=O)N(C)C1CCN(C(=O)c2ccccc2-c2ccccc2C(=O)O)CC1Z22958588321186474-R-001A04380.4370.559177.924
3NaNCOCC1(CNc2cnccc2C#N)CCNCC1Z20309940061186474-R-001A05260.3350.902269.975
4NaNCCCCOc1ccc(-c2nnc3n2CCCC3)cc1OCZ2736278501186474-R-001A06301.3833.227049.176
\n", "
" ], "text/plain": [ " Name SMILES Catalog ID \\\n", "0 NaN CN(C(=O)NC1CCOc2ccccc21)C(c1ccccc1)c1ccccn1 Z447596076 \n", "1 NaN Cn1cc(C(=O)N2CCC(OC3CCOC3)CC2)c(C2CC2)n1 Z2180753156 \n", "2 NaN CC(=O)N(C)C1CCN(C(=O)c2ccccc2-c2ccccc2C(=O)O)CC1 Z2295858832 \n", "3 NaN COCC1(CNc2cnccc2C#N)CCNCC1 Z2030994006 \n", "4 NaN CCCCOc1ccc(-c2nnc3n2CCCC3)cc1OC Z273627850 \n", "\n", " PlateID Well MW (desalted) ClogP HBD TPSA RotBonds \n", "0 1186474-R-001 A02 373.448 2.419 1 54.46 4 \n", "1 1186474-R-001 A03 319.399 -0.570 0 56.59 4 \n", "2 1186474-R-001 A04 380.437 0.559 1 77.92 4 \n", "3 1186474-R-001 A05 260.335 0.902 2 69.97 5 \n", "4 1186474-R-001 A06 301.383 3.227 0 49.17 6 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(args.data_path)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "\n", "from random import Random\n", "\n", "# Cache of graph featurizations\n", "CACHE_GRAPH = True\n", "SMILES_TO_GRAPH = {}\n", "\n", "def cache_graph():\n", " return CACHE_GRAPH\n", "\n", "def set_cache_graph(cache_graph):\n", " global CACHE_GRAPH\n", " CACHE_GRAPH = cache_graph\n", "\n", "# Cache of RDKit molecules\n", "CACHE_MOL = True\n", "SMILES_TO_MOL: Dict[str, Chem.Mol] = {}\n", "\n", "def cache_mol() -> bool:\n", " r\"\"\"Returns whether RDKit molecules will be cached.\"\"\"\n", " return CACHE_MOL\n", "\n", "def set_cache_mol(cache_mol: bool) -> None:\n", " r\"\"\"Sets whether RDKit molecules will be cached.\"\"\"\n", " global CACHE_MOL\n", " CACHE_MOL = cache_mol\n", " \n", "# Atom feature sizes\n", "MAX_ATOMIC_NUM = 100\n", "ATOM_FEATURES = {\n", " 'atomic_num': list(range(MAX_ATOMIC_NUM)),\n", " 'degree': [0, 1, 2, 3, 4, 5],\n", " 'formal_charge': [-1, -2, 1, 2, 0],\n", " 'chiral_tag': [0, 1, 2, 3],\n", " 'num_Hs': [0, 1, 2, 3, 4],\n", " 'hybridization': [\n", " Chem.rdchem.HybridizationType.SP,\n", " Chem.rdchem.HybridizationType.SP2,\n", " Chem.rdchem.HybridizationType.SP3,\n", " Chem.rdchem.HybridizationType.SP3D,\n", " Chem.rdchem.HybridizationType.SP3D2\n", " ],\n", "}\n", "\n", "# Distance feature sizes\n", "PATH_DISTANCE_BINS = list(range(10))\n", "THREE_D_DISTANCE_MAX = 20\n", "THREE_D_DISTANCE_STEP = 1\n", "THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP))\n", "\n", "# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass\n", "ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2\n", "EXTRA_ATOM_FDIM = 0\n", "BOND_FDIM = 14\n", "\n", "\n", "def get_atom_fdim():\n", " \"\"\"Gets the dimensionality of the atom feature vector.\"\"\"\n", " return ATOM_FDIM + EXTRA_ATOM_FDIM\n", "\n", "def get_bond_fdim(atom_messages=False):\n", " \"\"\"Gets the dimensionality of the bond feature vector.\n", " \"\"\"\n", " return BOND_FDIM + (not atom_messages) * get_atom_fdim()\n", "\n", "def onek_encoding_unk(value: int, choices: List[int]):\n", " encoding = [0] * (len(choices) + 1)\n", " index = choices.index(value) if value in choices else -1\n", " encoding[index] = 1\n", "\n", " return encoding\n", "\n", "def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None):\n", " \"\"\"Builds a feature vector for an atom.\n", " \"\"\"\n", " features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \\\n", " onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \\\n", " onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \\\n", " onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \\\n", " onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \\\n", " onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \\\n", " [1 if atom.GetIsAromatic() else 0] + \\\n", " [atom.GetMass() * 0.01] # scaled to about the same range as other features\n", " if functional_groups is not None:\n", " features += functional_groups\n", " return features\n", "\n", "def bond_features(bond: Chem.rdchem.Bond):\n", " \"\"\"Builds a feature vector for a bond.\n", " \"\"\"\n", " if bond is None:\n", " fbond = [1] + [0] * (BOND_FDIM - 1)\n", " else:\n", " bt = bond.GetBondType()\n", " fbond = [\n", " 0, # bond is not None\n", " bt == Chem.rdchem.BondType.SINGLE,\n", " bt == Chem.rdchem.BondType.DOUBLE,\n", " bt == Chem.rdchem.BondType.TRIPLE,\n", " bt == Chem.rdchem.BondType.AROMATIC,\n", " (bond.GetIsConjugated() if bt is not None else 0),\n", " (bond.IsInRing() if bt is not None else 0)\n", " ]\n", " fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))\n", " return fbond\n", " \n", "class MoleculeDatapoint:\n", " def __init__(self,\n", " smiles: str,\n", " targets: List[Optional[float]] = None,\n", " row: OrderedDict = None):\n", " \n", " self.smiles = smiles\n", " self.targets = targets\n", " self.features = []\n", " self.row = row\n", "\n", " @property\n", " def mol(self) -> Chem.Mol:\n", " \"\"\"Gets the corresponding list of RDKit molecules for the corresponding SMILES.\"\"\"\n", " mol = SMILES_TO_MOL.get(self.smiles, Chem.MolFromSmiles(self.smiles))\n", " if cache_mol():\n", " SMILES_TO_MOL[self.smiles] = mol\n", " return mol\n", "\n", " def set_features(self, features: np.ndarray) -> None:\n", " \"\"\"Sets the features of the molecule.\n", " \"\"\"\n", " self.features = features\n", "\n", " def extend_features(self, features: np.ndarray) -> None:\n", " \"\"\"Extends the features of the molecule.\n", " \"\"\"\n", " self.features = np.append(self.features, features) if self.features is not None else features\n", "\n", " def num_tasks(self) -> int:\n", " \"\"\"Returns the number of prediction tasks.\n", " \"\"\"\n", " return len(self.targets)\n", "\n", " def set_targets(self, targets: List[Optional[float]]):\n", " \"\"\"Sets the targets of a molecule.\n", " \"\"\"\n", " self.targets = targets\n", "\n", " def reset_features_and_targets(self) -> None:\n", " \"\"\"Resets the features and targets to their raw values.\"\"\"\n", " self.features, self.targets = self.raw_features, self.raw_targets\n", " \n", " \n", "class MoleculeDataset(Dataset):\n", " def __init__(self, data: List[MoleculeDatapoint]):\n", " self._data = data\n", " self._scaler = None\n", " self._batch_graph = None\n", " self._random = Random()\n", "\n", " def smiles(self) -> List[str]:\n", " return [d.smiles for d in self._data]\n", "\n", " def mols(self) -> List[Chem.Mol]:\n", " return [d.mol for d in self._data]\n", "\n", " def targets(self) -> List[List[Optional[float]]]:\n", " return [d.targets for d in self._data]\n", "\n", " def num_tasks(self) -> int:\n", " return self._data[0].num_tasks() if len(self._data) > 0 else None\n", "\n", " def set_targets(self, targets: List[List[Optional[float]]]) -> None:\n", " assert len(self._data) == len(targets)\n", " for i in range(len(self._data)):\n", " self._data[i].set_targets(targets[i])\n", "\n", " def reset_features_and_targets(self) -> None:\n", " for d in self._data:\n", " d.reset_features_and_targets()\n", "\n", " def __len__(self) -> int:\n", " return len(self._data)\n", "\n", " def __getitem__(self, item) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:\n", " return self._data[item]\n", " \n", " def batch_graph(self):\n", " if self._batch_graph is None:\n", " self._batch_graph = []\n", "\n", " mol_graphs = []\n", " for d in self._data:\n", " mol_graphs_list = []\n", " if d.smiles in SMILES_TO_GRAPH:\n", " mol_graph = SMILES_TO_GRAPH[d.smiles]\n", " else:\n", " mol_graph = MolGraph(d.mol)\n", " if cache_graph():\n", " SMILES_TO_GRAPH[d.smiles] = mol_graph\n", " mol_graphs.append([mol_graph])\n", "\n", " self._batch_graph = [BatchMolGraph([g[i] for g in mol_graphs]) for i in range(len(mol_graphs[0]))]\n", "\n", " return self._batch_graph\n", " \n", " def features(self) -> List[np.ndarray]:\n", " \"\"\"\n", " Returns the features associated with each molecule (if they exist).\n", "\n", " :return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.\n", " \"\"\"\n", " if len(self._data) == 0 or self._data[0].features is None:\n", " return None\n", "\n", " return [d.features for d in self._data]\n", " \n", "\n", "def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Selects the message features from source corresponding to the atom or bond indices in index.\n", " \"\"\"\n", " index_size = index.size() # (num_atoms/num_bonds, max_num_bonds)\n", " suffix_dim = source.size()[1:] # (hidden_size,)\n", " final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size)\n", "\n", " target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size)\n", " target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size)\n", " return target\n", "\n", "class MolGraph:\n", " def __init__(self, mol, atom_descriptors=None):\n", " # Convert SMILES to RDKit molecule if necessary\n", " if type(mol) == str:\n", " mol = Chem.MolFromSmiles(mol)\n", "\n", " self.n_atoms = 0 # number of atoms\n", " self.n_bonds = 0 # number of bonds\n", " self.f_atoms = [] # mapping from atom index to atom features\n", " self.f_bonds = [] # mapping from bond index to concat(in_atom, bond) features\n", " self.a2b = [] # mapping from atom index to incoming bond indices\n", " self.b2a = [] # mapping from bond index to the index of the atom the bond is coming from\n", " self.b2revb = [] # mapping from bond index to the index of the reverse bond\n", "\n", " # Get atom features\n", " self.f_atoms = [atom_features(atom) for atom in mol.GetAtoms()]\n", " if atom_descriptors is not None:\n", " self.f_atoms = [f_atoms + descs.tolist() for f_atoms, descs in zip(self.f_atoms, atom_descriptors)]\n", "\n", " self.n_atoms = len(self.f_atoms)\n", "\n", " # Initialize atom to bond mapping for each atom\n", " for _ in range(self.n_atoms):\n", " self.a2b.append([])\n", "\n", " # Get bond features\n", " for a1 in range(self.n_atoms):\n", " for a2 in range(a1 + 1, self.n_atoms):\n", " bond = mol.GetBondBetweenAtoms(a1, a2)\n", "\n", " if bond is None:\n", " continue\n", "\n", " f_bond = bond_features(bond)\n", " self.f_bonds.append(self.f_atoms[a1] + f_bond)\n", " self.f_bonds.append(self.f_atoms[a2] + f_bond)\n", "\n", " # Update index mappings\n", " b1 = self.n_bonds\n", " b2 = b1 + 1\n", " self.a2b[a2].append(b1) # b1 = a1 --> a2\n", " self.b2a.append(a1)\n", " self.a2b[a1].append(b2) # b2 = a2 --> a1\n", " self.b2a.append(a2)\n", " self.b2revb.append(b2)\n", " self.b2revb.append(b1)\n", " self.n_bonds += 2\n", "\n", "\n", "class BatchMolGraph:\n", " \"\"\"A `BatchMolGraph` represents the graph structure and featurization of a batch of molecules.\n", " \"\"\"\n", "\n", " def __init__(self, mol_graphs: List[MolGraph]):\n", " self.atom_fdim = get_atom_fdim()\n", " self.bond_fdim = get_bond_fdim()\n", "\n", " # Start n_atoms and n_bonds at 1 b/c zero padding\n", " self.n_atoms = 1 # number of atoms (start at 1 b/c need index 0 as padding)\n", " self.n_bonds = 1 # number of bonds (start at 1 b/c need index 0 as padding)\n", " self.a_scope = [] # list of tuples indicating (start_atom_index, num_atoms) for each molecule\n", " self.b_scope = [] # list of tuples indicating (start_bond_index, num_bonds) for each molecule\n", "\n", " # All start with zero padding so that indexing with zero padding returns zeros\n", " f_atoms = [[0] * self.atom_fdim] # atom features\n", " f_bonds = [[0] * self.bond_fdim] # combined atom/bond features\n", " a2b = [[]] # mapping from atom index to incoming bond indices\n", " b2a = [0] # mapping from bond index to the index of the atom the bond is coming from\n", " b2revb = [0] # mapping from bond index to the index of the reverse bond\n", " for mol_graph in mol_graphs:\n", " f_atoms.extend(mol_graph.f_atoms)\n", " f_bonds.extend(mol_graph.f_bonds)\n", "\n", " for a in range(mol_graph.n_atoms):\n", " a2b.append([b + self.n_bonds for b in mol_graph.a2b[a]])\n", "\n", " for b in range(mol_graph.n_bonds):\n", " b2a.append(self.n_atoms + mol_graph.b2a[b])\n", " b2revb.append(self.n_bonds + mol_graph.b2revb[b])\n", "\n", " self.a_scope.append((self.n_atoms, mol_graph.n_atoms))\n", " self.b_scope.append((self.n_bonds, mol_graph.n_bonds))\n", " self.n_atoms += mol_graph.n_atoms\n", " self.n_bonds += mol_graph.n_bonds\n", "\n", " self.max_num_bonds = max(1, max(len(in_bonds) for in_bonds in a2b)) # max with 1 to fix a crash in rare case of all single-heavy-atom mols\n", "\n", " self.f_atoms = torch.FloatTensor(f_atoms)\n", " self.f_bonds = torch.FloatTensor(f_bonds)\n", " self.a2b = torch.LongTensor([a2b[a] + [0] * (self.max_num_bonds - len(a2b[a])) for a in range(self.n_atoms)])\n", " self.b2a = torch.LongTensor(b2a)\n", " self.b2revb = torch.LongTensor(b2revb)\n", " self.b2b = None # try to avoid computing b2b b/c O(n_atoms^3)\n", " self.a2a = None # only needed if using atom messages\n", "\n", " def get_components(self, atom_messages: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor,\n", " torch.LongTensor, torch.LongTensor, torch.LongTensor,\n", " List[Tuple[int, int]], List[Tuple[int, int]]]:\n", " return self.f_atoms, self.f_bonds, self.a2b, self.b2a, self.b2revb, self.a_scope, self.b_scope\n", "\n", " def get_b2b(self) -> torch.LongTensor:\n", " \"\"\"Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.\n", " \"\"\"\n", " if self.b2b is None:\n", " b2b = self.a2b[self.b2a] # num_bonds x max_num_bonds\n", " # b2b includes reverse edge for each bond so need to mask out\n", " revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long() # num_bonds x max_num_bonds\n", " self.b2b = b2b * revmask\n", "\n", " return self.b2b\n", "\n", " def get_a2a(self) -> torch.LongTensor:\n", " \"\"\"Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.\n", " \"\"\"\n", " if self.a2a is None:\n", " # b = a1 --> a2\n", " # a2b maps a2 to all incoming bonds b\n", " # b2a maps each bond b to the atom it comes from a1\n", " # thus b2a[a2b] maps atom a2 to neighboring atoms a1\n", " self.a2a = self.b2a[self.a2b] # num_atoms x max_num_bonds\n", "\n", " return self.a2a" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'MoleculeDataset' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# prepare data set\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m data = MoleculeDataset([\n\u001b[0m\u001b[1;32m 3\u001b[0m MoleculeDatapoint(\n\u001b[1;32m 4\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmiles_column\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_column\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'MoleculeDataset' is not defined" ] } ], "source": [ "\n", "# prepare data set\n", "data = MoleculeDataset([\n", " MoleculeDatapoint(\n", " smiles=row[args.smiles_column],\n", " targets=[row[args.target_column]]\n", " ) for i, row in df.iterrows()\n", "])\n", "\n", "# split data into train, validation and test set\n", "random = Random()\n", "sizes = [0.8, 0.1, 0.1]\n", "\n", "indices = list(range(len(data)))\n", "random.shuffle(indices)\n", "\n", "train_size = int(sizes[0] * len(data))\n", "train_val_size = int((sizes[0] + sizes[1]) * len(data))\n", "\n", "train = [data[i] for i in indices[:train_size]]\n", "val = [data[i] for i in indices[train_size:train_val_size]]\n", "test = [data[i] for i in indices[train_val_size:]]\n", "\n", "train_data = MoleculeDataset(train)\n", "val_data = MoleculeDataset(val)\n", "test_data = MoleculeDataset(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# MPNN Model\n", "\n", "Let's create a MPNN model. The model is composed of encoder and feed-forward network (FFN). The encoder is same as the one we discussed before and the FFN is defined as a straightforward neural network." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "\n", "# Atom feature sizes\n", "MAX_ATOMIC_NUM = 100\n", "ATOM_FEATURES = {\n", " 'atomic_num': list(range(MAX_ATOMIC_NUM)),\n", " 'degree': [0, 1, 2, 3, 4, 5],\n", " 'formal_charge': [-1, -2, 1, 2, 0],\n", " 'chiral_tag': [0, 1, 2, 3],\n", " 'num_Hs': [0, 1, 2, 3, 4],\n", " 'hybridization': [\n", " Chem.rdchem.HybridizationType.SP,\n", " Chem.rdchem.HybridizationType.SP2,\n", " Chem.rdchem.HybridizationType.SP3,\n", " Chem.rdchem.HybridizationType.SP3D,\n", " Chem.rdchem.HybridizationType.SP3D2\n", " ],\n", "}\n", "\n", "# Distance feature sizes\n", "PATH_DISTANCE_BINS = list(range(10))\n", "THREE_D_DISTANCE_MAX = 20\n", "THREE_D_DISTANCE_STEP = 1\n", "THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP))\n", "\n", "# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass\n", "ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2\n", "EXTRA_ATOM_FDIM = 0\n", "BOND_FDIM = 14\n", "\n", "\n", "def get_atom_fdim() -> int:\n", " \"\"\"Gets the dimensionality of the atom feature vector.\"\"\"\n", " return ATOM_FDIM + EXTRA_ATOM_FDIM\n", "\n", "def get_bond_fdim() -> int:\n", " \"\"\"Gets the dimensionality of the bond feature vector.\n", " \"\"\"\n", " return BOND_FDIM + get_atom_fdim()\n", "\n", "\n", "def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:\n", " encoding = [0] * (len(choices) + 1)\n", " index = choices.index(value) if value in choices else -1\n", " encoding[index] = 1\n", "\n", " return encoding\n", "\n", "def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -> List[Union[bool, int, float]]:\n", " \"\"\"Builds a feature vector for an atom.\n", " \"\"\"\n", " features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \\\n", " onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \\\n", " onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \\\n", " onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \\\n", " onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \\\n", " onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \\\n", " [1 if atom.GetIsAromatic() else 0] + \\\n", " [atom.GetMass() * 0.01] # scaled to about the same range as other features\n", " if functional_groups is not None:\n", " features += functional_groups\n", " return features\n", "\n", "\n", "def initialize_weights(model: nn.Module) -> None:\n", " \"\"\"Initializes the weights of a model in place.\n", " \"\"\"\n", " for param in model.parameters():\n", " if param.dim() == 1:\n", " nn.init.constant_(param, 0)\n", " else:\n", " nn.init.xavier_normal_(param)\n", "\n", "class MPNEncoder(nn.Module):\n", " def __init__(self, args, atom_fdim, bond_fdim):\n", " super(MPNEncoder, self).__init__()\n", " self.atom_fdim = atom_fdim\n", " self.bond_fdim = bond_fdim\n", " self.hidden_size = args.hidden_size\n", " self.bias = args.bias\n", " self.depth = args.depth\n", " self.dropout = args.dropout\n", " self.layers_per_message = 1\n", " self.undirected = False\n", " self.atom_messages = False\n", " self.device = args.device\n", " self.aggregation = args.aggregation\n", " self.aggregation_norm = args.aggregation_norm\n", "\n", " self.dropout_layer = nn.Dropout(p=self.dropout)\n", " self.act_func = nn.ReLU()\n", " self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False)\n", "\n", " # Input\n", " input_dim = self.bond_fdim\n", " self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias)\n", " w_h_input_size = self.hidden_size\n", "\n", " # Shared weight matrix across depths (default)\n", " self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)\n", " self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size)\n", "\n", " def forward(self, mol_graph):\n", " \"\"\"Encodes a batch of molecular graphs.\n", " \"\"\"\n", " f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components()\n", " f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), b2revb.to(self.device)\n", "\n", " input = self.W_i(f_bonds) # num_bonds x hidden_size\n", " message = self.act_func(input) # num_bonds x hidden_size\n", "\n", " # Message passing\n", " for depth in range(self.depth - 1):\n", " # m(a1 -> a2) = [sum_{a0 \\in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)\n", " # message a_message = sum(nei_a_message) rev_message\n", " nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden\n", " a_message = nei_a_message.sum(dim=1) # num_atoms x hidden\n", " rev_message = message[b2revb] # num_bonds x hidden\n", " message = a_message[b2a] - rev_message # num_bonds x hidden\n", "\n", " message = self.W_h(message)\n", " message = self.act_func(input + message) # num_bonds x hidden_size\n", " message = self.dropout_layer(message) # num_bonds x hidden\n", "\n", " a2x = a2b\n", " nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden\n", " a_message = nei_a_message.sum(dim=1) # num_atoms x hidden\n", " a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden)\n", " atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden\n", " atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden\n", "\n", " # Readout\n", " mol_vecs = []\n", " for i, (a_start, a_size) in enumerate(a_scope):\n", " if a_size == 0:\n", " mol_vecs.append(self.cached_zero_vector)\n", " else:\n", " cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)\n", " mol_vec = cur_hiddens # (num_atoms, hidden_size)\n", " if self.aggregation == 'mean':\n", " mol_vec = mol_vec.sum(dim=0) / a_size\n", " elif self.aggregation == 'sum':\n", " mol_vec = mol_vec.sum(dim=0)\n", " elif self.aggregation == 'norm':\n", " mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm\n", " mol_vecs.append(mol_vec)\n", "\n", " mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size)\n", "\n", " return mol_vecs # num_molecules x hidden\n", " \n", "\n", "class MPN(nn.Module):\n", " def __init__(self, args, atom_fdim=None, bond_fdim=None):\n", " super(MPN, self).__init__()\n", " self.atom_fdim = atom_fdim or get_atom_fdim()\n", " self.bond_fdim = bond_fdim or get_bond_fdim()\n", " self.device = args.device\n", " self.encoder = MPNEncoder(args, self.atom_fdim, self.bond_fdim)\n", "\n", " def forward(self, batch):\n", " \"\"\"Encodes a batch of molecules.\n", " \"\"\"\n", " if type(batch[0]) != BatchMolGraph:\n", " batch = [mol2graph(b) for b in batch]\n", "\n", " encodings = [self.encoder(batch[0])]\n", " output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings)\n", " return output\n", " \n", "\n", "class MoleculeModel(nn.Module):\n", " def __init__(self, args, featurizer=False):\n", " super(MoleculeModel, self).__init__()\n", "\n", " self.classification = args.dataset_type == 'classification'\n", " self.featurizer = featurizer\n", "\n", " self.output_size = args.num_tasks\n", "\n", " if self.classification:\n", " self.sigmoid = nn.Sigmoid()\n", " \n", " self.create_encoder(args)\n", " self.create_ffn(args)\n", "\n", " initialize_weights(self)\n", "\n", " def create_encoder(self, args):\n", " self.encoder = MPN(args)\n", "\n", " def create_ffn(self, args):\n", " first_linear_dim = args.hidden_size\n", " dropout = nn.Dropout(args.dropout)\n", " activation = nn.ReLU()\n", "\n", " # Create FFN layers\n", " if args.ffn_num_layers == 1:\n", " ffn = [\n", " dropout,\n", " nn.Linear(first_linear_dim, self.output_size)\n", " ]\n", " else:\n", " ffn = [\n", " dropout,\n", " nn.Linear(first_linear_dim, args.ffn_hidden_size)\n", " ]\n", " for _ in range(args.ffn_num_layers - 2):\n", " ffn.extend([\n", " activation,\n", " dropout,\n", " nn.Linear(args.ffn_hidden_size, args.ffn_hidden_size),\n", " ])\n", " ffn.extend([\n", " activation,\n", " dropout,\n", " nn.Linear(args.ffn_hidden_size, self.output_size),\n", " ])\n", "\n", " # Create FFN model\n", " self.ffn = nn.Sequential(*ffn)\n", "\n", " def featurize(self, batch, features_batch=None, atom_descriptors_batch=None):\n", " \"\"\"Computes feature vectors of the input by running the model except for the last layer.\n", " \"\"\"\n", " return self.ffn[:-1](self.encoder(batch, features_batch, atom_descriptors_batch))\n", "\n", " def forward(self, batch):\n", " output = self.ffn(self.encoder(batch))\n", "\n", " # Don't apply sigmoid during training b/c using BCEWithLogitsLoss\n", " if self.classification and not self.training:\n", " output = self.sigmoid(output)\n", " \n", " return output" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model = MoleculeModel(args)\n", "model = model.to(args.device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As it is shown below, the model is comprised of an encoder and FFN. The encoder has three learned matrices and the FFN has 2 fully-connected layers." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MoleculeModel(\n", " (encoder): MPN(\n", " (encoder): MPNEncoder(\n", " (dropout_layer): Dropout(p=0.0, inplace=False)\n", " (act_func): ReLU()\n", " (W_i): Linear(in_features=147, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=433, out_features=300, bias=True)\n", " )\n", " )\n", " (ffn): Sequential(\n", " (0): Dropout(p=0.0, inplace=False)\n", " (1): Linear(in_features=300, out_features=300, bias=True)\n", " (2): ReLU()\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train the MPNN" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "from torch.optim.lr_scheduler import _LRScheduler\n", "from torch.optim import Adam, Optimizer\n", "\n", "class NoamLR(_LRScheduler):\n", " \"\"\"\n", " Noam learning rate scheduler with piecewise linear increase and exponential decay.\n", "\n", " The learning rate increases linearly from init_lr to max_lr over the course of\n", " the first warmup_steps (where :code:`warmup_steps = warmup_epochs * steps_per_epoch`).\n", " Then the learning rate decreases exponentially from :code:`max_lr` to :code:`final_lr` over the\n", " course of the remaining :code:`total_steps - warmup_steps` (where :code:`total_steps =\n", " total_epochs * steps_per_epoch`). This is roughly based on the learning rate\n", " schedule from `Attention is All You Need `_, section 5.3.\n", " \"\"\"\n", " def __init__(self,\n", " optimizer: Optimizer,\n", " warmup_epochs: List[Union[float, int]],\n", " total_epochs: List[int],\n", " steps_per_epoch: int,\n", " init_lr: List[float],\n", " max_lr: List[float],\n", " final_lr: List[float]):\n", "\n", " assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \\\n", " len(max_lr) == len(final_lr)\n", "\n", " self.num_lrs = len(optimizer.param_groups)\n", "\n", " self.optimizer = optimizer\n", " self.warmup_epochs = np.array(warmup_epochs)\n", " self.total_epochs = np.array(total_epochs)\n", " self.steps_per_epoch = steps_per_epoch\n", " self.init_lr = np.array(init_lr)\n", " self.max_lr = np.array(max_lr)\n", " self.final_lr = np.array(final_lr)\n", "\n", " self.current_step = 0\n", " self.lr = init_lr\n", " self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)\n", " self.total_steps = self.total_epochs * self.steps_per_epoch\n", " self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps\n", "\n", " self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))\n", "\n", " super(NoamLR, self).__init__(optimizer)\n", "\n", " def get_lr(self) -> List[float]:\n", " return list(self.lr)\n", "\n", " def step(self, current_step: int = None):\n", " if current_step is not None:\n", " self.current_step = current_step\n", " else:\n", " self.current_step += 1\n", "\n", " for i in range(self.num_lrs):\n", " if self.current_step <= self.warmup_steps[i]:\n", " self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]\n", " elif self.current_step <= self.total_steps[i]:\n", " self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))\n", " else: # theoretically this case should never be reached since training should stop at total_steps\n", " self.lr[i] = self.final_lr[i]\n", "\n", " self.optimizer.param_groups[i]['lr'] = self.lr[i]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#collapse-hide\n", "import threading\n", "\n", "def construct_molecule_batch(data):\n", " data = MoleculeDataset(data)\n", " data.batch_graph() # Forces computation and caching of the BatchMolGraph for the molecules\n", " return data\n", "\n", "class MoleculeSampler(Sampler):\n", " def __init__(self, dataset, shuffle=False, seed=0):\n", " super(Sampler, self).__init__()\n", "\n", " self.dataset = dataset\n", " self.shuffle = shuffle\n", " self._random = Random(seed)\n", " self.positive_indices = self.negative_indices = None\n", " self.length = len(self.dataset)\n", "\n", " def __iter__(self):\n", " indices = list(range(len(self.dataset)))\n", " if self.shuffle:\n", " self._random.shuffle(indices)\n", " return iter(indices)\n", "\n", " def __len__(self):\n", " return self.length\n", " \n", "\n", "class MoleculeDataLoader(DataLoader):\n", " def __init__(self,\n", " dataset: MoleculeDataset,\n", " batch_size: int = 50,\n", " num_workers: int = 8,\n", " shuffle: bool = False,\n", " seed: int = 0):\n", "\n", " self._dataset = dataset\n", " self._batch_size = batch_size\n", " self._num_workers = num_workers\n", " self._shuffle = shuffle\n", " self._seed = seed\n", " self._context = None\n", " self._class_balance = False\n", " self._timeout = 0\n", " is_main_thread = threading.current_thread() is threading.main_thread()\n", " \n", " if not is_main_thread and self._num_workers > 0:\n", " self._context = 'forkserver' # In order to prevent a hanging\n", " self._timeout = 3600 # Just for sure that the DataLoader won't hang\n", "\n", " self._sampler = MoleculeSampler(\n", " dataset=self._dataset,\n", " shuffle=self._shuffle,\n", " seed=self._seed\n", " )\n", "\n", " super(MoleculeDataLoader, self).__init__(\n", " dataset=self._dataset,\n", " batch_size=self._batch_size,\n", " sampler=self._sampler,\n", " num_workers=self._num_workers,\n", " collate_fn=construct_molecule_batch,\n", " multiprocessing_context=self._context,\n", " timeout=self._timeout\n", " )\n", "\n", " @property\n", " def targets(self) -> List[List[Optional[float]]]:\n", " if self._class_balance or self._shuffle:\n", " raise ValueError('Cannot safely extract targets when class balance or shuffle are enabled.')\n", "\n", " return [self._dataset[index].targets for index in self._sampler]\n", "\n", " @property\n", " def iter_size(self) -> int:\n", " return len(self._sampler)\n", "\n", " def __iter__(self) -> Iterator[MoleculeDataset]:\n", " return super(MoleculeDataLoader, self).__iter__()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "\n", "# Create data loaders\n", "train_data_loader = MoleculeDataLoader(\n", " dataset=train_data,\n", " batch_size=args.batch_size,\n", " num_workers=8,\n", " shuffle=True,\n", " seed=args.seed\n", ")\n", "val_data_loader = MoleculeDataLoader(\n", " dataset=val_data,\n", " batch_size=args.batch_size,\n", " num_workers=8\n", ")\n", "test_data_loader = MoleculeDataLoader(\n", " dataset=test_data,\n", " batch_size=args.batch_size,\n", " num_workers=8\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "\n", "# optimizer\n", "params = [{'params': model.parameters(), 'lr': args.init_lr, 'weight_decay': 0}]\n", "optimizer = Adam(params)\n", "\n", "# scheduler\n", "scheduler = NoamLR(\n", " optimizer=optimizer,\n", " warmup_epochs=[args.warmup_epochs],\n", " total_epochs=[args.epochs] * args.num_lrs,\n", " steps_per_epoch=len(train_data) // args.batch_size,\n", " init_lr=[args.init_lr],\n", " max_lr=[args.max_lr],\n", " final_lr=[args.final_lr]\n", ")\n", "\n", "# loss function\n", "loss_func = nn.MSELoss(reduction='none')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/164 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fit = plt.figure(figsize=(4,4))\n", "plt.scatter(valid_targets[i], valid_preds[i])\n", "plt.xlabel('Target ClogP')\n", "plt.ylabel('Predicted ClogP')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Almost no correlation as of first epoch. Let's train a few more epochs and see if the prediciton improves." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "359ffb3caf3c4c6384cb067ac569d05b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fit = plt.figure(figsize=(4,4))\n", "plt.scatter(valid_targets[i], valid_preds[i])\n", "plt.xlabel('Target ClogP')\n", "plt.ylabel('Predicted ClogP')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conclusion\n", "\n", "In this post, I have trained a graph neural network that can predict ClogP property. Within 30 epochs, it was able to predict the property pretty accurately in less than 0.1 MSE. Given only very simple features were used in atom and bond features, it was able to \"learn\" to predict the property fairly quickly.\n", "\n", "Now that we have a trained model, a few things I'd like to try:\n", "- compare this model with other traditional model and compare performance\n", "- try different parameters, such as `depth`\n", "- try alternative featurization, i.e., add if bond is rotatable in the bond_features and so on.\n", "- add long-range connection;current network is limited to chemical bonds, but longer range interaction may also be important." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }