{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "related-stream", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import plot_confusion_matrix" ] }, { "cell_type": "markdown", "id": "unauthorized-damage", "metadata": {}, "source": [ "# Material classification\n", "\n", "Neutron scattering can give complementary information about materials, such as the structural information from diffraction and electronic information from spectroscopy. \n", "In this exercise, we will look at using a random forest classification of structural and electronic information to create a model that can guess the [space group](https://en.wikipedia.org/wiki/Space_group) for new materials.\n", "\n", "The first thing that we need is some data, I have obtained data for lithium, sodium, and potassium containing materials of the four most common space groups ([space group numbers](https://en.wikipedia.org/wiki/List_of_space_groups): 2, 14, 15, 19). \n", "This data can be found in the file [`materials_data.csv`](https://github.com/arm61/trad_ml_methods/raw/main/materials_data.csv). " ] }, { "cell_type": "code", "execution_count": 2, "id": "tribal-lying", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('materials_data.csv')" ] }, { "cell_type": "markdown", "id": "first-christian", "metadata": {}, "source": [ "This data is read in as a `pandas.DataFrame` so it can be visualised as so. " ] }, { "cell_type": "code", "execution_count": 3, "id": "determined-korea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | pretty_formula | \n", "spacegroup.number | \n", "band_gap | \n", "density | \n", "volume | \n", "formation_energy_per_atom | \n", "nsites | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "KAl9O14 | \n", "14 | \n", "4.1426 | \n", "3.142646 | \n", "534.648960 | \n", "-3.364229 | \n", "48 | \n", "
1 | \n", "K2Se2O | \n", "15 | \n", "0.0975 | \n", "3.303683 | \n", "253.443517 | \n", "-0.940536 | \n", "10 | \n", "
2 | \n", "K2AgF4 | \n", "14 | \n", "0.6151 | \n", "3.343987 | \n", "260.263084 | \n", "-2.371859 | \n", "14 | \n", "
3 | \n", "Na2TeS3 | \n", "14 | \n", "1.7919 | \n", "2.904619 | \n", "616.908734 | \n", "-1.000993 | \n", "24 | \n", "
4 | \n", "K3Zn2Cl7 | \n", "2 | \n", "3.9024 | \n", "2.345705 | \n", "702.644832 | \n", "-1.909423 | \n", "24 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
537 | \n", "Na2Cr2O7 | \n", "2 | \n", "2.5253 | \n", "2.654095 | \n", "655.601817 | \n", "-1.869509 | \n", "44 | \n", "
538 | \n", "K4SnO4 | \n", "2 | \n", "2.3423 | \n", "2.977324 | \n", "378.252471 | \n", "-1.839865 | \n", "18 | \n", "
539 | \n", "LiTa3O8 | \n", "15 | \n", "3.3841 | \n", "7.889695 | \n", "285.303789 | \n", "-3.263699 | \n", "24 | \n", "
540 | \n", "LiSmO2 | \n", "14 | \n", "3.7655 | \n", "6.485182 | \n", "193.881819 | \n", "-3.185751 | \n", "16 | \n", "
541 | \n", "K4UO5 | \n", "2 | \n", "2.1910 | \n", "4.207670 | \n", "374.454975 | \n", "-2.645721 | \n", "20 | \n", "
542 rows × 7 columns
\n", "\n", " | pretty_formula | \n", "spacegroup.number | \n", "band_gap | \n", "density | \n", "volume | \n", "formation_energy_per_atom | \n", "nsites | \n", "
---|---|---|---|---|---|---|---|
53 | \n", "KInS2 | \n", "15 | \n", "2.0103 | \n", "3.080923 | \n", "940.171139 | \n", "-1.270707 | \n", "32 | \n", "
479 | \n", "Na4SiO4 | \n", "2 | \n", "3.2269 | \n", "2.568896 | \n", "237.930444 | \n", "-2.400403 | \n", "18 | \n", "
212 | \n", "Na2(ReSe2)3 | \n", "15 | \n", "1.2790 | \n", "6.859253 | \n", "1044.230215 | \n", "-0.603486 | \n", "44 | \n", "
188 | \n", "KC4N3 | \n", "2 | \n", "3.7294 | \n", "1.393763 | \n", "307.767109 | \n", "-0.042663 | \n", "16 | \n", "
171 | \n", "Na7In3Se8 | \n", "2 | \n", "1.1676 | \n", "3.824842 | \n", "987.301832 | \n", "-0.915512 | \n", "36 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
407 | \n", "K2Al2Sb3 | \n", "14 | \n", "0.9992 | \n", "3.561968 | \n", "1855.194833 | \n", "-0.363389 | \n", "56 | \n", "
357 | \n", "Li4P2O7 | \n", "14 | \n", "5.6533 | \n", "2.258198 | \n", "593.292328 | \n", "-2.751962 | \n", "52 | \n", "
327 | \n", "Na2Te2O5 | \n", "15 | \n", "3.0831 | \n", "4.378071 | \n", "578.299060 | \n", "-1.757621 | \n", "36 | \n", "
423 | \n", "K4ZnO3 | \n", "2 | \n", "1.6956 | \n", "2.756436 | \n", "650.135219 | \n", "-1.503639 | \n", "32 | \n", "
425 | \n", "KPbO2 | \n", "2 | \n", "1.4532 | \n", "6.172046 | \n", "598.988692 | \n", "-1.581835 | \n", "32 | \n", "
433 rows × 7 columns
\n", "\n", " | pretty_formula | \n", "spacegroup.number | \n", "band_gap | \n", "density | \n", "volume | \n", "formation_energy_per_atom | \n", "nsites | \n", "
---|---|---|---|---|---|---|---|
46 | \n", "Na3CoO3 | \n", "2 | \n", "0.8575 | \n", "2.947379 | \n", "198.203241 | \n", "-1.453332 | \n", "14 | \n", "
480 | \n", "RbLiF2 | \n", "15 | \n", "6.1608 | \n", "3.242957 | \n", "267.093984 | \n", "-3.028814 | \n", "16 | \n", "
509 | \n", "Na3BS3 | \n", "15 | \n", "2.6404 | \n", "1.964923 | \n", "297.430370 | \n", "-1.235997 | \n", "14 | \n", "
251 | \n", "KEu2I5 | \n", "14 | \n", "1.3409 | \n", "4.987760 | \n", "1301.792889 | \n", "-1.579361 | \n", "32 | \n", "
348 | \n", "Na6Sn2S7 | \n", "15 | \n", "2.0948 | \n", "2.597359 | \n", "766.943648 | \n", "-1.218262 | \n", "30 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
368 | \n", "Li3CrF6 | \n", "15 | \n", "3.9646 | \n", "2.839623 | \n", "655.448686 | \n", "-3.142304 | \n", "60 | \n", "
43 | \n", "KGaCl4 | \n", "14 | \n", "4.6049 | \n", "2.201932 | \n", "756.038494 | \n", "-1.881241 | \n", "24 | \n", "
14 | \n", "Na6Ge2Se7 | \n", "15 | \n", "1.6000 | \n", "3.363290 | \n", "825.447012 | \n", "-0.846623 | \n", "30 | \n", "
292 | \n", "LiIO5 | \n", "14 | \n", "0.2279 | \n", "3.483491 | \n", "407.744687 | \n", "-0.730449 | \n", "28 | \n", "
197 | \n", "Na3AlTe3 | \n", "14 | \n", "2.2355 | \n", "3.517580 | \n", "904.012910 | \n", "-0.804205 | \n", "28 | \n", "
109 rows × 7 columns
\n", "