{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## The Highly Adaptive LASSO and the `hal9001` R package\n", "\n", "## Lab 10 for PH 290: Targeted Learning in Biomedical Big Data\n", "\n", "### Author: [Nima Hejazi](https://nimahejazi.org)\n", "\n", "### Date: 21 March 2018" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# I. The Highly Adaptive LASSO (HAL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Recommended Reading\n", "\n", "* Benkeser, D, and van der Laan, MJ (2016). \"The Highly Adaptive Lasso Estimator.\" IEEE International Conference on Data Science and Advanced Analytics. (http://ieeexplore.ieee.org/abstract/document/7796956/)\n", "* R package: Coyle, JR, and Hejazi, NS (2018). \"hal9001: A fast and scalable Highly Adaptive LASSO.\" (https://github.com/jeremyrcoyle/hal9001)\n", "\n", "### What's HAL All About?\n", "\n", "\"Estimation of a regression function is a common goal of statistical learning. We propose a novel _nonparametric regression estimator_ that, in contrast to many existing methods, does not rely on _local smoothness assumptions_ nor is it constructed using _local smoothing techniques_. Instead, our estimator respects global smoothness constraints by virtue of falling in a class of right-hand continuous functions with left-hand limits that have variation norm bounded by a constant. Using empirical process theory, we establish a fast minimal rate of convergence of our proposed estimator and illustrate how such an estimator can be constructed using standard software.\" -excerpted from Benkeser & vdL (2016).\n", "\n", "### Why Care About Prediction?\n", "\n", "* Consider observed data $O = (X, Y)$, where $X$ is a set of covariates and $Y$ a vector outcome of interest.\n", "* Let us further write $O \\sim P_0 \\in \\mathcal{M}^{NP}$, where $P_0$ is the true distribution of the observed data $O$, contained in the nonparametric (infinite-dimensional) statistical model $\\mathcal{M}^{NP}$.\n", "* Let the _target parameter_ be the prediction function $\\psi_0$ s.t. we have $\\psi_0 \\in \\Psi$, a class of prediction functions that minimize the average of a scientifically relevant loss function (e.g., negative log-likelihood for binary outcomes).\n", "* Thus, our goal is to estimate $\\Psi(P_0)(X) = \\text{argmin}_{\\psi \\in \\Psi} E_{P_0}{L(\\psi)(X, Y)}$ for an arbitrary loss function $L(\\psi)$.\n", "\n", "### How Can Prediction Go Wrong?\n", "\n", "* \"Parametric and semiparametric methods for estimating the conditional mean assume its form is known up to a finite number of parameters.\"\n", "* Consider GLMs, which express the conditional mean as a transformation of a linear function of the conditioning variables - e.g., $E(Y \\mid X) = \\beta_0 + \\beta_1 X_1 + \\beta_2 X_2 + \\ldots$\n", "* Parametric methods suffer from a large bias when the assumed functional form is different from the true conditional mean (e.g., $E(Y \\mid X) = \\beta_0 + \\beta_i X_1 X_2 + \\beta_j X_1^2 + \\beta_k X_2^3 + \\ldots$).\n", "* Even nonparametrics is plagued by assumptions: \"For example, many methods assume $\\psi_0$ has nearly constant, linear, or low-order polynomial behavior for all points sufficiently close to each other in a given metric.\"\n", "\n", "### Theory and Properties of HAL\n", "\n", "* \"We assume that the true conditional mean function is right-hand continuous with left-hand limits (cadlag) and has variation norm smaller than a constant $M$. These are exceedingly mild assumptions that are expected to hold in almost every practical application.\"\n", "* We make two key smoothness assumptions about $\\psi_0$: (1) $\\psi_0$ is an element of the Banach space of d-variate cadlag functions; (2) $\\psi_0$ has finite variation norm, where the variation norm is defined $\\mid\\mid \\psi \\mid\\mid_{v} = \\int_{[0, \\tau]} \\mid \\psi(dx) \\mid$. The definition of HAL follows directly from an alternate formulation of the variation norm...\n", "* For any function $\\psi \\in D[0, \\tau]$, we define the s-th section of $\\psi$ as $\\psi_s(x) = \\psi(x_1 I(1 \\in s), \\ldots, x_d I(d \\in s))$, which is _the function that varies along the variables in $x_s$ according to $\\psi$, but sets the variables in $x_{s,c}$ equal to zero_. \n", "* Under the conditions above, $\\psi$ admits the following representation: $$\\psi(x) = \\psi(0) + \\sum_{s \\subset \\{1, \\ldots, d\\}} \\int_{0_s}^{x_s} \\psi_s(du) = \\psi(0) + \\sum_{s \\subset \\{1, \\ldots, d\\}} \\int_{0_s}^{\\tau_s} I(u \\leq x_s) \\psi_s(du).$$\n", "* By approximating $\\psi$ above with a discrete measure $\\psi_m$ over $m$ support points, one may write $$\\psi_m(x) = \\psi(0) + \\sum_{s \\subset \\{1, \\ldots, d\\}} \\sum_{j} I(u_{s,j} \\leq x_s) d\\psi_{m,s,j},$$ where this approximation consists of a linear combination of basis functions with\n", "corresponding coefficients $d\\psi_{m,s,j}$ summed over $s$ and $j$.\n", "* The following nice property follows directly from the definition of the HAL estimator: the sum of the absolute values of these coefficients gives the variation norm of $\\psi_m$ - i.e., $$\\mid\\mid\\psi_m \\mid\\mid_v = \\psi(0) + \\sum_{s \\subset \\{1, \\ldots, d\\}} \\sum_{j} \\mid d\\psi_{m,s,j} \\mid.$$\n", " \n", "\n", "\n", "### How Does HAL Help Us?\n", "\n", "* Since HAL has been shown to consistently estimate the true functional relation at a rate faster than $n^{−\\frac{1}{4}}$ as a function of sample size $n$, whenever the true target function has finite variation norm, we are able to obtain very good convergence rates when using it as an estimator.\n", "* The assumption of finite variation norm may be easily included in our statistical model without any concern for misspecification, since it is unlikely that the true functional relations have infinite variation as a function of the different covariates." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# II. Nonparametric Estimation / Prediction with `hal9001`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's load the packages we'll be using and some core project management tools." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "options(repr.plot.width = 5, repr.plot.height = 5) ## resizing plots\n", "options(scipen = 999) ## has scientific notation ever annoyed you?" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "here() starts at /Users/nimahejazi/Dropbox/UC_Berkeley-grad/teaching/2018_Spring/tlbbd-labs/lab_10\n", "── Attaching packages ─────────────────────────────────────── tidyverse 1.2.1 ──\n", "✔ ggplot2 2.2.1 ✔ purrr 0.2.4\n", "✔ tibble 1.4.2 ✔ dplyr 0.7.4\n", "✔ tidyr 0.8.0 ✔ stringr 1.3.0\n", "✔ readr 1.1.1 ✔ forcats 0.3.0\n", "── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──\n", "✖ dplyr::filter() masks stats::filter()\n", "✖ dplyr::lag() masks stats::lag()\n", "\n", "Attaching package: ‘data.table’\n", "\n", "The following objects are masked from ‘package:dplyr’:\n", "\n", " between, first, last\n", "\n", "The following object is masked from ‘package:purrr’:\n", "\n", " transpose\n", "\n" ] } ], "source": [ "library(here)\n", "library(usethis)\n", "library(tidyverse)\n", "library(data.table, quietly = TRUE)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading required package: plotmo\n", "Loading required package: plotrix\n", "Loading required package: TeachingDemos\n", "Loading required package: Matrix\n", "\n", "Attaching package: ‘Matrix’\n", "\n", "The following object is masked from ‘package:tidyr’:\n", "\n", " expand\n", "\n", "Loading required package: foreach\n", "\n", "Attaching package: ‘foreach’\n", "\n", "The following objects are masked from ‘package:purrr’:\n", "\n", " accumulate, when\n", "\n", "Loaded glmnet 2.0-13\n", "\n", "randomForest 4.6-12\n", "Type rfNews() to see new features/changes/bug fixes.\n", "\n", "Attaching package: ‘randomForest’\n", "\n", "The following object is masked from ‘package:dplyr’:\n", "\n", " combine\n", "\n", "The following object is masked from ‘package:ggplot2’:\n", "\n", " margin\n", "\n", "hal9001 v0.1.1: A fast and scalable Highly Adaptive LASSO\n" ] } ], "source": [ "library(earth)\n", "library(glmnet)\n", "library(randomForest)\n", "library(hal9001)\n", "library(sl3)\n", "set.seed(385971)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We begin by simulating a simple data set and illustrating a simple execution of how to use `hal9001` for prediction." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
X1 | X2 | X3 | X4 | Y |
---|---|---|---|---|
0.2086134 | -0.16809762 | 0.3635787 | 0 | 0.4550150 |
-0.9477350 | -0.68371703 | -1.2229561 | 0 | 0.1571233 |
0.9254026 | -1.42357891 | 0.1638252 | 1 | 0.1943260 |
0.5603774 | -0.09578228 | 1.4152984 | 1 | 1.6964338 |
0.6430260 | 1.01289165 | -1.1249601 | 0 | 0.5796670 |
-0.9730101 | 0.18683905 | 0.3577011 | 0 | -0.7109583 |