{ "cells": [ { "cell_type": "markdown", "id": "a596ae7f-73b5-4674-85a8-011a971e4f66", "metadata": {}, "source": [ "# Running Inference on Legacy (Non-SageWorks) Models\n", "This notebook goes through the steps of setting up and running inference on Models that weren't generated in SageWorks.\n", "\n", "```\n", "model = Model(\"the old model\")\n", "model.onboard()\n", "```\n", "Answer all the Model onboarding questions, the tricky one is 'features' (note: capitalization matters)\n", "\n", "```\n", "end = Endpoint(\"the old endpoint\")\n", "end.onboard()\n", "```\n", "\n", "Answer all the onboard questions, if you're asked for input put the Model Name.\n", "\n", "After this is done, you'll need follow this notebook for 'remapping' the evaluation data into a form that the old model can use." ] }, { "cell_type": "code", "execution_count": null, "id": "43bbee7d-709e-42b5-b818-934fe5be755a", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# Set the environment variable\n", "os.environ['SAGEWORKS_CONFIG'] = '/Users/briford/.sageworks/scp_sandbox.json'" ] }, { "cell_type": "code", "execution_count": null, "id": "314a562d-776b-404a-8c4f-1ea56078d2a9", "metadata": {}, "outputs": [], "source": [ "# Grab the features from the model (this will be set during onboard())\n", "from sageworks.api.model import Model\n", "model = Model(\"LogS-pH7-Class-0-231025\")\n", "target = model.target()\n", "print(target)\n", "features = model.features()\n", "print(f\"{features[:5]}...\")" ] }, { "cell_type": "code", "execution_count": null, "id": "026249ac-18bd-4041-b73b-87e01a764450", "metadata": {}, "outputs": [], "source": [ "# Grab the evaluation data\n", "from sageworks.api.feature_set import FeatureSet\n", "fs = FeatureSet(\"solubility_featurized_fs\")\n", "training_view = fs.view(\"training\").table\n", "eval_data = fs.query(f\"select * from {training_view} where training=0\")" ] }, { "cell_type": "code", "execution_count": null, "id": "429c64b8-3e0c-4de4-9dee-9f02ec4cac85", "metadata": {}, "outputs": [], "source": [ "data_columns = eval_data.columns" ] }, { "cell_type": "code", "execution_count": null, "id": "88cad15d-31c4-4089-aa56-c0613278021d", "metadata": {}, "outputs": [], "source": [ "def build_map(data_columns, features):\n", " map = dict()\n", " for f in features:\n", " if f.lower() in data_columns:\n", " map[f.lower()] = f\n", " return map" ] }, { "cell_type": "code", "execution_count": null, "id": "10d6a542-0cef-4d89-b275-2325604178ef", "metadata": {}, "outputs": [], "source": [ "map = build_map(data_columns, features)\n", "map[\"class\"] = target # Add the target variable name change" ] }, { "cell_type": "code", "execution_count": null, "id": "b5d613d8-8e5c-4cdd-9edc-e10fa81f6c4c", "metadata": {}, "outputs": [], "source": [ "# Rename the feature and target columns to 'match' what the Legacy Model expects\n", "eval_data.rename(columns=map, inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "f8a55517-71d4-458e-a486-0f6365d43823", "metadata": {}, "outputs": [], "source": [ "# Grab the Legacy Endpoint\n", "from sageworks.api.endpoint import Endpoint\n", "end = Endpoint(\"LogS-pH7-Class-0-231025\")" ] }, { "cell_type": "code", "execution_count": null, "id": "57788b7b-42ef-487a-abba-f0ae35188133", "metadata": {}, "outputs": [], "source": [ "# Run inference on the Legacy Endpoint\n", "end.inference(eval_data, capture_uuid=\"2024_03_19_holdout\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b840f3b0-fabd-45b1-8d44-fcd188f8e242", "metadata": {}, "outputs": [], "source": [ "# This might not be strictly necssary but it will make sure the details are up to date\n", "model.details(recompute=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }