{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine Learning with Python\n", "\n", "We now have our data. We have sanitized it into a csv format. We have explored it.\n", "\n", "Now lets try to predict some properties." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Load the data\n", "# df = pd.read_csv('../data/mpdata.csv')\n", "df = pd.read_csv('https://gitlab.com/costrouc/mse-machinelearning-notebooks/raw/master/data/mpdata.csv')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
material_idenergyvolumensitesenergy_per_atompretty_formulaspacegroupband_gapdensitytotal_magnetizationpoisson_ratiobulk_modulus_voigtbulk_modulus_reussbulk_modulus_vrhshear_modulus_voigtshear_modulus_vrh
2621mp-559777-372.248991732.10601656-6.647303Na5Ca2Al(PO4)494.64962.730746-8.000000e-07NaNNaNNaNNaNNaNNaN
4063mp-603327-355.351207654.01937464-5.552363H3SNO3615.18851.9721481.312010e-02NaNNaNNaNNaNNaNNaN
5443mp-559382-17.05370733.4821243-5.684569CoO21640.00004.5097549.997688e-01NaNNaNNaNNaNNaNNaN
4616mp-558564-281.494573681.18521136-7.819294SiO2125.51131.757625-1.510000e-05NaNNaNNaNNaNNaNNaN
2883mp-667374-1188.8764402214.927434168-7.076645NaAlSiO41694.54142.5559691.502800e-03NaNNaNNaNNaNNaNNaN
\n", "
" ], "text/plain": [ " material_id energy volume nsites energy_per_atom \\\n", "2621 mp-559777 -372.248991 732.106016 56 -6.647303 \n", "4063 mp-603327 -355.351207 654.019374 64 -5.552363 \n", "5443 mp-559382 -17.053707 33.482124 3 -5.684569 \n", "4616 mp-558564 -281.494573 681.185211 36 -7.819294 \n", "2883 mp-667374 -1188.876440 2214.927434 168 -7.076645 \n", "\n", " pretty_formula spacegroup band_gap density total_magnetization \\\n", "2621 Na5Ca2Al(PO4)4 9 4.6496 2.730746 -8.000000e-07 \n", "4063 H3SNO3 61 5.1885 1.972148 1.312010e-02 \n", "5443 CoO2 164 0.0000 4.509754 9.997688e-01 \n", "4616 SiO2 12 5.5113 1.757625 -1.510000e-05 \n", "2883 NaAlSiO4 169 4.5414 2.555969 1.502800e-03 \n", "\n", " poisson_ratio bulk_modulus_voigt bulk_modulus_reuss bulk_modulus_vrh \\\n", "2621 NaN NaN NaN NaN \n", "4063 NaN NaN NaN NaN \n", "5443 NaN NaN NaN NaN \n", "4616 NaN NaN NaN NaN \n", "2883 NaN NaN NaN NaN \n", "\n", " shear_modulus_voigt shear_modulus_vrh \n", "2621 NaN NaN \n", "4063 NaN NaN \n", "5443 NaN NaN \n", "4616 NaN NaN \n", "2883 NaN NaN " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.sample(5)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
energyvolumensitesenergy_per_atomspacegroupband_gapdensitytotal_magnetizationpoisson_ratiobulk_modulus_voigtbulk_modulus_reussbulk_modulus_vrhshear_modulus_voigtshear_modulus_vrh
energy1.000000-0.852020-0.9651980.3451710.176241-0.3787380.185483-0.1219210.045966-0.080688-0.059895-0.071393-0.128289-0.056726
volume-0.8520201.0000000.862495-0.110264-0.1623440.302136-0.3093430.0341270.000059-0.208067-0.187330-0.201073-0.110852-0.052474
nsites-0.9651980.8624951.000000-0.194668-0.2172660.371992-0.2482810.095329-0.032862-0.032160-0.028786-0.0309920.0423270.017493
energy_per_atom0.345171-0.110264-0.1946681.000000-0.038482-0.221015-0.334075-0.1715080.051704-0.565429-0.465262-0.523804-0.434698-0.211633
spacegroup0.176241-0.162344-0.217266-0.0384821.000000-0.0936780.250417-0.0655600.0284440.1956540.2059960.2044830.1138080.058407
band_gap-0.3787380.3021360.371992-0.221015-0.0936781.000000-0.421409-0.220998-0.068310-0.266141-0.249368-0.262229-0.0366940.000743
density0.185483-0.309343-0.248281-0.3340750.250417-0.4214091.0000000.3221210.0655130.5018740.5380170.5294850.1654600.086553
total_magnetization-0.1219210.0341270.095329-0.171508-0.065560-0.2209980.3221211.0000000.0230070.0801650.0895850.086458-0.006023-0.031027
poisson_ratio0.0459660.000059-0.0328620.0517040.028444-0.0683100.0655130.0230071.000000-0.019499-0.022243-0.021264-0.1434810.082642
bulk_modulus_voigt-0.080688-0.208067-0.032160-0.5654290.195654-0.2661410.5018740.080165-0.0194991.0000000.9304990.9819580.6760280.325090
bulk_modulus_reuss-0.059895-0.187330-0.028786-0.4652620.205996-0.2493680.5380170.089585-0.0222430.9304991.0000000.9829770.5982730.312095
bulk_modulus_vrh-0.071393-0.201073-0.030992-0.5238040.204483-0.2622290.5294850.086458-0.0212640.9819580.9829771.0000000.6479450.324180
shear_modulus_voigt-0.128289-0.1108520.042327-0.4346980.113808-0.0366940.165460-0.006023-0.1434810.6760280.5982730.6479451.0000000.460951
shear_modulus_vrh-0.056726-0.0524740.017493-0.2116330.0584070.0007430.086553-0.0310270.0826420.3250900.3120950.3241800.4609511.000000
\n", "
" ], "text/plain": [ " energy volume nsites energy_per_atom \\\n", "energy 1.000000 -0.852020 -0.965198 0.345171 \n", "volume -0.852020 1.000000 0.862495 -0.110264 \n", "nsites -0.965198 0.862495 1.000000 -0.194668 \n", "energy_per_atom 0.345171 -0.110264 -0.194668 1.000000 \n", "spacegroup 0.176241 -0.162344 -0.217266 -0.038482 \n", "band_gap -0.378738 0.302136 0.371992 -0.221015 \n", "density 0.185483 -0.309343 -0.248281 -0.334075 \n", "total_magnetization -0.121921 0.034127 0.095329 -0.171508 \n", "poisson_ratio 0.045966 0.000059 -0.032862 0.051704 \n", "bulk_modulus_voigt -0.080688 -0.208067 -0.032160 -0.565429 \n", "bulk_modulus_reuss -0.059895 -0.187330 -0.028786 -0.465262 \n", "bulk_modulus_vrh -0.071393 -0.201073 -0.030992 -0.523804 \n", "shear_modulus_voigt -0.128289 -0.110852 0.042327 -0.434698 \n", "shear_modulus_vrh -0.056726 -0.052474 0.017493 -0.211633 \n", "\n", " spacegroup band_gap density total_magnetization \\\n", "energy 0.176241 -0.378738 0.185483 -0.121921 \n", "volume -0.162344 0.302136 -0.309343 0.034127 \n", "nsites -0.217266 0.371992 -0.248281 0.095329 \n", "energy_per_atom -0.038482 -0.221015 -0.334075 -0.171508 \n", "spacegroup 1.000000 -0.093678 0.250417 -0.065560 \n", "band_gap -0.093678 1.000000 -0.421409 -0.220998 \n", "density 0.250417 -0.421409 1.000000 0.322121 \n", "total_magnetization -0.065560 -0.220998 0.322121 1.000000 \n", "poisson_ratio 0.028444 -0.068310 0.065513 0.023007 \n", "bulk_modulus_voigt 0.195654 -0.266141 0.501874 0.080165 \n", "bulk_modulus_reuss 0.205996 -0.249368 0.538017 0.089585 \n", "bulk_modulus_vrh 0.204483 -0.262229 0.529485 0.086458 \n", "shear_modulus_voigt 0.113808 -0.036694 0.165460 -0.006023 \n", "shear_modulus_vrh 0.058407 0.000743 0.086553 -0.031027 \n", "\n", " poisson_ratio bulk_modulus_voigt bulk_modulus_reuss \\\n", "energy 0.045966 -0.080688 -0.059895 \n", "volume 0.000059 -0.208067 -0.187330 \n", "nsites -0.032862 -0.032160 -0.028786 \n", "energy_per_atom 0.051704 -0.565429 -0.465262 \n", "spacegroup 0.028444 0.195654 0.205996 \n", "band_gap -0.068310 -0.266141 -0.249368 \n", "density 0.065513 0.501874 0.538017 \n", "total_magnetization 0.023007 0.080165 0.089585 \n", "poisson_ratio 1.000000 -0.019499 -0.022243 \n", "bulk_modulus_voigt -0.019499 1.000000 0.930499 \n", "bulk_modulus_reuss -0.022243 0.930499 1.000000 \n", "bulk_modulus_vrh -0.021264 0.981958 0.982977 \n", "shear_modulus_voigt -0.143481 0.676028 0.598273 \n", "shear_modulus_vrh 0.082642 0.325090 0.312095 \n", "\n", " bulk_modulus_vrh shear_modulus_voigt shear_modulus_vrh \n", "energy -0.071393 -0.128289 -0.056726 \n", "volume -0.201073 -0.110852 -0.052474 \n", "nsites -0.030992 0.042327 0.017493 \n", "energy_per_atom -0.523804 -0.434698 -0.211633 \n", "spacegroup 0.204483 0.113808 0.058407 \n", "band_gap -0.262229 -0.036694 0.000743 \n", "density 0.529485 0.165460 0.086553 \n", "total_magnetization 0.086458 -0.006023 -0.031027 \n", "poisson_ratio -0.021264 -0.143481 0.082642 \n", "bulk_modulus_voigt 0.981958 0.676028 0.325090 \n", "bulk_modulus_reuss 0.982977 0.598273 0.312095 \n", "bulk_modulus_vrh 1.000000 0.647945 0.324180 \n", "shear_modulus_voigt 0.647945 1.000000 0.460951 \n", "shear_modulus_vrh 0.324180 0.460951 1.000000 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.corr()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.matshow(df.corr())\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Lets choose a very simple example to show methodology\n", "\n", "How about we try to predict the `energy_per_atom`. You can see from the correlation plot that there are two very highly correlated values in purple.\n", "\n", "We will simplify our model and only use the first four columns. Obviously `volume` is not usefull in the calculation but we want to see if our algorithm can automatically determine this." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "simplified_df = df[['energy', 'volume', 'nsites', 'energy_per_atom']]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
energyvolumensitesenergy_per_atom
0-4.06460011.8527651-4.064600
1-16.38209647.2641584-4.095524
2-8.18695923.6173882-4.093479
3-4.06414211.8747031-4.064142
4-2.157191603.4752101-2.157191
\n", "
" ], "text/plain": [ " energy volume nsites energy_per_atom\n", "0 -4.064600 11.852765 1 -4.064600\n", "1 -16.382096 47.264158 4 -4.095524\n", "2 -8.186959 23.617388 2 -4.093479\n", "3 -4.064142 11.874703 1 -4.064142\n", "4 -2.157191 603.475210 1 -2.157191" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "simplified_df.head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All (99%) of machine learning algoritms need the data as arrays of floating point numbers. Scikit learn is no different. This is how easy it is to convery a pandas dataframe from a numpy array.\n", "\n", "Not covered here but you most likely will need it at one point [preprocessing data](http://scikit-learn.org/stable/modules/preprocessing.html) and how to handle categorical data." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(6928, 3) (6928,)\n", "[[ -4.0645998 11.85276501 1. ]\n", " [-16.38209642 47.26415795 4. ]\n", " [ -8.18695876 23.61738783 2. ]] [-4.0645998 -4.0955241 -4.09347938]\n" ] } ], "source": [ "# convert from pandas dataframe to numpy array\n", "X = simplified_df[['energy', 'volume', 'nsites']].values\n", "y = simplified_df['energy_per_atom'].values\n", "\n", "print(X.shape, y.shape)\n", "print(X[:3], y[:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Scikit Learn\n", "\n", "Very quick overview. [Scikit](http://scikit-learn.org/stable/) learn provides a unified framework for working with machine learning algorithms. It includes classification, regression, clustering, dimensionality reduction, model tuning, pre and post processing of data.\n", "\n", "Is that a lot? **YES** scikit learn is huge and you cannot expect to use and learn everything.\n", "\n", "The flow chart gives some good advice for which algorithms to use for your problem. See their [flow chart](http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html)\n", "\n", "![sklearn flowchart](../images/scklearn-flowchart.png)\n", "\n", "There are a ton of algoritms over 100! This is where sklearn really shines. All algorithms have the exact same api (this is the pseudocode).\n", "\n", "```python\n", "from sklearn import MyImportantModel\n", "\n", "model = MyImportantModel()\n", "model.fit(X, y)\n", "```\n", "\n", "Once you have `fit` your model you can using is to predict future data.\n", "\n", "```python\n", "y_predict = model.predict(X_predict)\n", "```\n", "\n", "We will be using a linear model to fit our data. Always start with the simplest model! That way you know what sort of improvement a complex one can get you.\n", "\n", "[sklearn.LinearRegression](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Lets use a simple linear model\n", "from sklearn.linear_model import LinearRegression\n", "\n", "model = LinearRegression()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.model_selection import cross_val_predict\n", "\n", "predicted = cross_val_predict(model, X, y, cv=10)\n", "\n", "fig, ax = plt.subplots()\n", "ax.scatter(y, predicted, edgecolors=(0, 0, 0))\n", "ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4)\n", "ax.set_xlabel('Measured')\n", "ax.set_ylabel('Predicted')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# lest do the cross validation by hand\n", "import sklearn" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, test_size=0.1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = LinearRegression()\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.5827284267819932" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_predict = model.predict(X_test)\n", "\n", "# calculate mean square error\n", "sklearn.metrics.mean_squared_error(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }