--- title: "Solution — Day 4 microbiome (models + VIP + SHAP)" format: html: toc: true code-tools: true --- Tasks **4.1–4.5** on the [lab exercises page](../index.qmd). ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, message = FALSE, warning = FALSE) suppressPackageStartupMessages({ library(tidymodels) library(dplyr) library(ggplot2) library(vip) library(kernelshap) library(shapviz) }) theme_set(theme_minimal()) source("_load_microbiome.R") ``` ## 4.1 Recipe ```{r data} set.seed(7) mic <- load_microbiome() otu_cols <- mic_otu_cols(mic) rec <- recipe(Label ~ ., data = mic) |> update_role(sample_id, Individual, Sex, Day, new_role = "id") |> step_mutate(across(all_of(otu_cols), ~ log1p(.x))) |> step_zv(all_predictors()) |> step_normalize(all_numeric_predictors()) folds <- mic_group_folds(mic, v = 5) metrics_cls <- metric_set(roc_auc, accuracy) ``` ## 4.2 Fit RF, XGBoost, MLP ```{r specs} rf_spec <- rand_forest(mtry = 8, trees = 300, min_n = 2) |> set_engine("ranger", probability = TRUE, importance = "impurity") |> set_mode("classification") xgb_spec <- boost_tree(trees = 100, tree_depth = 3, learn_rate = 0.05) |> set_engine("xgboost") |> set_mode("classification") mlp_spec <- mlp(hidden_units = 10, penalty = 0.1, epochs = 150) |> set_engine("nnet", trace = FALSE, MaxNWts = 5000) |> set_mode("classification") ``` ```{r fit-cv, cache=TRUE} # Grouped CV on ~300 OTUs is the slowest step (~1–3 min). Cached after first knit. set.seed(7) rs_rf <- fit_resamples(workflow() |> add_recipe(rec) |> add_model(rf_spec), folds, metrics_cls) set.seed(7) rs_xgb <- fit_resamples(workflow() |> add_recipe(rec) |> add_model(xgb_spec), folds, metrics_cls) set.seed(9) rs_mlp <- fit_resamples(workflow() |> add_recipe(rec) |> add_model(mlp_spec), folds, metrics_cls) cmp <- bind_rows( collect_metrics(rs_rf) |> mutate(model = "Random forest"), collect_metrics(rs_xgb) |> mutate(model = "XGBoost"), collect_metrics(rs_mlp) |> mutate(model = "MLP") ) knitr::kable(cmp |> filter(.metric == "roc_auc") |> select(model, mean, std_err), digits = 3) ``` ## 4.3 VIP (random forest) ```{r rf-final} wf_rf <- workflow() |> add_recipe(rec) |> add_model(rf_spec) fit_rf <- fit(wf_rf, mic) ``` ```{r vip, fig.width=8, fig.height=6} fit_rf |> extract_fit_parsnip() |> vip(geom = "point", num_features = 15) + labs( title = "Variable importance (random forest)", subtitle = "Microbiome OTUs — descriptive ranking, not causal" ) ``` ## 4.4 SHAP **Do not run kernel SHAP on all OTUs** — with 300+ predictors it can run for hours. We take the **top 10 VIP OTUs**, refit a small forest for teaching, and explain **12 samples** (same pattern as the penguin SHAP slides, fewer features). ```{r shap-top-otus} vip_top <- vip::vi(extract_fit_parsnip(fit_rf)) |> arrange(desc(Importance)) |> slice_head(n = 10) |> pull(Variable) mic_top <- mic |> select(Label, sample_id, Individual, Sex, Day, all_of(vip_top)) ``` ```{r shap-fit-small, cache=TRUE} rec_top <- recipe(Label ~ ., data = mic_top) |> update_role(sample_id, Individual, Sex, Day, new_role = "id") |> step_mutate(across(all_of(vip_top), ~ log1p(.x))) |> step_zv(all_predictors()) |> step_normalize(all_numeric_predictors()) rf_top_spec <- rand_forest(mtry = 3, trees = 100, min_n = 2) |> set_engine("ranger", probability = TRUE) |> set_mode("classification") fit_rf_top <- fit( workflow() |> add_recipe(rec_top) |> add_model(rf_top_spec), mic_top ) ``` ```{r shap-prep, cache=TRUE} if (requireNamespace("kernelshap", quietly = TRUE)) { options(kernelshap.verbose = FALSE) } rec_est <- extract_recipe(fit_rf_top, estimated = TRUE) X_model <- bake(rec_est, new_data = mic_top, all_predictors()) rf_engine <- extract_fit_parsnip(fit_rf_top)$fit pred_fun <- function(object, X_new) { as.numeric(predict(object, data = X_new)$predictions[, "Late"]) } set.seed(11) n_explain <- min(12L, nrow(X_model)) n_bg <- min(6L, nrow(X_model)) X <- X_model[sample.int(nrow(X_model), n_explain), , drop = FALSE] bg <- dplyr::slice_sample(X_model, n = n_bg) ks <- kernelshap::kernelshap(rf_engine, X = X, pred_fun = pred_fun, bg_X = bg) shp <- shapviz::shapviz(ks, X_pred = X) ``` ```{r shap-plot, fig.width=9, fig.height=6} sv_importance(shp, kind = "beeswarm", max_display = 12) + labs( title = "SHAP — forest on top 10 VIP OTUs (Late class)", subtitle = "Kernel SHAP on 12 samples; never run this on the full OTU matrix" ) ``` ## 4.5 Comparison figure ```{r compare-plot, fig.width=8, fig.height=4} cmp |> filter(.metric == "roc_auc") |> ggplot(aes(reorder(model, mean), mean, fill = model)) + geom_col(show.legend = FALSE) + geom_errorbar(aes(ymin = mean - std_err, ymax = mean + std_err), width = 0.12) + labs( title = "ROC AUC — grouped CV by Individual", x = NULL, y = "Mean ROC AUC" ) ``` **Humility:** I would **not** claim that the top VIP or SHAP OTU *causes* Early vs Late community shift — these plots describe **this fitted model** on observational counts, not a randomized intervention. ```{r session, echo=FALSE} sessionInfo() ```