{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Analysis with `sl3` and Writing Custom `sl3` Learners\n", "\n", "## Lab 06 for PH 290: Targeted Learning in Biomedical Big Data\n", "\n", "### Author: [Nima Hejazi](https://nimahejazi.org)\n", "\n", "### Date: 21 February 2018" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# I. Data Analysis with `sl3`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We begin by illustrating a simple execution of the Super Learner algorithm using the SMOCC data and default algorithms. Start by loading the necessary packages:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "library(usethis)\n", "usethis::create_project(\".\")\n", "library(here)\n", "library(tidyverse)\n", "library(sl3)\n", "library(knitr)\n", "library(R6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# prediction data set\n", "chspred <- read_csv(here(\"data\", \"chspred.csv\"))\n", "head(chspred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We begin by illustrating the \"default\" functionality of the Super Learner algorithm (as implemented in `sl3`). Using the `chspred` data, we are interested in predicting myocardial infarcation (`mi`) using the available covariate data. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "chspred_task <- make_sl3_Task(\n", " data = chspred,\n", " outcome = \"hdl\",\n", " covariates = colnames(chspred)[!(colnames(chspred) %in% \"hdl\")]\n", ")\n", "chspred_task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the sake of computational expediency, we will initially consider only a simple library of algorithms: a fast main effects GLM, an unadjusted (i.e., intercept) model, and a random forest. Later, we will look at how these algorithms are constructed for useage with `sl3`. We'll use nonnegative least squares to fit the meta-learning step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "lrn1 <- Lrnr_mean$new()\n", "lrn2 <- Lrnr_glm_fast$new()\n", "sl_lrn <- Lrnr_sl$new(learners = list(lrn1, lrn2),\n", " metalearner = Lrnr_nnls$new())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "chspred_sl <- sl_lrn$train(chspred_task)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "chspred_sl_pred <- chspred_sl$predict()\n", "head(chspred_sl_pred)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sl_mse <- mean((chspred$hdl - chspred_sl_pred)^2)\n", "sl_mse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise\n", "\n", "We can also obtain predictions on a new observation:\n", "\n", "1. Generate a new observation set to the mean of each variable\n", "2. Predict using the trained Super Learner model on this new observation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# II. Writing Custom `sl3` Learners" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This guide describes the process of implementing a learner class for a new\n", "machine learning algorithm. By writing a learner class for your favorite machine\n", "learning algorithm, you will be able to use it in all the places you could\n", "otherwise use any other `sl3` learners, including `Pipeline`s, `Stack`s, and\n", "Super Learner. We have done our best to streamline the process of creating new\n", "`sl3` learners.\n", "\n", "Before diving into defining a new learner, it will likely be helpful to read\n", "some background material. If you haven't already read it, the [\"Modern Machine\n", "Learning in R\"](https://sl3.tlverse.org/articles/intro_sl3.html) vignette is a\n", "good introduction to the `sl3` package and it's underlying architecture. The\n", "[`R6`](https://cran.r-project.org/web/packages/R6/vignettes/Introduction.html)\n", "documentation will help you understand how `R6` classes are defined. In\n", "addition, the help files for [`sl3_Task`](https://sl3.tlverse.org/reference/sl3_Task.html) and\n", "[`Lrnr_base`](https://sl3.tlverse.org/reference/Lrnr_base.html) are good resources for how those objects can be\n", "used. If you're interested in defining learners that fit sub-learners, reading\n", "the documentation of the\n", "[`delayed`](https://delayed.tlverse.org/articles/delayed.html)\n", "package will be helpful.\n", "\n", "In the following sections, we introduce and review a template for a new `sl3`\n", "learner, describing the sections that can be used to define your new learner.\n", "This is followed by a discussion of the important task of documenting and\n", "testing your new learner. Finally, we conclude by explaining how you can add\n", "your learner to `sl3` so that others may make use of it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learner Template\n", "\n", "`sl3` provides a template of a learner for use in defining new learners. You can\n", "make a copy of the template to work on by invoking `write_learner_template`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#write_learner_template(\"path/to/write/Learner_template.R\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The template has comments indicating where details specific to the learner you're trying to implement should be filled in. In the next section, we will discuss those details further.\n", "\n", "## Defining your Learner\n", "\n", "### Learner Name and Class\n", "\n", "At the top of the template, we define an object `Lrnr_template` and set\n", "`classname = \"Lrnr_template\"`. You should modify these to match the name of your\n", "new learner, which should also match the name of the corresponding R file. Note\n", "that the name should be prefixed by `Lrnr_` and use\n", "[`snake_case`](https://en.wikipedia.org/wiki/Snake_case).\n", "\n", "### `public$initialize`\n", "\n", "This function defines the constructor for your learner, and it stores the\n", "arguments (if any) provided when a user calls\n", "`make_learner(Lrnr_your_learner, ...)`. You can also provide default parameter\n", "values, just as the template does with `param_1 = \"default_1\"`, and\n", "`param_2 = \"default_2\"`. All parameters used by your newly defined learners\n", "should have defaults whenever possible. This will allow users to use your\n", "learner without having to figure out what reasonable parameter values might be.\n", "Parameter values should be documented; see the section below on\n", "[documentation](#doctest) for details.\n", "\n", "### `public$special_function`s\n", "\n", "You can of course define functions for things only your learner can do. These\n", "should be public functions like the `special_function` defined in the example.\n", "These should be documented; see the section below on documentation for details.\n", "\n", "### `private$.properties`\n", "\n", "This field defines properties supported by your learner. This may include\n", "different outcome types that are supported, offsets and weights, amongst many\n", "other possibilities. To see a list of all properties supported/used by at least\n", "one learner, you may invoke `sl3_list_properties`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sl3_list_properties()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `private$.required_packages`\n", "\n", "This field defines other R packages required for your learner to work properly.\n", "These will be loaded when an object of your new learner class is initialized.\n", "\n", "### User Interface for Learners\n", "\n", "If you've used `sl3` before, you may have noticed that while users are\n", "instructed to use `learner$train`, `learner$predict`, and `learner$chain`, to\n", "train, generate predictions, and generate a chained task for a given learner\n", "object, respectively, the template does not implement these methods. Instead,\n", "the template implements private methods called `.train`, `.predict`, and\n", "`.chain`. The specifics of these methods are explained below; however, it is\n", "helpful to first understand how the two sets of methods are related. At the risk\n", "of complicating things further, it is worth noting that there is actually a\n", "third set of methods (`learner$base_train`, `learner$base_predict`, and\n", "`learner$base_chain`) of which you may not be aware.\n", "\n", "So, what happens when a user calls `learner$train`? That method generates a\n", "`delayed` object using the `delayed_learner_train` function, and then computes\n", "that delayed object. In turn, `delayed_learner_train` defines a delayed\n", "computation that calls `base_train`, a user-facing function that can be used to\n", "train tasks without using the facilities of the `delayed` package. `base_train`\n", "validates the user input, and in turn calls `private$.train`. When\n", "`private$.train` returns a `fit_object`, `base_train` takes that fit object,\n", "generates a learner fit object, and returns it to the user.\n", "\n", "Each call to `learner$train` involves three separate training methods:\n", "\n", "1. The user-facing `learner$train` -- trains a learner in a manner that can be\n", " parallelized using `delayed`, which calls ...\n", "2. ... the user-facing `learner$base_train` that validates user input, and which\n", " calls ...\n", "3. ... the internal `private$.train`, which does the actual work of fitting the\n", " learner and returning the fit object.\n", "\n", "The logic in the user-facing `learner$train` and `learner$base_train` is defined\n", "in the `Lrnr_base` base class and is shared across all learners. As such, these\n", "methods need not be reimplemented in individual learners. By contrast,\n", "`private$.train` contains the behavior that is specific to each individual\n", "learner and should be reimplemented at the level of each individual learner.\n", "Since `learner$base_train` does not use `delayed`, it may be helpful to use it\n", "when debugging the training code in a new learner. The program flow used for\n", "prediction and chaining is analogous.\n", "\n", "### `private$.train`\n", "\n", "This is the main training function, which takes in a task and returns a\n", "`fit_object` that contains all information needed to generate predictions. The\n", "fit object should not contain more data than is absolutely necessary, as\n", "including excess information will create needless inefficiencies. Many learner\n", "functions (like `glm`) store one or more copies of their training data -- this\n", "uses unnecessary memory and will hurt learner performance for large sample\n", "sizes. Thus, these copies of the data should be removed from the fit object\n", "before it is returned. You may make use of `true_obj_size` to estimate the size\n", "of your `fit_object`. For most learners, `fit_object` size should _not grow_\n", "linearly with training sample size. If it does, and this is unexpected, please\n", "try to reduce the size of the `fit_object`.\n", "\n", "Most of the time, the learner you are implementing will be fit using a function\n", "that already exists elsewhere. We've built some tools to facilitate passing\n", "parameter values directly to such functions. The `private$.train` function in\n", "the template uses a common pattern: it builds up an argument list starting with\n", "the parameter values and using data from the task, it then uses `call_with_args`\n", "to call `my_ml_fun` with that argument list. It's not required that learners use\n", "this pattern, but it will be helpful in the common case where the learner is\n", "simply wrapping an underlying `my_ml_fun`.\n", "\n", "By default, `call_with_args` will pass all arguments in the argument list\n", "matched by the definition of the function that it is calling. This allows the\n", "learner to silently drop irrelevant parameters from the call to `my_ml_fun`.\n", "Some learners either capture important arguments using dot arguments (`...`) or\n", "by passing important arguments through such dot arguments on to a secondary\n", "function. Both of these cases can be handled using the `other_valid` and\n", "`keep_all` options to `call_with_args`. The former allows you to list other\n", "valid arguments and the latter disables argument filtering altogether.\n", "\n", "### `private$.predict`\n", "\n", "This is the main prediction function, and takes in a task and generates\n", "predictions for that task using the `fit_object`. If those predictions are\n", "1-dimensional, they will be coerced to a vector by `base_predict`.\n", "\n", "### `private$.chain`\n", "\n", "This is the main chaining function. It takes in a task and generates a chained\n", "task (based on the input task) using the given `fit_object`. If this method is\n", "not implemented, your learner will use the default chaining behavior, which is\n", "to return a new task where the covariates are defined as your learner's\n", "predictions for the current task." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Documenting and Testing your Learner\n", "\n", "If you want other people to be able to use your learner, you will need to\n", "document and provide unit tests for it. The above template has example\n", "documentation, written in the [roxygen](http://r-pkgs.had.co.nz/man.html)\n", "format. Most importantly, you should describe what your learner does, reference\n", "any external code it uses, and document any parameters and public methods\n", "defined by it.\n", "\n", "It's also important to [test](http://r-pkgs.had.co.nz/tests.html) your learner.\n", "You should write unit tests to verify that your learner can train and predict on\n", "new data, and, if applicable, generate a chained task. It might also be a good\n", "idea to use the `risk` function in `sl3` to verify your learner's performance on\n", "a sample dataset. That way, if you change your learner and performance drops,\n", "you know something may have gone wrong.\n", "\n", "## Submitting your Learner to `sl3`\n", "\n", "Once you've implemented your new learner (and made sure that it has quality\n", "documentation and unit tests), please consider adding it to the `sl3` project.\n", "This will make it possible for other `sl3` users to use and build on your work.\n", "Make sure to add any R packages listed in `.required_packages` to the\n", "`Suggests:` field of the `DESCRIPTION` file of the `sl3` package. Once this is\n", "done, please submit a __Pull Request__ to the `sl3` package [on\n", "GitHub](https://github.com/tlverse/sl3) to request that your learned be\n", "added. If you've never made a \"Pull Request\" before, see this helpful\n", "guide: https://yangsu.github.io/pull-request-tutorial/." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "##' Template of a \\code{sl3} Learner.\n", "##'\n", "##' This is a template for defining a new learner.\n", "##' This can be copied to a new file using \\code{\\link{write_learner_template}}. \n", "##' The remainder of this documentation is an example of how you might write documentation for your new learner.\n", "##' This learner uses \\code{\\link[my_package]{my_ml_fun}} from \\code{my_package} to fit my favorite machine learning algorithm.\n", "##' \n", "##' @docType class\n", "##' @importFrom R6 R6Class\n", "##' @export\n", "##' @keywords data\n", "##' @return Learner object with methods for training and prediction. See \\code{\\link{Lrnr_base}} for documentation on learners.\n", "##' @format \\code{\\link{R6Class}} object.\n", "##' @family Learners\n", "##' \n", "##' @section Parameters:\n", "##' \\describe{\n", "##' \\item{\\code{param_1=\"default_1\"}}{ This parameter does something.\n", "##' }\n", "##' \\item{\\code{param_2=\"default_2\"}}{ This parameter does something else.\n", "##' }\n", "##' \\item{\\code{...}}{ Other parameters passed directly to \\code{\\link[my_package]{my_ml_fun}}. See its documentation for details.\n", "##' }\n", "##' }\n", "##' \n", "##' @section Methods:\n", "##' \\describe{\n", "##' \\item{\\code{special_function(arg_1)}}{\n", "##' My learner is special so it has a special function.\n", "##' \n", "##' \\itemize{\n", "##' \\item{\\code{arg_1}: A very special argument.\n", "##' }\n", "##' }\n", "##' }\n", "##' }\n", "Lrnr_template <- R6Class(classname = \"Lrnr_template\", inherit = Lrnr_base,\n", " portable = TRUE, class = TRUE,\n", "# Above, you should change Lrnr_template (in both the object name and the classname argument)\n", "# to a name that indicates what your learner does\n", " public = list(\n", " \n", " # you can define default parameter values here\n", " # if possible, your learner should define defaults for all required parameters\n", " initialize = function(param_1=\"default_1\", param_2=\"default_2\", ...) {\n", " # this captures all parameters to initialize and saves them as self$params \n", " params <- args_to_list()\n", " super$initialize(params = params, ...)\n", " },\n", " \n", " # you can define public functions that allow your learner to do special things here\n", " # for instance glm learner might return prediction standard errors\n", " special_function = function(arg_1){\n", " \n", " }\n", " ),\n", " private = list(\n", " # list properties your learner supports here. \n", " # Use sl3_list_properties() for a list of options\n", " .properties = c(\"\"),\n", " \n", " # list any packages required for your learner here.\n", " .required_packages = c(\"my_package\"),\n", " \n", " # .train takes task data and returns a fit object that can be used to generate predictions\n", " .train = function(task) {\n", " # generate an argument list from the parameters that were\n", " # captured when your learner was initialized.\n", " # this allows users to pass arguments directly to your ml function\n", " args <- self$params\n", " \n", " \n", " # get outcome variable type\n", " # prefering learner$params$outcome_type first, then task$outcome_type\n", " outcome_type <- self$get_outcome_type(task)\n", " # should pass something on to your learner indicating outcome_type\n", " # e.g. family or objective\n", " \n", " # add task data to the argument list\n", " # what these arguments are called depends on the learner you are wrapping\n", " args$x <- as.matrix(task$X_intercept)\n", " args$y <- outcome_type$format(task$Y)\n", " \n", " # only add arguments on weights and offset \n", " # if those were specified when the task was generated\n", " if(task$has_node(\"weights\")){\n", " args$weights <- task$weights\n", " }\n", " \n", " if(task$has_node(\"offset\")){\n", " args$offset <- task$offset\n", " }\n", " \n", " # call a function that fits your algorithm\n", " # with the argument list you constructed\n", " fit_object <- call_with_args(my_ml_fun, args)\n", " \n", " # return the fit object, which will be stored\n", " # in a learner object and returned from the call\n", " # to learner$predict\n", " return(fit_object)\n", " },\n", " \n", " # .predict takes a task and returns predictions from that task\n", " .predict = function(task = NULL) {\n", " self$training_task\n", " self$training_outcome_type\n", " self$fit_object\n", " \n", " predictions <- predict(self$fit_object, task$X)\n", " return(predictions)\n", " }\n", " )\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "R", "language": "R", "name": "ir" }, "language_info": { "codemirror_mode": "r", "file_extension": ".r", "mimetype": "text/x-r-source", "name": "R", "pygments_lexer": "r", "version": "3.4.3" } }, "nbformat": 4, "nbformat_minor": 2 }