{
"cells": [
{
"cell_type": "markdown",
"id": "5aa74260",
"metadata": {},
"source": [
"# Regression Confidence Experiments\n",
"\n",
"This notebook uses explores the computation and usage of Regression confidence metrics. Unlink many classification algorithms that provide a `predict_proba()` method that will give you probabilities, which can then be turned into a confidence, there's no direct equivalent for regression models\n",
"
\n",
"\n",
"## Data\n",
"AqSolDB: A curated reference set of aqueous solubility, created by the Autonomous Energy Materials Discovery [AMD] research group, consists of aqueous solubility values of 9,982 unique compounds curated from 9 different publicly available aqueous solubility datasets. AqSolDB also contains some relevant topological and physico-chemical 2D descriptors. Additionally, AqSolDB contains validated molecular representations of each of the compounds. This openly accessible dataset, which is the largest of its kind, and will not only serve as a useful reference source of measured and calculated solubility data, but also as a much improved and generalizable training data source for building data-driven models. (2019-04-10)\n",
"\n",
"Main Reference:\n",
"https://www.nature.com/articles/s41597-019-0151-1\n",
"\n",
"Data Dowloaded from the Harvard DataVerse:\n",
"https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/OVHAW8\n",
"\n",
"® Amazon Web Services, AWS, the Powered by AWS logo, are trademarks of Amazon.com, Inc. or its affiliates."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77f0186f-c6ac-4dc7-a804-ce5d248d25b7",
"metadata": {},
"outputs": [],
"source": [
"import sageworks\n",
"import logging\n",
"logging.getLogger(\"sageworks\").setLevel(logging.WARNING)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a7ae1c21",
"metadata": {},
"outputs": [],
"source": [
"# We've already created a FeatureSet so just grab it\n",
"from sageworks.api.feature_set import FeatureSet"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97243583",
"metadata": {},
"outputs": [],
"source": [
"fs = FeatureSet(\"aqsol_mol_descriptors\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "174e06f0",
"metadata": {},
"outputs": [],
"source": [
"full_df = fs.pull_dataframe()\n",
"full_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb268a27-6f18-432e-8eb5-43316f6174cf",
"metadata": {},
"outputs": [],
"source": [
"# Sanity check our solubility and solubility_class\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['font.size'] = 12.0\n",
"plt.rcParams['figure.figsize'] = 14.0, 7.0\n",
"sns.set_theme(style='darkgrid')\n",
"\n",
"# Create a box plot\n",
"sns.boxplot(x='solubility_class', y='solubility', data=full_df, order = ['high', 'medium', 'low'])\n",
"plt.title('Solubility by Solubility Class')\n",
"plt.xlabel('Solubility Class')\n",
"plt.ylabel('Solubility')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91c177bc-549d-4060-a001-dea272b42ff6",
"metadata": {},
"outputs": [],
"source": [
"# Grab the Training View\n",
"table = fs.view(\"training\").table\n",
"train_df = fs.query(f\"SELECT * FROM {table} where training = TRUE\")\n",
"hold_out_df = fs.query(f\"SELECT * FROM {table} where training = FALSE\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d36b6567-a800-4536-8684-9e66c2f7df83",
"metadata": {},
"outputs": [],
"source": [
"train_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d83c358-cd7c-4b6e-acb0-0e2e1d3ec3ae",
"metadata": {},
"outputs": [],
"source": [
"# Here we're just grabbing the model to get the target and features\n",
"from sageworks.api.model import Model\n",
"model = Model(\"aqsol-mol-regression\")\n",
"target = model.target()\n",
"features = model.features()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2113b792-36ef-40ad-b6ad-7ef83c97aa76",
"metadata": {},
"outputs": [],
"source": [
"len(full_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d800a6b1-5304-422c-a691-4a0fce00f4af",
"metadata": {},
"outputs": [],
"source": [
"import xgboost as xgb\n",
"\n",
"X = train_df[features]\n",
"y = train_df[target]\n",
"\n",
"# Train the main XGBRegressor model\n",
"model = xgb.XGBRegressor()\n",
"model.fit(X, y)\n",
"\n",
"# Run Predictions on the the hold out\n",
"hold_out_df[\"predictions\"] = model.predict(hold_out_df[features])\n",
"hold_out_df[\"confidence\"] = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f91a4fb3-c14d-4774-9993-c9c9308c8ceb",
"metadata": {},
"outputs": [],
"source": [
"# Plot Stuff\n",
"plot_predictions(hold_out_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22171a02-c138-4119-93d4-2b4c5d1460e8",
"metadata": {},
"outputs": [],
"source": [
"from sageworks.algorithms.dataframe.feature_spider import FeatureSpider\n",
"\n",
"# Create the FeatureSpider class and run the various methods\n",
"f_spider = FeatureSpider(train_df, features, id_column=\"id\", target_column=target, neighbors=5)\n",
"hold_out_df[\"predictions\"] = f_spider.predict(hold_out_df)\n",
"plot_predictions(hold_out_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a49b8941-eebf-4d76-b98e-064fcf5f490d",
"metadata": {},
"outputs": [],
"source": [
"# Include Confidence metric\n",
"hold_out_df[\"predictions\"] = model.predict(hold_out_df[features])\n",
"hold_out_df[\"confidence\"] = f_spider.confidence_scores(hold_out_df, hold_out_df[\"predictions\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37031651-2f34-46b8-b1c7-58c60b65d012",
"metadata": {},
"outputs": [],
"source": [
"plot_predictions(hold_out_df, color=\"confidence\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9da45bb-4ab6-48f6-9b34-74fa16920ab4",
"metadata": {},
"outputs": [],
"source": [
"high_df = hold_out_df[hold_out_df[\"confidence\"] >= 0.5]\n",
"plot_predictions(high_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fc1f2dc-cbe0-4856-b19c-22dee0d0c271",
"metadata": {},
"outputs": [],
"source": [
"# Classification Model\n",
"class_model = xgb.XGBClassifier(enable_categorical=True)\n",
"\n",
"X = train_df[features]\n",
"y = train_df[\"solubility_class\"]\n",
"y = y.astype(\"category\")\n",
"\n",
"# Train the main XGBClassifier model\n",
"class_model.fit(X, y)\n",
"\n",
"# Run Predictions on the the hold out\n",
"# hold_out_df[\"predictions\"] = class_model.predict(hold_out_df[features])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85473d05-6538-439f-af26-cd124015d4f1",
"metadata": {},
"outputs": [],
"source": [
"hold_out_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc5c5469-db98-400a-b61c-a80537e2083e",
"metadata": {},
"outputs": [],
"source": [
"high_df = df[df[\"confidence\"] >= 0.0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30f0b4de-75ba-46af-95b0-2a961edce382",
"metadata": {},
"outputs": [],
"source": [
"plot_predictions(high_df, color=\"confidence\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "504204a6-b46e-4aae-97e4-6869c14491fd",
"metadata": {},
"outputs": [],
"source": [
"df[\"predictions\"] = f_spider.predict(df)\n",
"plot_predictions(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57d69fa9-9dba-4a83-aacc-a048b874566e",
"metadata": {},
"outputs": [],
"source": [
"# Use XGBRegressor and Quartiles\n",
"import xgboost as xgb\n",
"import numpy as np\n",
"from sklearn.linear_model import QuantileRegressor\n",
"\n",
"X = df[features]\n",
"y = df[target]\n",
"\n",
"# Train the main XGBRegressor model\n",
"model = xgb.XGBRegressor()\n",
"model.fit(X, y)\n",
"\n",
"# Train QuantileRegressor models for lower and upper bounds\n",
"# Specify a different solver, such as 'highs'\n",
"lower_model = QuantileRegressor(quantile=0.05, solver='highs')\n",
"upper_model = QuantileRegressor(quantile=0.95, solver='highs')\n",
"\n",
"lower_model.fit(X, y)\n",
"upper_model.fit(X, y)\n",
"\n",
"# Make predictions\n",
"predictions = model.predict(X)\n",
"lower_bounds = lower_model.predict(X)\n",
"upper_bounds = upper_model.predict(X)\n",
"\n",
"# Calculate confidence metric\n",
"confidence = 1 - (upper_bounds - lower_bounds) / (np.max(y) - np.min(y))\n",
"\n",
"# Combine predictions and confidence into the existing dataframe\n",
"df[\"predictions\"] = predictions\n",
"df[\"confidence\"] = confidence\n",
"df[\"lower_bound\"] = lower_bounds\n",
"df[\"upper_bound\"] = upper_bounds\n",
"\n",
"print(df.head())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d08bdd8-6e13-4bad-a47f-bbad4b1e8c70",
"metadata": {},
"outputs": [],
"source": [
"df[\"confidence\"].hist()"
]
},