{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Biostat 257 Homework 6\n", "\n", "**Due June 10 @ 11:59PM**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "versioninfo()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again we continue with the linear mixed effects model (LMM)\n", "$$\n", " \\mathbf{Y}_i = \\mathbf{X}_i \\boldsymbol{\\beta} + \\mathbf{Z}_i \\boldsymbol{\\gamma}_i + \\boldsymbol{\\epsilon}_i, \\quad i=1,\\ldots,n,\n", "$$\n", "where \n", "- $\\mathbf{Y}_i \\in \\mathbb{R}^{n_i}$ is the response vector of $i$-th individual, \n", "- $\\mathbf{X}_i \\in \\mathbb{R}^{n_i \\times p}$ is the fixed effects predictor matrix of $i$-th individual, \n", "- $\\mathbf{Z}_i \\in \\mathbb{R}^{n_i \\times q}$ is the random effects predictor matrix of $i$-th individual, \n", "- $\\boldsymbol{\\epsilon}_i \\in \\mathbb{R}^{n_i}$ are multivariate normal $N(\\mathbf{0}_{n_i},\\sigma^2 \\mathbf{I}_{n_i})$, \n", "- $\\boldsymbol{\\beta} \\in \\mathbb{R}^p$ are fixed effects, and \n", "- $\\boldsymbol{\\gamma}_i \\in \\mathbb{R}^q$ are random effects assumed to be $N(\\mathbf{0}_q, \\boldsymbol{\\Sigma}_{q \\times q}$) independent of $\\boldsymbol{\\epsilon}_i$.\n", "\n", "The log-likelihood of the $i$-th datum $(\\mathbf{y}_i, \\mathbf{X}_i, \\mathbf{Z}_i)$ is \n", "$$\n", " \\ell_i(\\boldsymbol{\\beta}, \\mathbf{L}, \\sigma_0^2) = - \\frac{n_i}{2} \\log (2\\pi) - \\frac{1}{2} \\log \\det \\boldsymbol{\\Omega}_i - \\frac{1}{2} (\\mathbf{y} - \\mathbf{X}_i \\boldsymbol{\\beta})^T \\boldsymbol{\\Omega}_i^{-1} (\\mathbf{y} - \\mathbf{X}_i \\boldsymbol{\\beta}),\n", "$$\n", "where\n", "$$\n", " \\boldsymbol{\\Omega}_i = \\sigma^2 \\mathbf{I}_{n_i} + \\mathbf{Z}_i \\boldsymbol{\\Sigma} \\mathbf{Z}_i^T.\n", "$$\n", "Given $m$ independent data points $(\\mathbf{y}_i, \\mathbf{X}_i, \\mathbf{Z}_i)$, $i=1,\\ldots,m$, we seek the maximum likelihood estimate (MLE) by maximizing the log-likelihood\n", "$$\n", "\\ell(\\boldsymbol{\\beta}, \\boldsymbol{\\Sigma}, \\sigma_0^2) = \\sum_{i=1}^m \\ell_i(\\boldsymbol{\\beta}, \\boldsymbol{\\Sigma}, \\sigma_0^2).\n", "$$\n", "\n", "In HW5, we used the nonlinear programming (NLP) approach (Newton type algorithms) for optimization. In this assignment, we derive and implement an expectation-maximization (EM) algorithm for the same problem." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load necessary packages; make sure install them first\n", "using BenchmarkTools, Distributions, LinearAlgebra, Random, Revise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q1. (10 pts) Refresher on normal-normal model\n", "\n", "Assume the conditional distribution\n", "$$\n", "\\mathbf{y} \\mid \\boldsymbol{\\gamma} \\sim N(\\mathbf{X} \\boldsymbol{\\beta} + \\mathbf{Z} \\boldsymbol{\\gamma}, \\sigma^2 \\mathbf{I}_n)\n", "$$\n", "and the prior distribution\n", "$$\n", "\\boldsymbol{\\gamma} \\sim N(\\mathbf{0}_q, \\boldsymbol{\\Sigma}).\n", "$$\n", "By the Bayes theorem, the posterior distribution is\n", "\\begin{eqnarray*}\n", "f(\\boldsymbol{\\gamma} \\mid \\mathbf{y}) &=& \\frac{f(\\mathbf{y} \\mid \\boldsymbol{\\gamma}) \\times f(\\boldsymbol{\\gamma})}{f(\\mathbf{y})}, \\end{eqnarray*}\n", "where $f$ denotes corresponding density. \n", "\n", "Show that the posterior distribution of random effects $\\boldsymbol{\\gamma}$ is a multivariate normal with mean\n", "\\begin{eqnarray*}\n", "\\mathbb{E} (\\boldsymbol{\\gamma} \\mid \\mathbf{y}) &=& \\sigma^{-2} (\\sigma^{-2} \\mathbf{Z}^T \\mathbf{Z} + \\boldsymbol{\\Sigma}^{-1})^{-1 } \\mathbf{Z}^T (\\mathbf{y} - \\mathbf{X} \\boldsymbol{\\beta}) \\\\\n", "&=& \\boldsymbol{\\Sigma} \\mathbf{Z}^T (\\mathbf{Z} \\boldsymbol{\\Sigma} \\mathbf{Z}^T + \\sigma^2 \\mathbf{I})^{-1} (\\mathbf{y} - \\mathbf{X} \\boldsymbol{\\beta})\n", "\\end{eqnarray*}\n", "and covariance\n", "\\begin{eqnarray*}\n", "\\text{Var} (\\boldsymbol{\\gamma} \\mid \\mathbf{y}) &=& (\\sigma^{-2} \\mathbf{Z}^T \\mathbf{Z} + \\boldsymbol{\\Sigma}^{-1})^{-1} \\\\\n", "&=& \\boldsymbol{\\Sigma} - \\boldsymbol{\\Sigma} \\mathbf{Z}^T (\\mathbf{Z} \\boldsymbol{\\Sigma} \\mathbf{Z}^T + \\sigma^2 \\mathbf{I})^{-1} \\mathbf{Z} \\boldsymbol{\\Sigma}.\n", "\\end{eqnarray*}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q2. (20 pts) Derive EM algorithm\n", "\n", "1. Write down the complete log-likelihood\n", "$$\n", "\\sum_{i=1}^m \\log f(\\mathbf{y}_i, \\boldsymbol{\\gamma}_i \\mid \\boldsymbol{\\beta}, \\boldsymbol{\\Sigma}, \\sigma^2)\n", "$$\n", "\n", "2. Derive the $Q$ function (E-step).\n", "$$\n", "Q(\\boldsymbol{\\beta}, \\boldsymbol{\\Sigma}, \\sigma^2 \\mid \\boldsymbol{\\beta}^{(t)}, \\boldsymbol{\\Sigma}^{(t)}, \\sigma^{2(t)}).\n", "$$\n", "\n", "3. Derive the EM (or ECM) update of $\\boldsymbol{\\beta}$, $\\boldsymbol{\\Sigma}$, and $\\sigma^2$ (M-step). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q3. (20 pts) Objective of a single datum\n", "\n", "We modify the code from HW5 to evaluate the objective, the conditional mean of $\\boldsymbol{\\gamma}$, and the conditional variance of $\\boldsymbol{\\gamma}$. Start-up code is provided below. You do _not_ have to use this code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define a type that holds an LMM datum\n", "struct LmmObs{T <: AbstractFloat}\n", " # data\n", " y :: Vector{T}\n", " X :: Matrix{T}\n", " Z :: Matrix{T}\n", " # posterior mean and variance of random effects γ\n", " μγ :: Vector{T} # posterior mean of random effects\n", " νγ :: Matrix{T} # posterior variance of random effects\n", " # TODO: add whatever intermediate arrays you may want to pre-allocate\n", " yty :: T\n", " rtr :: Vector{T}\n", " xty :: Vector{T}\n", " zty :: Vector{T}\n", " ztr :: Vector{T}\n", " ltztr :: Vector{T}\n", " xtr :: Vector{T}\n", " storage_p :: Vector{T}\n", " storage_q :: Vector{T}\n", " xtx :: Matrix{T}\n", " ztx :: Matrix{T}\n", " ztz :: Matrix{T}\n", " ltztzl :: Matrix{T}\n", " storage_qq :: Matrix{T}\n", "end\n", "\n", "\"\"\"\n", " LmmObs(y::Vector, X::Matrix, Z::Matrix)\n", "\n", "Create an LMM datum of type `LmmObs`.\n", "\"\"\"\n", "function LmmObs(\n", " y::Vector{T}, \n", " X::Matrix{T}, \n", " Z::Matrix{T}) where T <: AbstractFloat\n", " n, p, q = size(X, 1), size(X, 2), size(Z, 2)\n", " μγ = Vector{T}(undef, q)\n", " νγ = Matrix{T}(undef, q, q)\n", " yty = abs2(norm(y))\n", " rtr = Vector{T}(undef, 1)\n", " xty = transpose(X) * y\n", " zty = transpose(Z) * y\n", " ztr = similar(zty)\n", " ltztr = similar(zty)\n", " xtr = Vector{T}(undef, p)\n", " storage_p = similar(xtr)\n", " storage_q = Vector{T}(undef, q)\n", " xtx = transpose(X) * X\n", " ztx = transpose(Z) * X\n", " ztz = transpose(Z) * Z\n", " ltztzl = similar(ztz)\n", " storage_qq = similar(ztz)\n", " LmmObs(y, X, Z, μγ, νγ, \n", " yty, rtr, xty, zty, ztr, ltztr, xtr,\n", " storage_p, storage_q, \n", " xtx, ztx, ztz, ltztzl, storage_qq)\n", "end\n", "\n", "\"\"\"\n", " logl!(obs::LmmObs, β, Σ, L, σ², updater = false)\n", "\n", "Evaluate the log-likelihood of a single LMM datum at parameter values `β`, `Σ`, \n", "and `σ²`. The lower triangular Cholesky factor `L` of `Σ` must be supplied too.\n", "The fields `obs.μγ` and `obs.νγ` are overwritten by the posterior mean and \n", "posterior variance of random effects. If `updater==true`, fields `obs.ztr`, \n", "`obs.xtr`, and `obs.rtr` are updated according to input parameter values. \n", "Otherwise, it assumes these three fields are pre-computed. \n", "\"\"\"\n", "function logl!(\n", " obs :: LmmObs{T}, \n", " β :: Vector{T}, \n", " Σ :: Matrix{T},\n", " L :: Matrix{T},\n", " σ² :: T,\n", " updater :: Bool = false\n", " ) where T <: AbstractFloat\n", " n, p, q = size(obs.X, 1), size(obs.X, 2), size(obs.Z, 2)\n", " σ²inv = inv(σ²)\n", " ####################\n", " # Evaluate objective\n", " ####################\n", " # form the q-by-q matrix: Lt Zt Z L\n", " copy!(obs.ltztzl, obs.ztz)\n", " BLAS.trmm!('L', 'L', 'T', 'N', T(1), L, obs.ltztzl) # O(q^3)\n", " BLAS.trmm!('R', 'L', 'N', 'N', T(1), L, obs.ltztzl) # O(q^3) \n", " # form the q-by-q matrix: M = σ² I + Lt Zt Z L\n", " copy!(obs.storage_qq, obs.ltztzl)\n", " @inbounds for j in 1:q\n", " obs.storage_qq[j, j] += σ²\n", " end\n", " LAPACK.potrf!('U', obs.storage_qq) # O(q^3)\n", " # Zt * res\n", " updater && BLAS.gemv!('N', T(-1), obs.ztx, β, T(1), copy!(obs.ztr, obs.zty)) # O(pq)\n", " # Lt * (Zt * res)\n", " BLAS.trmv!('L', 'T', 'N', L, copy!(obs.ltztr, obs.ztr)) # O(q^2)\n", " # storage_q = (Mchol.U') \\ (Lt * (Zt * res))\n", " BLAS.trsv!('U', 'T', 'N', obs.storage_qq, copy!(obs.storage_q, obs.ltztr)) # O(q^3)\n", " # Xt * res = Xt * y - Xt * X * β\n", " updater && BLAS.gemv!('N', T(-1), obs.xtx, β, T(1), copy!(obs.xtr, obs.xty))\n", " # l2 norm of residual vector\n", " updater && (obs.rtr[1] = obs.yty - dot(obs.xty, β) - dot(obs.xtr, β))\n", " # assemble pieces\n", " logl::T = n * log(2π) + (n - q) * log(σ²) # constant term\n", " @inbounds for j in 1:q # log det term\n", " logl += 2log(obs.storage_qq[j, j])\n", " end\n", " qf = abs2(norm(obs.storage_q)) # quadratic form term\n", " logl += (obs.rtr[1] - qf) * σ²inv \n", " logl /= -2\n", " ######################################\n", " # TODO: Evaluate posterior mean and variance\n", " ###################################### \n", " # TODO\n", " ###################\n", " # Return\n", " ################### \n", " return logl\n", "end\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is a good idea to test correctness and efficiency of the single datum objective/posterior mean/var evaluator here. It's the same test datum in HW3 and HW5." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Random.seed!(257)\n", "# dimension\n", "n, p, q = 2000, 5, 3\n", "# predictors\n", "X = [ones(n) randn(n, p - 1)]\n", "Z = [ones(n) randn(n, q - 1)]\n", "# parameter values\n", "β = [2.0; -1.0; rand(p - 2)]\n", "σ² = 1.5\n", "Σ = fill(0.1, q, q) + 0.9I # compound symmetry \n", "L = Matrix(cholesky(Symmetric(Σ)).L)\n", "# generate y\n", "y = X * β + Z * rand(MvNormal(Σ)) + sqrt(σ²) * randn(n)\n", "\n", "# form the LmmObs object\n", "obs = LmmObs(y, X, Z);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@show logl = logl!(obs, β, Σ, L, σ², true)\n", "@show obs.μγ\n", "@show obs.νγ;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will lose all 20 points if following statement throws `AssertionError`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@assert abs(logl - (-3256.1793358058258)) < 1e-4\n", "@assert norm(obs.μγ - [0.10608689301333621, \n", " -0.25104190602577225, -1.4653979409855415]) < 1e-4\n", "@assert norm(obs.νγ - [\n", " 0.0007494356395909563 -1.2183420093769967e-6 -2.176783643112221e-6; \n", " -1.2183420282298223e-6 0.0007542331467601107 2.1553464632686345e-5; \n", " -2.1767836636008638e-6 2.1553464641863096e-5 0.0007465271342535443\n", " ]) < 1e-4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Efficiency" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Benchmark for efficiency." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bm_obj = @benchmark logl!($obs, $β, $Σ, $L, $σ², true)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "My median runt time is 1.8μs. You will get full credit if the median run time is within 10μs. The points you will get are" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clamp(10 / (median(bm_obj).time / 1e3) * 10, 0, 10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # check for type stability\n", "# @code_warntype logl!(obs, β, Σ, L, σ²)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# using Profile\n", "\n", "# Profile.clear()\n", "# @profile for i in 1:10000; logl!(obs, β, Σ, L, σ²); end\n", "# Profile.print(format=:flat)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q4. LmmModel type\n", "\n", "We modify the `LmmModel` type in HW5 to hold all data points, model parameters, and intermediate arrays." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define a type that holds LMM model (data + parameters)\n", "struct LmmModel{T <: AbstractFloat}\n", " # data\n", " data :: Vector{LmmObs{T}}\n", " # parameters\n", " β :: Vector{T}\n", " Σ :: Matrix{T}\n", " L :: Matrix{T}\n", " σ² :: Vector{T} \n", " # TODO: add whatever intermediate arrays you may want to pre-allocate\n", " xty :: Vector{T}\n", " xtr :: Vector{T}\n", " ztr2 :: Vector{T}\n", " xtxinv :: Matrix{T}\n", " ztz2 :: Matrix{T}\n", "end\n", "\n", "\"\"\"\n", " LmmModel(data::Vector{LmmObs})\n", "\n", "Create an LMM model that contains data and parameters.\n", "\"\"\"\n", "function LmmModel(obsvec::Vector{LmmObs{T}}) where T <: AbstractFloat\n", " # dims\n", " p = size(obsvec[1].X, 2)\n", " q = size(obsvec[1].Z, 2)\n", " # parameters\n", " β = Vector{T}(undef, p)\n", " Σ = Matrix{T}(undef, q, q)\n", " L = Matrix{T}(undef, q, q)\n", " σ² = Vector{T}(undef, 1) \n", " # intermediate arrays\n", " xty = zeros(T, p)\n", " xtr = similar(xty)\n", " ztr2 = Vector{T}(undef, abs2(q))\n", " xtxinv = zeros(T, p, p)\n", " # pre-calculate \\sum_i Xi^T Xi and \\sum_i Xi^T y_i\n", " @inbounds for i in eachindex(obsvec)\n", " obs = obsvec[i]\n", " BLAS.axpy!(T(1), obs.xtx, xtxinv)\n", " BLAS.axpy!(T(1), obs.xty, xty)\n", " end\n", " # invert X'X\n", " LAPACK.potrf!('U', xtxinv)\n", " LAPACK.potri!('U', xtxinv)\n", " LinearAlgebra.copytri!(xtxinv, 'U')\n", " ztz2 = Matrix{T}(undef, abs2(q), abs2(q))\n", " LmmModel(obsvec, β, Σ, L, σ², xty, xtr, ztr2, xtxinv, ztz2)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q5. Implement EM update\n", "\n", "Let's write the key function `update_em!` that performs one iteration of EM update." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", " update_em!(m::LmmModel, updater::Bool = false)\n", "\n", "Perform one iteration of EM update. It returns the log-likelihood calculated \n", "from input `m.β`, `m.Σ`, `m.L`, and `m.σ²`. These fields are then overwritten \n", "by the next EM iterate. The fields `m.data[i].xtr`, `m.data[i].ztr`, and \n", "`m.data[i].rtr` are updated according to the resultant `m.β`. If `updater==true`, \n", "the function first updates `m.data[i].xtr`, `m.data[i].ztr`, and \n", "`m.data[i].rtr` according to `m.β`. If `updater==false`, it assumes these fields \n", "are pre-computed.\n", "\"\"\"\n", "function update_em!(m::LmmModel{T}, updater::Bool = false) where T <: AbstractFloat\n", " logl = zero(T)\n", " # TODO: update m.β\n", " # TODO: update m.data[i].ztr, m.data[i].xtr, m.data[i].rtr\n", " # TODO: update m.σ²\n", " # update m.Σ and m.L\n", " # return log-likelihood at input parameter values\n", " logl\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q6. (30 pts) Test data\n", "\n", "Let's generate a fake longitudinal data set (same as HW5) to test our algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Random.seed!(257)\n", "\n", "# dimension\n", "m = 1000 # number of individuals\n", "ns = rand(1500:2000, m) # numbers of observations per individual\n", "p = 5 # number of fixed effects, including intercept\n", "q = 3 # number of random effects, including intercept\n", "obsvec = Vector{LmmObs{Float64}}(undef, m)\n", "# true parameter values\n", "βtrue = [0.1; 6.5; -3.5; 1.0; 5]\n", "σ²true = 1.5\n", "σtrue = sqrt(σ²true)\n", "Σtrue = Matrix(Diagonal([2.0; 1.2; 1.0]))\n", "Ltrue = Matrix(cholesky(Symmetric(Σtrue)).L)\n", "# generate data\n", "for i in 1:m\n", " # first column intercept, remaining entries iid std normal\n", " X = Matrix{Float64}(undef, ns[i], p)\n", " X[:, 1] .= 1\n", " @views Distributions.rand!(Normal(), X[:, 2:p])\n", " # first column intercept, remaining entries iid std normal\n", " Z = Matrix{Float64}(undef, ns[i], q)\n", " Z[:, 1] .= 1\n", " @views Distributions.rand!(Normal(), Z[:, 2:q])\n", " # generate y\n", " y = X * βtrue .+ Z * (Ltrue * randn(q)) .+ σtrue * randn(ns[i])\n", " # form a LmmObs instance\n", " obsvec[i] = LmmObs(y, X, Z)\n", "end\n", "# form a LmmModel instance\n", "lmm = LmmModel(obsvec);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Correctness\n", "\n", "Evaluate log-likelihood and gradient at the true parameter values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "copy!(lmm.β, βtrue)\n", "copy!(lmm.Σ, Σtrue)\n", "copy!(lmm.L, Ltrue)\n", "lmm.σ²[1] = σ²true\n", "@show obj1 = update_em!(lmm, true)\n", "@show lmm.β\n", "@show lmm.Σ\n", "@show lmm.L\n", "@show lmm.σ²\n", "println()\n", "@show obj2 = update_em!(lmm, false)\n", "@show lmm.β\n", "@show lmm.Σ\n", "@show lmm.L\n", "@show lmm.σ²" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Test correctness. You will loss all 30 points if following code throws `AssertError`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@assert abs(obj1 - (-2.840068438369969e6)) < 1e-4\n", "@assert abs(obj2 - (-2.84006046054206e6)) < 1e-4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Efficiency\n", "\n", "Test efficiency of EM update." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bm_emupdate = @benchmark update_em!($lmm, true) setup=(\n", " copy!(lmm.β, βtrue);\n", " copy!(lmm.Σ, Σtrue);\n", " copy!(lmm.L, Ltrue);\n", " lmm.σ²[1] = σ²true)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "My median run time is 2.4ms. You will get full credit if your median run time is within 10ms. The points you will get are" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clamp(10 / (median(bm_emupdate).time / 1e6) * 10, 0, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Memory\n", "\n", "You will lose 1 point for each 100 bytes memory allocation. So the points you will get is" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clamp(10 - median(bm_emupdate).memory / 100, 0, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q7. Starting point\n", "\n", "We use the same least squares estimates as in HW5 as starting point. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", " init_ls!(m::LmmModel)\n", "\n", "Initialize parameters of a `LmmModel` object from the least squares estimate. \n", "`m.β`, `m.L`, and `m.σ²` are overwritten with the least squares estimates.\n", "\"\"\"\n", "function init_ls!(m::LmmModel{T}) where T <: AbstractFloat\n", " p, q = size(m.data[1].X, 2), size(m.data[1].Z, 2)\n", " # LS estimate for β\n", " mul!(m.β, m.xtxinv, m.xty)\n", " # LS etimate for σ2 and Σ\n", " rss, ntotal = zero(T), 0\n", " fill!(m.ztz2, 0)\n", " fill!(m.ztr2, 0) \n", " @inbounds for i in eachindex(m.data)\n", " obs = m.data[i]\n", " ntotal += length(obs.y)\n", " # update Xt * res\n", " BLAS.gemv!('N', T(-1), obs.xtx, m.β, T(1), copy!(obs.xtr, obs.xty))\n", " # rss of i-th individual\n", " rss += obs.yty - dot(obs.xty, m.β) - dot(obs.xtr, m.β)\n", " # update Zi' * res\n", " BLAS.gemv!('N', T(-1), obs.ztx, m.β, T(1), copy!(obs.ztr, obs.zty))\n", " # Zi'Zi ⊗ Zi'Zi\n", " kron_axpy!(obs.ztz, obs.ztz, m.ztz2)\n", " # Zi'res ⊗ Zi'res\n", " kron_axpy!(obs.ztr, obs.ztr, m.ztr2)\n", " end\n", " m.σ²[1] = rss / ntotal\n", " # LS estimate for Σ = LLt\n", " LAPACK.potrf!('U', m.ztz2)\n", " BLAS.trsv!('U', 'T', 'N', m.ztz2, m.ztr2)\n", " BLAS.trsv!('U', 'N', 'N', m.ztz2, m.ztr2)\n", " copyto!(m.Σ, m.ztr2)\n", " copy!(m.L, m.Σ)\n", " LAPACK.potrf!('L', m.L)\n", " for j in 2:q, i in 1:j-1\n", " m.L[i, j] = 0\n", " end\n", " m\n", "end\n", "\n", "\"\"\"\n", " kron_axpy!(A, X, Y)\n", "\n", "Overwrite `Y` with `A ⊗ X + Y`. Same as `Y += kron(A, X)` but\n", "more memory efficient.\n", "\"\"\"\n", "function kron_axpy!(\n", " A::AbstractVecOrMat{T},\n", " X::AbstractVecOrMat{T},\n", " Y::AbstractVecOrMat{T}\n", " ) where T <: Real\n", " m, n = size(A, 1), size(A, 2)\n", " p, q = size(X, 1), size(X, 2)\n", " @assert size(Y, 1) == m * p\n", " @assert size(Y, 2) == n * q\n", " @inbounds for j in 1:n\n", " coffset = (j - 1) * q\n", " for i in 1:m\n", " a = A[i, j]\n", " roffset = (i - 1) * p \n", " for l in 1:q\n", " r = roffset + 1\n", " c = coffset + l\n", " for k in 1:p \n", " Y[r, c] += a * X[k, l]\n", " r += 1\n", " end\n", " end\n", " end\n", " end\n", " Y\n", "end" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "init_ls!(lmm)\n", "@show lmm.β\n", "@show lmm.Σ\n", "@show lmm.L\n", "@show lmm.σ²" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q8. Estimation by EM\n", "\n", "We write a function `fit!` that implements the EM algorithm for estimating LMM." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", " fit!(m::LmmModel)\n", "\n", "Fit an `LmmModel` object by MLE using a EM algorithm. Start point \n", "should be provided in `m.β`, `m.σ²`, `m.L`.\n", "\"\"\"\n", "function fit!(\n", " m :: LmmModel;\n", " maxiter :: Integer = 10_000,\n", " ftolrel :: AbstractFloat = 1e-12,\n", " prtfreq :: Integer = 0\n", " )\n", " obj = update_em!(m, true)\n", " for iter in 0:maxiter\n", " obj_old = obj\n", " # EM update\n", " obj = update_em!(m, false)\n", " # print obj\n", " prtfreq > 0 && rem(iter, prtfreq) == 0 && println(\"iter=$iter, obj=$obj\")\n", " # check monotonicity\n", " obj < obj_old && (@warn \"monotoniciy violated\")\n", " # check convergence criterion\n", " (obj - obj_old) < ftolrel * (abs(obj_old) + 1) && break\n", " # warning about non-convergence\n", " iter == maxiter && (@warn \"maximum iterations reached\")\n", " end\n", " m\n", "end\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q9. (20 pts) Test drive\n", "\n", "Now we can run our EM algorithm to compute the MLE." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# initialize from least squares\n", "init_ls!(lmm)\n", "\n", "@time fit!(lmm, prtfreq = 1);\n", "\n", "println(\"objective value at solution: \", update_em!(lmm)); println()\n", "println(\"solution values:\")\n", "@show lmm.β\n", "@show lmm.σ²\n", "@show lmm.L * transpose(lmm.L)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Correctness\n", "\n", "You get 10 points if the following code does not throw `AssertError`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# objective at solution should be close enough to the optimal\n", "@assert update_em!(lmm) > -2.840059e6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Efficiency\n", "\n", "My median run time 12ms. You get 10 points if your median run time is within 1s." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bm_em = @benchmark fit!($lmm) setup = (init_ls!(lmm))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# this is the points you get\n", "clamp(1 / (median(bm_em).time / 1e9) * 10, 0, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Q10. (10 pts) EM vs Newton type algorithms\n", "\n", "Contrast EM algorithm to the Newton type algorithms (gradient free, gradient based, using Hessian) in HW5, in terms of the stability, convergence rate (how fast the algorithm is converging), final objective value, total run time, derivation, and implementation efforts. " ] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "kernelspec": { "display_name": "Julia 1.7.1", "language": "julia", "name": "julia-1.7" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.7.1" }, "toc": { "colors": { "hover_highlight": "#DAA520", "running_highlight": "#FF0000", "selected_highlight": "#FFD700" }, "moveMenuLeft": true, "nav_menu": { "height": "87px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "skip_h1_title": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 4 }