{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import pandas as pd\n", "import numpy as np\n", "from rdkit import Chem\n", "from rdkit.Chem import RDConfig\n", "from rdkit.Chem import DataStructs\n", "from rdkit.Chem import AllChem\n", "from rdkit.Chem.Draw import IPythonConsole\n", "from rdkit.Chem import Draw\n", "from sklearn.ensemble import RandomForestClassifier\n", "from nonconformist.nc import ClassifierNc\n", "from nonconformist.nc import ClassifierAdapter\n", "from nonconformist.icp import IcpClassifier\n", "from nonconformist.evaluation import ClassIcpCvHelper\n", "from nonconformist.evaluation import cross_val_score" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')\n", "test = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "trainmol = [m for m in Chem.SDMolSupplier(train)]\n", "testmol = [m for m in Chem.SDMolSupplier(test)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'(A) low', '(C) high', '(B) medium'}\n" ] } ], "source": [ "labels = set([m.GetProp('SOL_classification') for m in trainmol])\n", "print(labels)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label2cls = {'(A) low':0, '(B) medium':1, '(C) high':2}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def fp2arr(fp):\n", " arr = np.zeros((1,))\n", " DataStructs.ConvertToNumpyArray(fp, arr)\n", " return arr" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "trainfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in trainmol]\n", "trainfps = np.array([fp2arr(fp) for fp in trainfps])\n", "\n", "testfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in testmol]\n", "testfps = np.array([fp2arr(fp) for fp in testfps])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1025, 1024) (1025,) (257, 1024) (257,)\n" ] } ], "source": [ "train_cls = [label2cls[m.GetProp('SOL_classification')] for m in trainmol]\n", "train_cls = np.array(train_cls)\n", "test_cls = [label2cls[m.GetProp('SOL_classification')] for m in testmol]\n", "test_cls = np.array(test_cls)\n", "print(trainfps.shape, train_cls.shape, testfps.shape, test_cls.shape)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "#train data is devided to train and calibration data\n", "ids = np.random.permutation(train_cls.size)\n", "# Use first 700 data for train and second set is used for calibration\n", "trainX, calibX = trainfps[ids[:700],:],trainfps[ids[700:],:] \n", "trainY, calibY = train_cls[ids[:700]],train_cls[ids[700:]] " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "testX = testfps\n", "testY = test_cls" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "rf = RandomForestClassifier(n_estimators=500, random_state=794)\n", "nc = ClassifierNc(ClassifierAdapter(rf))\n", "icp = IcpClassifier(nc)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "icp.fit(trainX, trainY)\n", "icp.calibrate(calibX, calibY)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "pred = icp.predict(testX)\n", "pred95 = icp.predict(testX, significance=0.05).astype(np.int32)\n", "pred80 = icp.predict(testX, significance=0.2).astype(np.int32)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from nonconformist.evaluation import class_avg_c, class_n_correct" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "244" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_n_correct(pred, testY, significance=0.05)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "232" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_n_correct(pred, testY, significance=0.1)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2 False [0 1 1] : [1 1 1]\n", "0 1 False [0 1 1] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "0 1 False [0 1 0] : [1 1 0]\n", "0 1 False [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [0 1 1]\n", "1 0 False [1 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 1 False [0 1 0] : [0 1 0]\n", "1 0 False [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "2 2 True [0 0 1] : [0 1 1]\n", "2 2 True [0 0 1] : [0 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "1 2 False [0 0 1] : [0 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "1 1 True [1 1 1] : [1 1 1]\n", "1 2 False [0 0 1] : [0 1 1]\n", "1 2 False [0 0 1] : [0 0 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "1 2 False [0 1 1] : [1 1 1]\n", "2 1 False [0 1 0] : [0 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 0]\n", "2 2 True [0 1 1] : [1 1 1]\n", "0 1 False [0 1 0] : [0 1 0]\n", "2 1 False [0 1 1] : [1 1 1]\n", "2 0 False [1 1 1] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "0 1 False [0 1 0] : [0 1 0]\n", "2 2 True [0 0 1] : [0 0 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "0 1 False [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "2 2 True [0 0 1] : [0 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "1 2 False [0 0 1] : [0 0 1]\n", "1 2 False [0 0 1] : [0 1 1]\n", "1 0 False [1 0 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "2 2 True [0 0 1] : [0 0 1]\n", "2 2 True [0 0 1] : [1 1 1]\n", "1 2 False [0 1 1] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "2 2 True [0 0 1] : [1 1 1]\n", "1 0 False [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 0]\n", "0 1 False [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "1 1 True [1 1 1] : [1 1 1]\n", "2 1 False [0 1 1] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "0 1 False [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [0 1 1]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 1 False [0 1 0] : [0 1 0]\n", "0 2 False [1 1 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "0 1 False [1 1 0] : [1 1 1]\n", "1 1 True [1 1 1] : [1 1 1]\n", "2 1 False [0 1 0] : [1 1 1]\n", "1 2 False [1 1 1] : [1 1 1]\n", "2 2 True [0 0 1] : [1 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "2 1 False [0 1 1] : [1 1 1]\n", "2 2 True [0 1 1] : [1 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "0 1 False [1 1 0] : [1 1 0]\n", "2 2 True [0 0 1] : [0 0 1]\n", "2 2 True [0 0 1] : [0 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "2 1 False [0 1 0] : [0 1 1]\n", "0 1 False [1 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [0 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 1 False [1 1 0] : [1 1 1]\n", "2 1 False [0 1 1] : [1 1 1]\n", "2 2 True [0 0 1] : [0 1 1]\n", "1 1 True [0 1 0] : [0 1 0]\n", "2 2 True [0 1 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "2 2 True [0 0 1] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [0 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 1 0] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 1 True [0 1 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 0 False [1 0 1] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "1 2 False [0 0 1] : [1 1 1]\n", "0 1 False [0 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "2 2 True [0 0 1] : [0 0 1]\n", "2 2 True [0 1 1] : [1 1 1]\n", "2 2 True [0 0 1] : [0 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "2 2 True [0 0 1] : [1 1 1]\n", "1 2 False [0 1 1] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 2 False [0 0 1] : [1 1 1]\n", "2 2 True [1 0 1] : [1 1 1]\n", "0 1 False [0 1 0] : [1 1 0]\n", "1 1 True [0 1 1] : [1 1 1]\n", "1 2 False [0 0 1] : [0 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 2 False [0 0 1] : [0 1 1]\n", "1 2 False [0 1 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 0 0]\n", "1 0 False [1 0 0] : [1 1 0]\n", "0 1 False [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 0 False [1 1 1] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 1]\n", "1 0 False [1 1 1] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 1]\n", "0 1 False [1 1 0] : [1 1 1]\n", "1 1 True [0 1 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 1 False [1 1 0] : [1 1 1]\n", "0 0 True [1 1 0] : [1 1 0]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "1 0 False [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 1 0] : [1 1 0]\n", "0 0 True [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "1 1 True [0 1 0] : [0 1 0]\n", "1 1 True [0 1 0] : [1 1 0]\n", "0 1 False [0 1 0] : [1 1 0]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 0 True [1 1 0] : [1 1 1]\n", "0 0 True [1 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "1 1 True [1 1 0] : [1 1 1]\n", "1 1 True [0 1 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 1]\n", "1 1 True [1 1 0] : [1 1 1]\n", "0 2 False [0 0 1] : [0 1 1]\n", "0 2 False [0 1 1] : [1 1 1]\n", "1 0 False [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 1]\n", "0 0 True [1 0 0] : [1 0 0]\n", "0 0 True [1 0 0] : [1 1 0]\n", "2 2 True [0 0 1] : [1 1 1]\n", "0 0 True [1 0 0] : [1 1 0]\n", "0 0 True [1 0 0] : [1 0 0]\n" ] } ], "source": [ "tp = 0\n", "for idx, j in enumerate(testY):\n", " print(j, np.argmax(pred[idx]), j == np.argmax(pred[idx]) , pred80[idx], \":\", pred95[idx])\n", " if j == np.argmax(pred[idx]):\n", " tp += 1" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6848249027237354\n" ] } ], "source": [ "print(tp/testY.size)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.0155642023346303\n", "1.2996108949416343\n" ] } ], "source": [ "print(class_avg_c(pred, testY, significance=0.05))\n", "print(class_avg_c(pred, testY, significance=0.2))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "icpmodel = ClassIcpCvHelper(icp)\n", "res = cross_val_score(icpmodel,\n", " trainfps,\n", " train_cls,\n", " iterations=10,\n", " scoring_funcs=[class_avg_c],\n", " significance_levels=[0.05, 0.1, 0.2],\n", " )" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iterfoldsignificanceclass_avg_c
0000.052.135922
1000.101.708738
2000.201.330097
3010.052.495146
4010.101.922330
5010.201.281553
6020.052.194175
7020.101.582524
8020.201.271845
9030.052.000000
\n", "
" ], "text/plain": [ " iter fold significance class_avg_c\n", "0 0 0 0.05 2.135922\n", "1 0 0 0.10 1.708738\n", "2 0 0 0.20 1.330097\n", "3 0 1 0.05 2.495146\n", "4 0 1 0.10 1.922330\n", "5 0 1 0.20 1.281553\n", "6 0 2 0.05 2.194175\n", "7 0 2 0.10 1.582524\n", "8 0 2 0.20 1.271845\n", "9 0 3 0.05 2.000000" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.head(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }