{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2018.09.1\n" ] } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", "from rdkit import Chem\n", "from rdkit import rdBase\n", "from rdkit.Chem import AllChem\n", "from rdkit.Chem import DataStructs\n", "from rdkit.Chem import RDConfig\n", "from rdkit.Chem import rdFingerprintGenerator\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.ensemble import ExtraTreesClassifier\n", "from sklearn.ensemble import GradientBoostingClassifier\n", "from sklearn.svm import SVC\n", "from xgboost import XGBClassifier\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import confusion_matrix\n", "# This is new class for blending\n", "from blending_classification import BlendingClassifier\n", "\n", "datadir = os.path.join(RDConfig.RDDocsDir, \"Book/data\")\n", "print(rdBase.rdkitVersion)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def mols2feat(mols):\n", " generator = rdFingerprintGenerator.GetMorganGenerator(radius=2)\n", " res = []\n", " for mol in mols:\n", " fp = generator.GetFingerprint(mol)\n", " arr = np.zeros(0,)\n", " DataStructs.ConvertToNumpyArray(fp, arr)\n", " res.append(arr)\n", " return res" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# load train and test data\n", "train_mol = [ mol for mol in Chem.SDMolSupplier(os.path.join(datadir,'solubility.train.sdf')) if mol != None]\n", "test_mol = [ mol for mol in Chem.SDMolSupplier(os.path.join(datadir,'solubility.test.sdf')) if mol != None]\n", "\n", "cls = list(set([mol.GetProp('SOL_classification') for mol in train_mol]))\n", "cls_dic = {}\n", "for i, cl in enumerate(cls):\n", " cls_dic[cl] = i\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# make train X, y and test X, y\n", "train_X = np.array(mols2feat(train_mol))\n", "train_y = np.array([cls_dic[mol.GetProp('SOL_classification')] for mol in train_mol])\n", "\n", "test_X = np.array(mols2feat(test_mol))\n", "test_y = np.array([cls_dic[mol.GetProp('SOL_classification')] for mol in test_mol])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)\n", "et = ExtraTreesClassifier(n_estimators=100, n_jobs=-1)\n", "gbc = GradientBoostingClassifier(learning_rate=0.01)\n", "xgbc = XGBClassifier(n_estimators=100, n_jobs=-1) \n", "# To use SVC, probability option must be True\n", "svc = SVC(probability=True, gamma='auto')\n", "\n", "l1_clfs = [rf, et, gbc, xgbc]\n", "l2_clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "blendclf = BlendingClassifier(l1_clfs, l2_clf, verbose=1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4\n", "Fitting 4 l1_classifiers...\n", "3 classes classification\n", "1-1th hold, classifier\n", "1-2th hold, classifier\n", "1-3th hold, classifier\n", "1-4th hold, classifier\n", "1-5th hold, classifier\n", "2-1th hold, classifier\n", "2-2th hold, classifier\n", "2-3th hold, classifier\n", "2-4th hold, classifier\n", "2-5th hold, classifier\n", "3-1th hold, classifier\n", "3-2th hold, classifier\n", "3-3th hold, classifier\n", "3-4th hold, classifier\n", "3-5th hold, classifier\n", "4-1th hold, classifier\n", "4-2th hold, classifier\n", "4-3th hold, classifier\n", "4-4th hold, classifier\n", "4-5th hold, classifier\n", "--- Blending ---\n", "(1025, 4, 3)\n" ] } ], "source": [ "blendclf.fit(train_X, train_y)\n", "pred_y = blendclf.predict(test_X)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.56 0.68 0.61 40\n", " 1 0.74 0.83 0.78 102\n", " 2 0.70 0.57 0.63 115\n", "\n", " micro avg 0.69 0.69 0.69 257\n", " macro avg 0.67 0.69 0.68 257\n", "weighted avg 0.70 0.69 0.69 257\n", "\n" ] } ], "source": [ "print(classification_report(test_y, pred_y))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(
,\n", " )" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAEKCAYAAADw9/tHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFiRJREFUeJzt3Xl4VOXdxvHvL5sgEHbKqoKyiFoVI1rlVdzAKloXrOLS4lJflVZpsRaXV2tR0aK2Lt1waV0qWLCiooCKUBVFQEQWUUEBhaDsEDZDkt/7R0YaeIBMkOE50ftzXXPlzJkzc+4M4c45T845Y+6OiEhFWbEDiEjyqBhEJKBiEJGAikFEAioGEQmoGEQkoGIQkYCKQUQCKgYRCeTEDlBRfv0G3qR5q9gxEqtujdzYERLvq5Ky2BESbdHnn7FyxTKrbLlEFUOT5q24Z8iY2DESq3vHprEjJN6CZetjR0i0s7t3SWs57UqISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIoGc2AGSoEZuFp1a1aVGTjaOs2D5Bj5dvp6CvepSe4/ytyg3O4tNpWWMn7M8ctpkeHnMaK791TWUlpbS+5LL+PV1/WNHiu6GX17B+FdG0bBRY14YPwWA++76HWPHjCQrK4sGDRsz8L7BfK9ps8hJK5fRLQYzO9nMPjKzuWaW2J8cd5i1uIjXPl7GG3NX0LrRntTZI5spn61m/JzljJ+znMLVGylcvTF21EQoLS2l79V9eO6FUbw3/QOGDR3C7A8+iB0rujN/fCEPPTVii3mXXtWX51+bxIhXJ9L1pB/y53sHRkpXNRkrBjPLBv4E/BDoCPQys46ZWt838VVJGas3lABQUuYUbSyhRm72Fsu0qFuDRatUDACTJ01i3333o3WbNuTl5XHOuecx8oXnYseK7vAfdKFu/QZbzKtdJ3/z9Ib16zCz3R1rp2RyV6IzMNfdPwUws6HAj4BE/2qpmZtN3Zq5rFy/afO8hrVy+aqkjHXFpRGTJUdh4SJatmy1+X6LFi2ZNOmdiImS7Q8Df8tzw5+iTp18Hhs+KnactGRyV6IF8HmF+wtT8xIrO8vovHc9ZhauoaTMN89vUa8mC1dtiJgsWdw9mFddfhPG8Mvrf8v4dz+mx1nn8uTf/xY7TloyWQzb+kkJfqLM7HIzm2JmU9asjDewZ0DnveuxcNUGFq/5aov5zfL3YJHGFzZr0aIlCxf+t/MXLVpI8+bNIyaqHnqceS6vvDii8gUTIJPFsBBoVeF+S6Bw64XcfbC7F7h7QX79hhmMs2OHtqpL0cYSPlm2fov5jWvnsfarUjZuKouULHkKDj+cuXPnMH/ePIqLixn29FBO7XF67FiJNP/TuZunX3v5RVrv1z5imvRlcoxhMtDWzFoDi4DzgPMzuL6d1mDPXFrVr8nqDZvo2ra8nD74ooglRcW0qFeTRdqN2EJOTg5/uO9BTju1O6Wlpfy09yV0POCA2LGi+9WVP2XyW2+wcsVyju3Ull9cexP/GTuG+Z98jGVl0bzlXtx61/2xY6bFtrW/uMte3OwU4I9ANvCou9++o+X3O+Bgv2fImIzlqe66d2waO0LiLdhqi0+2dHb3Lsx8f2qlA0IZPcDJ3V8CXsrkOkRk19Mh0SISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIoGc2AEqyq+Ry/HtmsSOkVj1D/957AiJN2XknbEjJJqnuZy2GEQkoGIQkYCKQUQCKgYRCagYRCSgYhCRgIpBRAIqBhEJqBhEJKBiEJGAikFEAioGEQmoGEQksN2zK82siP+ejGWpr56adnfPz3A2EYlku8Xg7nV2ZxARSY60diXMrIuZXZyabmRmrTMbS0RiqrQYzOwW4DfA9alZecCTmQwlInGls8VwJnA6sA7A3QsB7WaIfIulUwzF7u6kBiLNrFZmI4lIbOkUw7/M7G9APTP7GfAq8FBmY4lITJVeDNbd7zazk4A1QDvgZnd/JePJRCSadK8SPQOoSfnuxIzMxRGRJEjnrxKXAZOAs4CewEQzuyTTwUQknnS2GH4NHOruywHMrCHwFvBoJoOJSDzpDD4uBIoq3C8CPs9MHBFJgh2dK/Gr1OQi4B0ze47yMYYfUb5rISLfUjvalfj6IKZPUrevPZe5OCKSBDs6ierW3RlERJKj0sFHM2sMXAccANT4er67H5/BXCISUTqDj/8EPgRaA7cC84HJGcwkIpGlUwwN3f0RYJO7/8fdLwGOzHCuaBZ+/jmndj+BgkMOoHOng/jzg/fHjpQYv7jgON4dfiNTht3AYwN7s0deDoNvvZDZI3/LxKH9mTi0P99v1yJ2zGhu6nclxxzcmjNO6Lx53p/uuYPjD2vH2d2O4uxuR/H62DERE6YvneMYNqW+LjazU4FCoGVlTzKzR4EewBJ3P3DnI+5eOTk53H7nIA45tBNFRUUcc9ThHH/CiXTYv2PsaFE1b1yXq3ody6Fn387Grzbx5F2XcE73wwC44Y8jePbVaZETxnfGORdwfu//5Ya+l28x/6Kf9eHiK66JlGrnpLPFcJuZ1QX6AdcCDwO/TON5/wBO3vlocTRt1oxDDu0EQJ06dWjfoQOFhYsip0qGnOxsau6RS3Z2FjVr5LF46erYkRKl4Mgu1K1XP3aMXaLSYnD3ke6+2t1nuvtx7n6Yuz+fxvNeB1bskpSRLFgwn+nTplFw+BGxo0RXuHQ1f3x8LB+PGsC8V25nzdoNjJ34IQC/7XMak56+nt/3O4u83HRPv/nuGPKPwZx54pHc1O9KVq9aGTtOWrZbDGb2gJndv73b7gwZw9q1a7mo1zncOehe8vN13dt6dWrSo+tB7N/jFtp0u5FaNfM475TDufmB5zn4zAF0uXAQ9evWot/FJ8aOmijn/uQyRk2YzjMvv0XjJk0ZNOCG2JHSsqMthinAuzu47RJmdrmZTTGzKcuWLt1VL/uNbNq0iQt79eTH557P6WecFTtOIhx/RAfmFy5n2cq1lJSUMeK19zny4NZ8sWwNAMWbSnj8uYkUHLBP3KAJ06hxE7Kzs8nKyqLn+b2ZOW2X/dfJqB0d4PTY7gjg7oOBwQCdDivwShbPOHenzxWX0b79/vz8mnSGUr4bPv9iBZ0Pak3NGrls2LiJ4zq3Z+oHn9G0Uf7mcjj9uO/zwSeFkZMmy9Ivv6Dx95oCMHb0C+zXvnoMYmuHcCsT35rA0Kee5IADD+LoI8oHIW++9Ta6n3xK5GRxTZ65gGdffY+3n/oNJaVlvP/hQh55ZgLPPXgljerXwQymf7SQX9w+NHbUaH7d52Imv/0Gq1Ys54SC9lzV7wYmv/0mH82aDma0aLUXt9xZPfbCrfxyjhl4YbMhQFegEfAlcEvqeIjt6nRYgf9ngs7P2p4mP7g6doTEmzLyztgREu3HpxzDrPenWmXLZWyLwd17Zeq1RSSz0rmCUzszG2tmM1P3v29mN2U+mojEks4BTg9R/mEzmwDcfTpwXiZDiUhc6RTDnu6+9Y5/SSbCiEgypFMMy8xsX/77gTM9gcUZTSUiUaUz+NiH8uMMOpjZImAecGFGU4lIVOl84MynwImpj6bLcveiyp4jItVbOldwunmr+wC4++8ylElEIktnV2JdhekalF9jYXZm4ohIEqSzK3FPxftmdjdQ6WnXIlJ9pfNXia3tCbTZ1UFEJDnSGWOYQepPlUA20BjQ+ILIt1g6Yww9KkyXAF+6uw5wEvkW22ExmFkW8GJ1upiriHxzOxxjcPcy4H0z22s35RGRBEhnV6IZMMvMJlHhT5fufnrGUolIVOkUgz7DUuQ7Jp1iOMXdf1NxhpndBfwnM5FEJLZ0jmM4aRvzfrirg4hIcmx3i8HMrgSuAtqY2fQKD9UBJmQ6mIjEs6NdiaeAUcBAoH+F+UXuXq0/YUpEdmxHnyuxGlgN6KKuIt8xO3OuhIh8y6kYRCSgYhCRgIpBRAIqBhEJqBhEJKBiEJGAikFEAioGEQmoGEQkkM5p17vNxuJSPlqsD7rankce6V/5Qt9x3e94NXaERFuyeE1ay2mLQUQCKgYRCagYRCSgYhCRgIpBRAIqBhEJqBhEJKBiEJGAikFEAioGEQmoGEQkoGIQkYCKQUQCKgYRCagYRCSgYhCRgIpBRAIqBhEJqBhEJKBiEJGAikFEAioGEQmoGEQkoGIQkYCKQUQCKgYRCagYRCSgYhCRgIpBRAIqBhEJqBhEJJATO0BS/O66Prw5bgz1Gzbm6dFvA/Dx7BncedOvWL9uHc1atmLAHx6idp38yEnjKP5qIwN+1pOS4mJKS0vpfMIp9LyiH0sWfcaD1/dh7ZpV7NPhQK4acB85uXmx40aRXzOXQRccQvtm+TjQ78mpTJ23kouPbUPvY1tTUua8NvNLbh8xK3bUSmVsi8HMWpnZODObbWazzOyaTK1rV+jR83zu//vwLebd1v9q+lx3C0NHv8Vx3XrwxEP3R0oXX27eHtz416cZOPRl7nhqNNPfGs+cGVMZev9AfnjBZdw74g1q5ddj/IihsaNGc2vPgxj/wRK6DhhLtzteY+4XazmqbSO6fb8pJ90xjhNue42/vjondsy0ZHJXogTo5+77A0cCfcysYwbX94106nw0+fXqbzHvs3lz6dT5aAA6dzmOcaNfiBEtEcyMGnvWAqC0pITSkhIMY9bkCXQ+4VQAjunRkynjx8SMGU3tGjkcsV9Dhry1AIBNpc6aDZu46JjW/OnlORSXlAGwfG1xzJhpy1gxuPtid5+ami4CZgMtMrW+TGjTbn9ef/UlAMa+NIIvFy+KnCiustJSru/VnStPOoQDj/wfvtdyb2rVySc7p3yPtEGTZqxc+kXklHHs1agWK9YWc+9FnRjdvyuDzj+EmnnZtGlSmyP2a8gLvz6G4X27cPBe9WJHTctuGXw0s32AQ4F3dsf6dpWb73qQYU88zEWnH8v6dWvJzc2NHSmqrOxsBg4ZwwOjJvHJzGksmr+tzWLb7bmSICfLOLBVXZ54Yx4n3zme9cWl9OnWjuwso+6euZw26HVue3Ymf7n08NhR05LxwUczqw08A/R19zXbePxy4HKAps1bZTpOleyzbzsefPxZABZ8Opc3x70cOVEy1KpTl/0LfsDcGe+xrmgNpSUlZOfksGLJYuo3/l7seFEsXrWBxas28t78lQC8+F4hfbq15YtVGxg1bTEA0xasosyhQe08ViR8lyKjWwxmlkt5KfzT3f+9rWXcfbC7F7h7Qf0GDTMZp8pWLFsKQFlZGY/+aRBnn39x5ETxrFm5nHVFqwEo3riBWe+8QfN99qNjwVFMGvsiAK+PHM5hx3aLGTOapWu+onDleto0qQ1Al/aNmfNFEaPfX8zR7RoB0LpJLfJyLPGlABncYjAzAx4BZrv7vZlaz65y49WX8u47b7Jq5XJOPaojl1/Tn/Xr1zH8iYcB6Nr9NE4758LIKeNZtWwJf73ll5SVluJexhEnnkanY06kZZu2PHBDH4b9eRB7tz+QrmecFztqNP83bAYP9D6MvJwsFixbT78nprK+uIR7LuzEqzcez6aSMvo+PjV2zLSYu2fmhc26AG8AM4Cy1Owb3P2l7T2n40GH+uPPj89Inm+Dj1cWxY6QeNc9MiV2hERbMuxaipfMrXQgKGNbDO7+Jt/VkSiRak6HRItIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiARUDCISUDGISEDFICIBFYOIBFQMIhJQMYhIQMUgIgEVg4gEVAwiElAxiEhAxSAiARWDiATM3WNn2MzMlgILYueooBGwLHaIBNP7U7mkvUd7u3vjyhZKVDEkjZlNcfeC2DmSSu9P5arre6RdCREJqBhEJKBi2LHBsQMknN6fylXL90hjDCIS0BaDiARUDNtgZieb2UdmNtfM+sfOkzRm9qiZLTGzmbGzJJGZtTKzcWY228xmmdk1sTNVlXYltmJm2cDHwEnAQmAy0MvdP4gaLEHM7BhgLfC4ux8YO0/SmFkzoJm7TzWzOsC7wBnV6WdIWwyhzsBcd//U3YuBocCPImdKFHd/HVgRO0dSuftid5+ami4CZgMt4qaqGhVDqAXweYX7C6lm/6iSHGa2D3Ao8E7cJFWjYgjZNuZpf0uqzMxqA88Afd19Tew8VaFiCC0EWlW43xIojJRFqikzy6W8FP7p7v+OnaeqVAyhyUBbM2ttZnnAecDzkTNJNWJmBjwCzHb3e2Pn2Rkqhq24ewnwc2AM5YNG/3L3WXFTJYuZDQHeBtqb2UIzuzR2poQ5GrgION7MpqVup8QOVRX6c6WIBLTFICIBFYOIBFQMIhJQMYhIQMUgIgEVw3eYma1NfW1uZsMrWbavme1ZxdfvamYj052/1TK9zezBKq5vvpk1qspzZNtUDN8yqbNDq8TdC929ZyWL9QWqVAxSfakYqgkz28fMPjSzx8xsupkN//o3eOo35c1m9iZwjpnta2ajzexdM3vDzDqklmttZm+b2WQzG7DVa89MTWeb2d1mNiO1nl+Y2dVAc2CcmY1LLdct9VpTzWxY6ryAr69l8WEqy1lpfF+dzewtM3sv9bV9hYdbpb6Pj8zslgrPudDMJqUOHPrbzpShVMLddasGN2Afyk/mOjp1/1Hg2tT0fOC6CsuOBdqmpo8AXktNPw/8JDXdB1hb4bVnpqavpPwY/5zU/QYV1tEoNd0IeB2olbr/G+BmoAblZ6a2pfxktH8BI7fxvXT9ej6QX2FdJwLPpKZ7A4uBhkBNYCZQAOwPvADkppb7c4XvaXNG3b7ZLWcnukTi+dzdJ6SmnwSuBu5O3X8aNp/RdxQwrPyQfQD2SH09Gjg7Nf0EcNc21nEi8FcvPzQcd9/WdReOBDoCE1LryKP8EOkOwDx3n5PK8iRweSXfU13gMTNrS3nx5VZ47BV3X556rX8DXYAS4DBgcmrdNYEllaxDqkjFUL1sffx6xfvrUl+zgFXufkiar7E1S3OZV9y91xYzzQ5J47lbGwCMc/czU9cuGF/hsW19vwY85u7XV3E9UgUaY6he9jKzH6SmewFvbr2Al5/3P8/MzoHyM/3M7ODUwxMoP1sU4ILtrONl4Aozy0k9v0FqfhFQJzU9ETjazPZLLbOnmbUDPgRam9m+FTJWpi6wKDXde6vHTjKzBmZWEzgjlX8s0NPMmnydz8z2TmM9UgUqhuplNvBTM5sONAD+sp3lLgAuNbP3gVn899J01wB9zGwy5f8ht+Vh4DNgeur556fmDwZGmdk4d19K+X/iIaksE4EO7r6R8l2HF1ODj+l8DunvgYFmNgHYehDxTcp3eaZRPvYwxcuvm3gT8HJq3a8AzdJYj1SBzq6sJlKb2SNdF1+V3UBbDCIS0BaDiAS0xSAiARWDiARUDCISUDGISEDFICIBFYOIBP4fjOxdaQC9IikAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from mlxtend.plotting import plot_confusion_matrix\n", "cm = confusion_matrix(test_y, pred_y)\n", "plot_confusion_matrix(cm)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.71 0.75 0.73 40\n", " 1 0.81 0.81 0.81 102\n", " 2 0.75 0.73 0.74 115\n", "\n", " micro avg 0.77 0.77 0.77 257\n", " macro avg 0.76 0.76 0.76 257\n", "weighted avg 0.77 0.77 0.77 257\n", "\n" ] } ], "source": [ "mono_rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)\n", "mono_rf.fit(train_X, train_y)\n", "pred_y2 = mono_rf.predict(test_X) \n", "print(classification_report(test_y, pred_y2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check the models correlation with PCA" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "labels = [\"rf\", \"et\", \"gbc\", \"xgbc\", \"mono_rf\"]\n", "feature_importances_list = [clf.feature_importances_ for clf in blendclf.l1_clfs_]\n", "feature_importances_list.append(mono_rf.feature_importances_)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "pca = PCA(n_components=2)\n", "res = pca.fit_transform(feature_importances_list)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5, 2)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.shape" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from adjustText import adjust_text\n", "x, y = res[:,0], res[:,1]\n", "plt.plot(x, y, 'bo')\n", "plt.xlim(-0.1, 0.3)\n", "plt.ylim(-0.05, 0.1)\n", "\n", "texts = [plt.text(x[i], y[i], '{}'.format(labels[i])) for i in range(len(labels))]\n", "adjust_text(texts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The PCA plot indicates that RF and ET in layer one and Mono_rf model learned similar feature importance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }