# ============================================================================ # PSYC 434 — Lab 9: Policy Trees # self-standing script — run from top to bottom # ============================================================================ # --- packages --------------------------------------------------------------- library(causalworkshop) library(grf) library(policytree) library(tidyverse) # --- fit causal forest (from Labs 5-6) -------------------------------------- d <- simulate_nzavs_data(n = 5000, seed = 2026) d0 <- d |> filter(wave == 0) d1 <- d |> filter(wave == 1) d2 <- d |> filter(wave == 2) covariate_cols <- c( "age", "male", "nz_european", "education", "partner", "employed", "log_income", "nz_dep", "agreeableness", "conscientiousness", "extraversion", "neuroticism", "openness", "community_group", "wellbeing" ) X <- as.matrix(d0[, covariate_cols]) Y <- d2$wellbeing W <- d1$community_group cf <- causal_forest( X, Y, W, num.trees = 1000, honesty = TRUE, tune.parameters = "all", seed = 2026 ) tau_hat <- predict(cf)$predictions # --- construct gamma matrix ------------------------------------------------- gamma_matrix <- cbind( control = rep(0, length(tau_hat)), treatment = tau_hat ) head(gamma_matrix) # --- depth-1 policy tree ---------------------------------------------------- set.seed(2026) n_sample <- 500 idx <- sample(seq_len(nrow(X)), n_sample) X_sample <- as.data.frame(X[idx, ]) gamma_sample <- gamma_matrix[idx, ] pt_depth1 <- policy_tree(X_sample, gamma_sample, depth = 1) print(pt_depth1) plot(pt_depth1) # --- depth-2 policy tree ---------------------------------------------------- pt_depth2 <- policy_tree(X_sample, gamma_sample, depth = 2) print(pt_depth2) plot(pt_depth2) # --- evaluate policies ------------------------------------------------------ X_full <- as.data.frame(X) actions_depth1 <- predict(pt_depth1, X_full) actions_depth2 <- predict(pt_depth2, X_full) # expected reward under each policy reward_depth1 <- ifelse(actions_depth1 == 1, gamma_matrix[, 1], gamma_matrix[, 2]) reward_depth2 <- ifelse(actions_depth2 == 1, gamma_matrix[, 1], gamma_matrix[, 2]) reward_random <- 0.5 * gamma_matrix[, 1] + 0.5 * gamma_matrix[, 2] policy_comparison <- tibble( policy = c("Random assignment", "Depth-1 tree", "Depth-2 tree", "Treat everyone"), expected_reward = c( mean(reward_random), mean(reward_depth1), mean(reward_depth2), mean(tau_hat) ), treat_rate = c( 0.50, mean(actions_depth1 == 2), mean(actions_depth2 == 2), 1.00 ) ) print(policy_comparison |> mutate(across(where(is.numeric), \(x) round(x, 3)))) # --- interpret rules -------------------------------------------------------- print(pt_depth2) # --- compare with actual treatment ------------------------------------------ agreement <- tibble( actual = W, policy_depth2 = ifelse(actions_depth2 == 2, 1, 0) ) |> mutate(agree = actual == policy_depth2) cat("Agreement rate:", round(mean(agreement$agree), 3), "\n") cat("Policy treats: ", round(mean(agreement$policy_depth2), 3), "\n") cat("Actual treated:", round(mean(agreement$actual), 3), "\n")