
{
"cell_type": "markdown",
"id": "f31162c1",
"metadata": {},
"source": [
"# Helper Methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90d26e96",
"metadata": {},
"outputs": [],
"source": [
"# Helper to look at predictions vs target\n",
"from math import sqrt\n",
"import pandas as pd\n",
"def plot_predictions(df, color=\"error\", line=True):\n",
" \n",
" # Dataframe of the targets and predictions\n",
" target = 'Actual Solubility'\n",
" pred = 'Predicted Solubility'\n",
" df_plot = pd.DataFrame({target: df['solubility'], pred: df['predictions'], 'confidence': df['confidence']})\n",
" \n",
" # Compute Error per prediction\n",
" # df_plot['RMSError'] = df_plot.apply(lambda x: sqrt((x[pred] - x[target])**2), axis=1)\n",
" df_plot['PredError'] = df_plot.apply(lambda x: abs(x[pred] - x[target]), axis=1)\n",
"\n",
" if color == \"error\":\n",
" ax = df_plot.plot.scatter(x=target, y=pred, c='PredError', cmap='coolwarm', sharex=False)\n",
" else:\n",
" ax = df_plot.plot.scatter(x=target, y=pred, c='confidence', cmap='coolwarm', sharex=False)\n",
" \n",
" # Just a diagonal line\n",
" if line:\n",
" ax.axline((1, 1), slope=1, linewidth=2, c='black')\n",
" x_pad = (df_plot[target].max() - df_plot[target].min())/10.0 \n",
" y_pad = (df_plot[pred].max() - df_plot[pred].min())/10.0\n",
" plt.xlim(df_plot[target].min()-x_pad, df_plot[target].max()+x_pad)\n",
" plt.ylim(df_plot[pred].min()-y_pad, df_plot[pred].max()+y_pad)\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfab1df1-ea45-43dd-a76b-33354f572d8e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}