{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "related-stream", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import plot_confusion_matrix" ] }, { "cell_type": "markdown", "id": "unauthorized-damage", "metadata": {}, "source": [ "# Material classification\n", "\n", "Neutron scattering can give complementary information about materials, such as the structural information from diffraction and electronic information from spectroscopy. \n", "In this exercise, we will look at using a random forest classification of structural and electronic information to create a model that can guess the [space group](https://en.wikipedia.org/wiki/Space_group) for new materials.\n", "\n", "The first thing that we need is some data, I have obtained data for lithium, sodium, and potassium containing materials of the four most common space groups ([space group numbers](https://en.wikipedia.org/wiki/List_of_space_groups): 2, 14, 15, 19). \n", "This data can be found in the file [`materials_data.csv`](https://github.com/arm61/trad_ml_methods/raw/main/materials_data.csv). " ] }, { "cell_type": "code", "execution_count": 2, "id": "tribal-lying", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('materials_data.csv')" ] }, { "cell_type": "markdown", "id": "first-christian", "metadata": {}, "source": [ "This data is read in as a `pandas.DataFrame` so it can be visualised as so. " ] }, { "cell_type": "code", "execution_count": 3, "id": "determined-korea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pretty_formulaspacegroup.numberband_gapdensityvolumeformation_energy_per_atomnsites
0KAl9O14144.14263.142646534.648960-3.36422948
1K2Se2O150.09753.303683253.443517-0.94053610
2K2AgF4140.61513.343987260.263084-2.37185914
3Na2TeS3141.79192.904619616.908734-1.00099324
4K3Zn2Cl723.90242.345705702.644832-1.90942324
........................
537Na2Cr2O722.52532.654095655.601817-1.86950944
538K4SnO422.34232.977324378.252471-1.83986518
539LiTa3O8153.38417.889695285.303789-3.26369924
540LiSmO2143.76556.485182193.881819-3.18575116
541K4UO522.19104.207670374.454975-2.64572120
\n", "

542 rows × 7 columns

\n", "
" ], "text/plain": [ " pretty_formula spacegroup.number band_gap density volume \\\n", "0 KAl9O14 14 4.1426 3.142646 534.648960 \n", "1 K2Se2O 15 0.0975 3.303683 253.443517 \n", "2 K2AgF4 14 0.6151 3.343987 260.263084 \n", "3 Na2TeS3 14 1.7919 2.904619 616.908734 \n", "4 K3Zn2Cl7 2 3.9024 2.345705 702.644832 \n", ".. ... ... ... ... ... \n", "537 Na2Cr2O7 2 2.5253 2.654095 655.601817 \n", "538 K4SnO4 2 2.3423 2.977324 378.252471 \n", "539 LiTa3O8 15 3.3841 7.889695 285.303789 \n", "540 LiSmO2 14 3.7655 6.485182 193.881819 \n", "541 K4UO5 2 2.1910 4.207670 374.454975 \n", "\n", " formation_energy_per_atom nsites \n", "0 -3.364229 48 \n", "1 -0.940536 10 \n", "2 -2.371859 14 \n", "3 -1.000993 24 \n", "4 -1.909423 24 \n", ".. ... ... \n", "537 -1.869509 44 \n", "538 -1.839865 18 \n", "539 -3.263699 24 \n", "540 -3.185751 16 \n", "541 -2.645721 20 \n", "\n", "[542 rows x 7 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "markdown", "id": "assisted-better", "metadata": {}, "source": [ "The first thing to do is separate our data into training and validation data (this process is common in machine learning methods). \n", "For this we will use the `scikit-learn` function [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html). " ] }, { "cell_type": "code", "execution_count": 4, "id": "falling-sailing", "metadata": {}, "outputs": [], "source": [ "train, validate = train_test_split(data, test_size=0.2)" ] }, { "cell_type": "markdown", "id": "immune-surprise", "metadata": {}, "source": [ "We have split the data, which consisted of data about 542 materials, so that 80 % will be used to train and the remaining 20 % is reserved for validation. " ] }, { "cell_type": "code", "execution_count": 5, "id": "loose-pressure", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pretty_formulaspacegroup.numberband_gapdensityvolumeformation_energy_per_atomnsites
53KInS2152.01033.080923940.171139-1.27070732
479Na4SiO423.22692.568896237.930444-2.40040318
212Na2(ReSe2)3151.27906.8592531044.230215-0.60348644
188KC4N323.72941.393763307.767109-0.04266316
171Na7In3Se821.16763.824842987.301832-0.91551236
........................
407K2Al2Sb3140.99923.5619681855.194833-0.36338956
357Li4P2O7145.65332.258198593.292328-2.75196252
327Na2Te2O5153.08314.378071578.299060-1.75762136
423K4ZnO321.69562.756436650.135219-1.50363932
425KPbO221.45326.172046598.988692-1.58183532
\n", "

433 rows × 7 columns

\n", "
" ], "text/plain": [ " pretty_formula spacegroup.number band_gap density volume \\\n", "53 KInS2 15 2.0103 3.080923 940.171139 \n", "479 Na4SiO4 2 3.2269 2.568896 237.930444 \n", "212 Na2(ReSe2)3 15 1.2790 6.859253 1044.230215 \n", "188 KC4N3 2 3.7294 1.393763 307.767109 \n", "171 Na7In3Se8 2 1.1676 3.824842 987.301832 \n", ".. ... ... ... ... ... \n", "407 K2Al2Sb3 14 0.9992 3.561968 1855.194833 \n", "357 Li4P2O7 14 5.6533 2.258198 593.292328 \n", "327 Na2Te2O5 15 3.0831 4.378071 578.299060 \n", "423 K4ZnO3 2 1.6956 2.756436 650.135219 \n", "425 KPbO2 2 1.4532 6.172046 598.988692 \n", "\n", " formation_energy_per_atom nsites \n", "53 -1.270707 32 \n", "479 -2.400403 18 \n", "212 -0.603486 44 \n", "188 -0.042663 16 \n", "171 -0.915512 36 \n", ".. ... ... \n", "407 -0.363389 56 \n", "357 -2.751962 52 \n", "327 -1.757621 36 \n", "423 -1.503639 32 \n", "425 -1.581835 32 \n", "\n", "[433 rows x 7 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train" ] }, { "cell_type": "code", "execution_count": 6, "id": "preliminary-marketplace", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pretty_formulaspacegroup.numberband_gapdensityvolumeformation_energy_per_atomnsites
46Na3CoO320.85752.947379198.203241-1.45333214
480RbLiF2156.16083.242957267.093984-3.02881416
509Na3BS3152.64041.964923297.430370-1.23599714
251KEu2I5141.34094.9877601301.792889-1.57936132
348Na6Sn2S7152.09482.597359766.943648-1.21826230
........................
368Li3CrF6153.96462.839623655.448686-3.14230460
43KGaCl4144.60492.201932756.038494-1.88124124
14Na6Ge2Se7151.60003.363290825.447012-0.84662330
292LiIO5140.22793.483491407.744687-0.73044928
197Na3AlTe3142.23553.517580904.012910-0.80420528
\n", "

109 rows × 7 columns

\n", "
" ], "text/plain": [ " pretty_formula spacegroup.number band_gap density volume \\\n", "46 Na3CoO3 2 0.8575 2.947379 198.203241 \n", "480 RbLiF2 15 6.1608 3.242957 267.093984 \n", "509 Na3BS3 15 2.6404 1.964923 297.430370 \n", "251 KEu2I5 14 1.3409 4.987760 1301.792889 \n", "348 Na6Sn2S7 15 2.0948 2.597359 766.943648 \n", ".. ... ... ... ... ... \n", "368 Li3CrF6 15 3.9646 2.839623 655.448686 \n", "43 KGaCl4 14 4.6049 2.201932 756.038494 \n", "14 Na6Ge2Se7 15 1.6000 3.363290 825.447012 \n", "292 LiIO5 14 0.2279 3.483491 407.744687 \n", "197 Na3AlTe3 14 2.2355 3.517580 904.012910 \n", "\n", " formation_energy_per_atom nsites \n", "46 -1.453332 14 \n", "480 -3.028814 16 \n", "509 -1.235997 14 \n", "251 -1.579361 32 \n", "348 -1.218262 30 \n", ".. ... ... \n", "368 -3.142304 60 \n", "43 -1.881241 24 \n", "14 -0.846623 30 \n", "292 -0.730449 28 \n", "197 -0.804205 28 \n", "\n", "[109 rows x 7 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate" ] }, { "cell_type": "markdown", "id": "increasing-journalism", "metadata": {}, "source": [ "We then define the columns of interest and segment our data into the `X` and `y` elements. " ] }, { "cell_type": "code", "execution_count": 7, "id": "dental-asset", "metadata": {}, "outputs": [], "source": [ "columns = ['band_gap', 'formation_energy_per_atom']\n", "X_train = train[columns]\n", "y_train = train['spacegroup.number']\n", "X_validate = validate[columns]\n", "y_validate = validate['spacegroup.number']" ] }, { "cell_type": "markdown", "id": "mobile-france", "metadata": {}, "source": [ "With the data split up, we can define our random forest model. " ] }, { "cell_type": "code", "execution_count": 8, "id": "maritime-underground", "metadata": {}, "outputs": [], "source": [ "model = RandomForestClassifier(n_estimators=100).fit(X_train, y_train)" ] }, { "cell_type": "markdown", "id": "commercial-decimal", "metadata": {}, "source": [ "Let's see how well our model managed to classify the validation data?" ] }, { "cell_type": "code", "execution_count": 9, "id": "electric-morrison", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_confusion_matrix(model, X_validate, y_validate, normalize='true')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "annoying-fluid", "metadata": {}, "source": [ "The confusion matrix allows us to compare the actual space group of the validation data, with that from the classification. \n", "For a perfect classification, you would have an [identity matrix](https://en.wikipedia.org/wiki/Identity_matrix) with the dimension of the number of labels. \n", "\n", "## Exercise\n", "\n", "It is clear from the above confusion matrix, that our current selection of columns of interest is not sufficient to accurately classify the data. \n", "In this exercise, you should try different combinations of the columns of interest to see which give the best classification of the space group for these materials. " ] }, { "cell_type": "code", "execution_count": null, "id": "unauthorized-association", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }