{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Generate data for Huber regression.\n", "srand(1);\n", "n = 300;\n", "SAMPLES = int(1.5*n);\n", "beta_true = 5*randn(n);\n", "X = randn(n, SAMPLES);\n", "Y = zeros(SAMPLES);\n", "v = randn(SAMPLES);" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Generate data for different values of p.\n", "# Solve the resulting problems.\n", "# WARNING this script takes a few minutes to run.\n", "using Convex, SCS, Distributions\n", "set_default_solver(SCSSolver(verbose=0));\n", "TESTS = 50;\n", "lsq_data = zeros(TESTS);\n", "huber_data = zeros(TESTS);\n", "prescient_data = zeros(TESTS);\n", "p_vals = linspace(0,0.15, TESTS);\n", "for i=1:length(p_vals)\n", " p = p_vals[i];\n", " # Generate the sign changes.\n", " factor = float(2 * rand(Binomial(1, 1-p), SAMPLES) - 1);\n", " Y = factor .* X' * beta_true + v;\n", " \n", " # Form and solve a standard regression problem.\n", " beta = Variable(n);\n", " fit = norm(beta - beta_true) / norm(beta_true);\n", " cost = norm(X' * beta - Y);\n", " prob = minimize(cost);\n", " solve!(prob);\n", " lsq_data[i] = evaluate(fit);\n", " \n", " # Form and solve a prescient regression problem,\n", " # i.e., where the sign changes are known.\n", " cost = norm(factor .* (X'*beta) - Y);\n", " solve!(minimize(cost))\n", " prescient_data[i] = evaluate(fit);\n", " \n", " # Form and solve the Huber regression problem.\n", " cost = sum(huber(X' * beta - Y, 1));\n", " solve!(minimize(cost))\n", " huber_data[i] = evaluate(fit);\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " p\n", " \n", " \n", " 0.00\n", " 0.05\n", " 0.10\n", " 0.15\n", " \n", " \n", " \n", " Huber\n", " Prescient\n", " Least squares\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " \n", " \n", " Fit\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " p\n", " \n", " \n", " -0.20\n", " -0.15\n", " -0.10\n", " -0.05\n", " 0.00\n", " 0.05\n", " 0.10\n", " 0.15\n", " 0.20\n", " 0.25\n", " 0.30\n", " 0.35\n", " -0.155\n", " -0.150\n", " -0.145\n", " -0.140\n", " -0.135\n", " -0.130\n", " -0.125\n", " -0.120\n", " -0.115\n", " -0.110\n", " -0.105\n", " -0.100\n", " -0.095\n", " -0.090\n", " -0.085\n", " -0.080\n", " -0.075\n", " -0.070\n", " -0.065\n", " -0.060\n", " -0.055\n", " -0.050\n", " -0.045\n", " -0.040\n", " -0.035\n", " -0.030\n", " -0.025\n", " -0.020\n", " -0.015\n", " -0.010\n", " -0.005\n", " 0.000\n", " 0.005\n", " 0.010\n", " 0.015\n", " 0.020\n", " 0.025\n", " 0.030\n", " 0.035\n", " 0.040\n", " 0.045\n", " 0.050\n", " 0.055\n", " 0.060\n", " 0.065\n", " 0.070\n", " 0.075\n", " 0.080\n", " 0.085\n", " 0.090\n", " 0.095\n", " 0.100\n", " 0.105\n", " 0.110\n", " 0.115\n", " 0.120\n", " 0.125\n", " 0.130\n", " 0.135\n", " 0.140\n", " 0.145\n", " 0.150\n", " 0.155\n", " 0.160\n", " 0.165\n", " 0.170\n", " 0.175\n", " 0.180\n", " 0.185\n", " 0.190\n", " 0.195\n", " 0.200\n", " 0.205\n", " 0.210\n", " 0.215\n", " 0.220\n", " 0.225\n", " 0.230\n", " 0.235\n", " 0.240\n", " 0.245\n", " 0.250\n", " 0.255\n", " 0.260\n", " 0.265\n", " 0.270\n", " 0.275\n", " 0.280\n", " 0.285\n", " 0.290\n", " 0.295\n", " 0.300\n", " 0.305\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " -0.16\n", " -0.15\n", " -0.14\n", " -0.13\n", " -0.12\n", " -0.11\n", " -0.10\n", " -0.09\n", " -0.08\n", " -0.07\n", " -0.06\n", " -0.05\n", " -0.04\n", " -0.03\n", " -0.02\n", " -0.01\n", " 0.00\n", " 0.01\n", " 0.02\n", " 0.03\n", " 0.04\n", " 0.05\n", " 0.06\n", " 0.07\n", " 0.08\n", " 0.09\n", " 0.10\n", " 0.11\n", " 0.12\n", " 0.13\n", " 0.14\n", " 0.15\n", " 0.16\n", " 0.17\n", " 0.18\n", " 0.19\n", " 0.20\n", " 0.21\n", " 0.22\n", " 0.23\n", " 0.24\n", " 0.25\n", " 0.26\n", " 0.27\n", " 0.28\n", " 0.29\n", " 0.30\n", " 0.31\n", " \n", " \n", " \n", " Huber\n", " Prescient\n", " Least squares\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " -1.50\n", " -1.45\n", " -1.40\n", " -1.35\n", " -1.30\n", " -1.25\n", " -1.20\n", " -1.15\n", " -1.10\n", " -1.05\n", " -1.00\n", " -0.95\n", " -0.90\n", " -0.85\n", " -0.80\n", " -0.75\n", " -0.70\n", " -0.65\n", " -0.60\n", " -0.55\n", " -0.50\n", " -0.45\n", " -0.40\n", " -0.35\n", " -0.30\n", " -0.25\n", " -0.20\n", " -0.15\n", " -0.10\n", " -0.05\n", " 0.00\n", " 0.05\n", " 0.10\n", " 0.15\n", " 0.20\n", " 0.25\n", " 0.30\n", " 0.35\n", " 0.40\n", " 0.45\n", " 0.50\n", " 0.55\n", " 0.60\n", " 0.65\n", " 0.70\n", " 0.75\n", " 0.80\n", " 0.85\n", " 0.90\n", " 0.95\n", " 1.00\n", " 1.05\n", " 1.10\n", " 1.15\n", " 1.20\n", " 1.25\n", " 1.30\n", " 1.35\n", " 1.40\n", " 1.45\n", " 1.50\n", " 1.55\n", " 1.60\n", " 1.65\n", " 1.70\n", " 1.75\n", " 1.80\n", " 1.85\n", " 1.90\n", " 1.95\n", " 2.00\n", " 2.05\n", " 2.10\n", " 2.15\n", " 2.20\n", " 2.25\n", " 2.30\n", " 2.35\n", " 2.40\n", " 2.45\n", " 2.50\n", " 2.55\n", " 2.60\n", " 2.65\n", " 2.70\n", " 2.75\n", " 2.80\n", " 2.85\n", " 2.90\n", " 2.95\n", " 3.00\n", " -2\n", " 0\n", " 2\n", " 4\n", " -1.5\n", " -1.4\n", " -1.3\n", " -1.2\n", " -1.1\n", " -1.0\n", " -0.9\n", " -0.8\n", " -0.7\n", " -0.6\n", " -0.5\n", " -0.4\n", " -0.3\n", " -0.2\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " 0.3\n", " 0.4\n", " 0.5\n", " 0.6\n", " 0.7\n", " 0.8\n", " 0.9\n", " 1.0\n", " 1.1\n", " 1.2\n", " 1.3\n", " 1.4\n", " 1.5\n", " 1.6\n", " 1.7\n", " 1.8\n", " 1.9\n", " 2.0\n", " 2.1\n", " 2.2\n", " 2.3\n", " 2.4\n", " 2.5\n", " 2.6\n", " 2.7\n", " 2.8\n", " 2.9\n", " 3.0\n", " \n", " \n", " Fit\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Gadfly, DataFrames\n", "df = DataFrame(x=p_vals, y=huber_data, label=\"Huber\");\n", "df = vcat(df, DataFrame(x=p_vals, y=prescient_data, label=\"Prescient\"));\n", "df = vcat(df, DataFrame(x=p_vals, y=lsq_data, label=\"Least squares\"));\n", "plot(df, x=\"x\", y=\"y\", color=\"label\", Geom.line, Guide.XLabel(\"p\"), Guide.YLabel(\"Fit\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " p\n", " \n", " \n", " 0.00\n", " 0.02\n", " 0.04\n", " 0.06\n", " 0.08\n", " \n", " \n", " \n", " Huber\n", " Prescient\n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 0.00\n", " 0.01\n", " 0.02\n", " 0.03\n", " 0.04\n", " 0.05\n", " \n", " \n", " Fit\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " p\n", " \n", " \n", " -0.10\n", " -0.08\n", " -0.06\n", " -0.04\n", " -0.02\n", " 0.00\n", " 0.02\n", " 0.04\n", " 0.06\n", " 0.08\n", " 0.10\n", " 0.12\n", " 0.14\n", " 0.16\n", " 0.18\n", " -0.080\n", " -0.075\n", " -0.070\n", " -0.065\n", " -0.060\n", " -0.055\n", " -0.050\n", " -0.045\n", " -0.040\n", " -0.035\n", " -0.030\n", " -0.025\n", " -0.020\n", " -0.015\n", " -0.010\n", " -0.005\n", " 0.000\n", " 0.005\n", " 0.010\n", " 0.015\n", " 0.020\n", " 0.025\n", " 0.030\n", " 0.035\n", " 0.040\n", " 0.045\n", " 0.050\n", " 0.055\n", " 0.060\n", " 0.065\n", " 0.070\n", " 0.075\n", " 0.080\n", " 0.085\n", " 0.090\n", " 0.095\n", " 0.100\n", " 0.105\n", " 0.110\n", " 0.115\n", " 0.120\n", " 0.125\n", " 0.130\n", " 0.135\n", " 0.140\n", " 0.145\n", " 0.150\n", " 0.155\n", " 0.160\n", " 0.165\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " -0.080\n", " -0.075\n", " -0.070\n", " -0.065\n", " -0.060\n", " -0.055\n", " -0.050\n", " -0.045\n", " -0.040\n", " -0.035\n", " -0.030\n", " -0.025\n", " -0.020\n", " -0.015\n", " -0.010\n", " -0.005\n", " 0.000\n", " 0.005\n", " 0.010\n", " 0.015\n", " 0.020\n", " 0.025\n", " 0.030\n", " 0.035\n", " 0.040\n", " 0.045\n", " 0.050\n", " 0.055\n", " 0.060\n", " 0.065\n", " 0.070\n", " 0.075\n", " 0.080\n", " 0.085\n", " 0.090\n", " 0.095\n", " 0.100\n", " 0.105\n", " 0.110\n", " 0.115\n", " 0.120\n", " 0.125\n", " 0.130\n", " 0.135\n", " 0.140\n", " 0.145\n", " 0.150\n", " 0.155\n", " 0.160\n", " 0.165\n", " \n", " \n", " \n", " Huber\n", " Prescient\n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -0.06\n", " -0.05\n", " -0.04\n", " -0.03\n", " -0.02\n", " -0.01\n", " 0.00\n", " 0.01\n", " 0.02\n", " 0.03\n", " 0.04\n", " 0.05\n", " 0.06\n", " 0.07\n", " 0.08\n", " 0.09\n", " 0.10\n", " 0.11\n", " -0.050\n", " -0.048\n", " -0.046\n", " -0.044\n", " -0.042\n", " -0.040\n", " -0.038\n", " -0.036\n", " -0.034\n", " -0.032\n", " -0.030\n", " -0.028\n", " -0.026\n", " -0.024\n", " -0.022\n", " -0.020\n", " -0.018\n", " -0.016\n", " -0.014\n", " -0.012\n", " -0.010\n", " -0.008\n", " -0.006\n", " -0.004\n", " -0.002\n", " 0.000\n", " 0.002\n", " 0.004\n", " 0.006\n", " 0.008\n", " 0.010\n", " 0.012\n", " 0.014\n", " 0.016\n", " 0.018\n", " 0.020\n", " 0.022\n", " 0.024\n", " 0.026\n", " 0.028\n", " 0.030\n", " 0.032\n", " 0.034\n", " 0.036\n", " 0.038\n", " 0.040\n", " 0.042\n", " 0.044\n", " 0.046\n", " 0.048\n", " 0.050\n", " 0.052\n", " 0.054\n", " 0.056\n", " 0.058\n", " 0.060\n", " 0.062\n", " 0.064\n", " 0.066\n", " 0.068\n", " 0.070\n", " 0.072\n", " 0.074\n", " 0.076\n", " 0.078\n", " 0.080\n", " 0.082\n", " 0.084\n", " 0.086\n", " 0.088\n", " 0.090\n", " 0.092\n", " 0.094\n", " 0.096\n", " 0.098\n", " 0.100\n", " 0.102\n", " -0.10\n", " -0.05\n", " 0.00\n", " 0.05\n", " 0.10\n", " -0.050\n", " -0.045\n", " -0.040\n", " -0.035\n", " -0.030\n", " -0.025\n", " -0.020\n", " -0.015\n", " -0.010\n", " -0.005\n", " 0.000\n", " 0.005\n", " 0.010\n", " 0.015\n", " 0.020\n", " 0.025\n", " 0.030\n", " 0.035\n", " 0.040\n", " 0.045\n", " 0.050\n", " 0.055\n", " 0.060\n", " 0.065\n", " 0.070\n", " 0.075\n", " 0.080\n", " 0.085\n", " 0.090\n", " 0.095\n", " 0.100\n", " 0.105\n", " \n", " \n", " Fit\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Plot the relative reconstruction error for Huber and prescient regression,\n", "# zooming in on smaller values of p.\n", "indices = find(p_vals .<= 0.08);\n", "df = DataFrame(x=p_vals[indices], y=huber_data[indices], label=\"Huber\");\n", "df = vcat(df, DataFrame(x=p_vals[indices], y=prescient_data[indices], label=\"Prescient\"));\n", "plot(df, x=\"x\", y=\"y\", color=\"label\", Geom.line, Guide.XLabel(\"p\"), Guide.YLabel(\"Fit\"))" ] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.3.9", "language": "julia", "name": "julia-0.3" }, "language_info": { "name": "julia", "version": "0.3.9" } }, "nbformat": 4, "nbformat_minor": 0 }