{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `auton-survival` Cross Validation Survival Regression\n", "\n", "`auton-survival` offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of `auton-survival` to train survival models on the *SUPPORT* dataset in cross validation fashion." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append('../')\n", "from auton_survival import datasets\n", "outcomes, features = datasets.load_support()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from auton_survival.preprocessing import Preprocessor\n", "\n", "cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']\n", "num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', \n", " 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n", " 'glucose', 'bun', 'urine', 'adlp', 'adls']\n", "\n", "# Data should be processed in a fold-independent manner when performing cross-validation. \n", "# For simplicity in this demo, we process the dataset in a non-independent manner.\n", "preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') \n", "x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,\n", " one_hot=True, fill_value=-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from auton_survival.experiments import SurvivalRegressionCV\n", "\n", "param_grid = {'k' : [3],\n", " 'distribution' : ['Weibull'],\n", " 'learning_rate' : [1e-4, 1e-3],\n", " 'layers' : [[100]]}\n", "\n", "experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n", "model = experiment.fit(x, outcomes, times, metric='brs')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(experiment.folds)\n", "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out_risk = model.predict_risk(x, times)\n", "out_survival = model.predict_survival(x, times)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from auton_survival.metrics import survival_regression_metric\n", "\n", "for fold in set(experiment.folds):\n", " print(survival_regression_metric('brs', outcomes[experiment.folds==fold], \n", " out_survival[experiment.folds==fold], \n", " times=times))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from auton_survival.metrics import survival_regression_metric\n", "\n", "for fold in set(experiment.folds):\n", " print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n", " out_survival[experiment.folds==fold], \n", " times=times))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for fold in set(experiment.folds):\n", " for time in times:\n", " print(time)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "c22fbbe4c37d04364aa4e785d7cd9729f94ca3cb878d5f955e35b0076c04a2d7" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }