{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Generate data for SVM classifier with L1 regularization.\n", "srand(1);\n", "n = 20;\n", "m = 1000;\n", "TEST = m;\n", "DENSITY = 0.2;\n", "beta_true = randn(n,1);\n", "idxs = randperm(n)[1:int((1-DENSITY)*n)];\n", "for idx in idxs\n", " beta_true[idx] = 0;\n", "end\n", "offset = 0;\n", "sigma = 45;\n", "X = 5 * randn(m, n);\n", "Y = sign(X * beta_true + offset + sigma * randn(m,1));\n", "X_test = 5 * randn(TEST, n);" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Form SVM with L1 regularization problem.\n", "using Convex, SCS\n", "set_default_solver(SCSSolver(verbose=0));\n", "beta = Variable(n);\n", "v = Variable();\n", "loss = sum(pos(1 - Y .* (X*beta - v)));\n", "reg = norm(beta, 1);\n", "\n", "# Compute a trade-off curve and record train and test error.\n", "TRIALS = 100\n", "train_error = zeros(TRIALS);\n", "test_error = zeros(TRIALS);\n", "lambda_vals = logspace(-2, 0, TRIALS);\n", "beta_vals = zeros(length(beta), TRIALS);\n", "for i = 1:TRIALS\n", " lambda = lambda_vals[i];\n", " problem = minimize(loss/m + lambda*reg);\n", " solve!(problem);\n", " train_error[i] = sum(float(sign(X*beta_true + offset) .!= sign(evaluate(X*beta - v))))/m;\n", " test_error[i] = sum(float(sign(X_test*beta_true + offset) .!= sign(evaluate(X_test*beta - v))))/TEST;\n", " beta_vals[:, i] = evaluate(beta);\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " λ\n", " \n", " \n", " 10-2.0\n", " 10-1.5\n", " 10-1.0\n", " 10-0.5\n", " 100.0\n", " \n", " \n", " \n", " Train error\n", " Test error\n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 10-1.1\n", " 10-1.0\n", " 10-0.9\n", " 10-0.8\n", " 10-0.7\n", " 10-0.6\n", " \n", " \n", " errors\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " λ\n", " \n", " \n", " 10-4.5\n", " 10-4.0\n", " 10-3.5\n", " 10-3.0\n", " 10-2.5\n", " 10-2.0\n", " 10-1.5\n", " 10-1.0\n", " 10-0.5\n", " 100.0\n", " 100.5\n", " 101.0\n", " 101.5\n", " 102.0\n", " 102.5\n", " 10-4.0\n", " 10-3.9\n", " 10-3.8\n", " 10-3.7\n", " 10-3.6\n", " 10-3.5\n", " 10-3.4\n", " 10-3.3\n", " 10-3.2\n", " 10-3.1\n", " 10-3.0\n", " 10-2.9\n", " 10-2.8\n", " 10-2.7\n", " 10-2.6\n", " 10-2.5\n", " 10-2.4\n", " 10-2.3\n", " 10-2.2\n", " 10-2.1\n", " 10-2.0\n", " 10-1.9\n", " 10-1.8\n", " 10-1.7\n", " 10-1.6\n", " 10-1.5\n", " 10-1.4\n", " 10-1.3\n", " 10-1.2\n", " 10-1.1\n", " 10-1.0\n", " 10-0.9\n", " 10-0.8\n", " 10-0.7\n", " 10-0.6\n", " 10-0.5\n", " 10-0.4\n", " 10-0.3\n", " 10-0.2\n", " 10-0.1\n", " 100.0\n", " 100.1\n", " 100.2\n", " 100.3\n", " 100.4\n", " 100.5\n", " 100.6\n", " 100.7\n", " 100.8\n", " 100.9\n", " 101.0\n", " 101.1\n", " 101.2\n", " 101.3\n", " 101.4\n", " 101.5\n", " 101.6\n", " 101.7\n", " 101.8\n", " 101.9\n", " 102.0\n", " 10-4\n", " 10-2\n", " 100\n", " 102\n", " 10-4.0\n", " 10-3.8\n", " 10-3.6\n", " 10-3.4\n", " 10-3.2\n", " 10-3.0\n", " 10-2.8\n", " 10-2.6\n", " 10-2.4\n", " 10-2.2\n", " 10-2.0\n", " 10-1.8\n", " 10-1.6\n", " 10-1.4\n", " 10-1.2\n", " 10-1.0\n", " 10-0.8\n", " 10-0.6\n", " 10-0.4\n", " 10-0.2\n", " 100.0\n", " 100.2\n", " 100.4\n", " 100.6\n", " 100.8\n", " 101.0\n", " 101.2\n", " 101.4\n", " 101.6\n", " 101.8\n", " 102.0\n", " \n", " \n", " \n", " Train error\n", " Test error\n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 10-1.7\n", " 10-1.6\n", " 10-1.5\n", " 10-1.4\n", " 10-1.3\n", " 10-1.2\n", " 10-1.1\n", " 10-1.0\n", " 10-0.9\n", " 10-0.8\n", " 10-0.7\n", " 10-0.6\n", " 10-0.5\n", " 10-0.4\n", " 10-0.3\n", " 10-0.2\n", " 10-0.1\n", " 100.0\n", " 10-1.60\n", " 10-1.58\n", " 10-1.56\n", " 10-1.54\n", " 10-1.52\n", " 10-1.50\n", " 10-1.48\n", " 10-1.46\n", " 10-1.44\n", " 10-1.42\n", " 10-1.40\n", " 10-1.38\n", " 10-1.36\n", " 10-1.34\n", " 10-1.32\n", " 10-1.30\n", " 10-1.28\n", " 10-1.26\n", " 10-1.24\n", " 10-1.22\n", " 10-1.20\n", " 10-1.18\n", " 10-1.16\n", " 10-1.14\n", " 10-1.12\n", " 10-1.10\n", " 10-1.08\n", " 10-1.06\n", " 10-1.04\n", " 10-1.02\n", " 10-1.00\n", " 10-0.98\n", " 10-0.96\n", " 10-0.94\n", " 10-0.92\n", " 10-0.90\n", " 10-0.88\n", " 10-0.86\n", " 10-0.84\n", " 10-0.82\n", " 10-0.80\n", " 10-0.78\n", " 10-0.76\n", " 10-0.74\n", " 10-0.72\n", " 10-0.70\n", " 10-0.68\n", " 10-0.66\n", " 10-0.64\n", " 10-0.62\n", " 10-0.60\n", " 10-0.58\n", " 10-0.56\n", " 10-0.54\n", " 10-0.52\n", " 10-0.50\n", " 10-0.48\n", " 10-0.46\n", " 10-0.44\n", " 10-0.42\n", " 10-0.40\n", " 10-0.38\n", " 10-0.36\n", " 10-0.34\n", " 10-0.32\n", " 10-0.30\n", " 10-0.28\n", " 10-0.26\n", " 10-0.24\n", " 10-0.22\n", " 10-0.20\n", " 10-0.18\n", " 10-0.16\n", " 10-0.14\n", " 10-0.12\n", " 10-0.10\n", " 10-0.08\n", " 10-0.06\n", " 10-0.04\n", " 10-0.02\n", " 100.00\n", " 10-2.0\n", " 10-1.5\n", " 10-1.0\n", " 10-0.5\n", " 100.0\n", " 10-1.60\n", " 10-1.55\n", " 10-1.50\n", " 10-1.45\n", " 10-1.40\n", " 10-1.35\n", " 10-1.30\n", " 10-1.25\n", " 10-1.20\n", " 10-1.15\n", " 10-1.10\n", " 10-1.05\n", " 10-1.00\n", " 10-0.95\n", " 10-0.90\n", " 10-0.85\n", " 10-0.80\n", " 10-0.75\n", " 10-0.70\n", " 10-0.65\n", " 10-0.60\n", " 10-0.55\n", " 10-0.50\n", " 10-0.45\n", " 10-0.40\n", " 10-0.35\n", " 10-0.30\n", " 10-0.25\n", " 10-0.20\n", " 10-0.15\n", " 10-0.10\n", " 10-0.05\n", " 100.00\n", " \n", " \n", " errors\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Plot the train and test error over the trade-off curve.\n", "using Gadfly, DataFrames\n", "df1 = DataFrame(λ=lambda_vals, errors=train_error, label=\"Train error\");\n", "df2 = DataFrame(λ=lambda_vals, errors=test_error, label=\"Test error\");\n", "df = vcat(df1, df2);\n", "\n", "plot(df, x=\"λ\", y=\"errors\", color=\"label\", Geom.line,\n", " Scale.x_log10, Scale.y_log10)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " λ\n", " \n", " \n", " 10-2.0\n", " 10-1.5\n", " 10-1.0\n", " 10-0.5\n", " 100.0\n", " \n", " \n", " \n", " beta11\n", " beta12\n", " beta13\n", " beta14\n", " beta15\n", " beta16\n", " beta17\n", " beta18\n", " beta19\n", " beta20\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " beta1\n", " beta2\n", " beta3\n", " beta4\n", " beta5\n", " beta6\n", " beta7\n", " beta8\n", " beta9\n", " beta10\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -0.15\n", " -0.10\n", " -0.05\n", " 0.00\n", " 0.05\n", " 0.10\n", " \n", " \n", " betas\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " λ\n", " \n", " \n", " 10-4.5\n", " 10-4.0\n", " 10-3.5\n", " 10-3.0\n", " 10-2.5\n", " 10-2.0\n", " 10-1.5\n", " 10-1.0\n", " 10-0.5\n", " 100.0\n", " 100.5\n", " 101.0\n", " 101.5\n", " 102.0\n", " 102.5\n", " 10-4.0\n", " 10-3.9\n", " 10-3.8\n", " 10-3.7\n", " 10-3.6\n", " 10-3.5\n", " 10-3.4\n", " 10-3.3\n", " 10-3.2\n", " 10-3.1\n", " 10-3.0\n", " 10-2.9\n", " 10-2.8\n", " 10-2.7\n", " 10-2.6\n", " 10-2.5\n", " 10-2.4\n", " 10-2.3\n", " 10-2.2\n", " 10-2.1\n", " 10-2.0\n", " 10-1.9\n", " 10-1.8\n", " 10-1.7\n", " 10-1.6\n", " 10-1.5\n", " 10-1.4\n", " 10-1.3\n", " 10-1.2\n", " 10-1.1\n", " 10-1.0\n", " 10-0.9\n", " 10-0.8\n", " 10-0.7\n", " 10-0.6\n", " 10-0.5\n", " 10-0.4\n", " 10-0.3\n", " 10-0.2\n", " 10-0.1\n", " 100.0\n", " 100.1\n", " 100.2\n", " 100.3\n", " 100.4\n", " 100.5\n", " 100.6\n", " 100.7\n", " 100.8\n", " 100.9\n", " 101.0\n", " 101.1\n", " 101.2\n", " 101.3\n", " 101.4\n", " 101.5\n", " 101.6\n", " 101.7\n", " 101.8\n", " 101.9\n", " 102.0\n", " 10-4\n", " 10-2\n", " 100\n", " 102\n", " 10-4.0\n", " 10-3.8\n", " 10-3.6\n", " 10-3.4\n", " 10-3.2\n", " 10-3.0\n", " 10-2.8\n", " 10-2.6\n", " 10-2.4\n", " 10-2.2\n", " 10-2.0\n", " 10-1.8\n", " 10-1.6\n", " 10-1.4\n", " 10-1.2\n", " 10-1.0\n", " 10-0.8\n", " 10-0.6\n", " 10-0.4\n", " 10-0.2\n", " 100.0\n", " 100.2\n", " 100.4\n", " 100.6\n", " 100.8\n", " 101.0\n", " 101.2\n", " 101.4\n", " 101.6\n", " 101.8\n", " 102.0\n", " \n", " \n", " \n", " beta11\n", " beta12\n", " beta13\n", " beta14\n", " beta15\n", " beta16\n", " beta17\n", " beta18\n", " beta19\n", " beta20\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " beta1\n", " beta2\n", " beta3\n", " beta4\n", " beta5\n", " beta6\n", " beta7\n", " beta8\n", " beta9\n", " beta10\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " label\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -0.45\n", " -0.40\n", " -0.35\n", " -0.30\n", " -0.25\n", " -0.20\n", " -0.15\n", " -0.10\n", " -0.05\n", " 0.00\n", " 0.05\n", " 0.10\n", " 0.15\n", " 0.20\n", " 0.25\n", " 0.30\n", " 0.35\n", " 0.40\n", " -0.40\n", " -0.39\n", " -0.38\n", " -0.37\n", " -0.36\n", " -0.35\n", " -0.34\n", " -0.33\n", " -0.32\n", " -0.31\n", " -0.30\n", " -0.29\n", " -0.28\n", " -0.27\n", " -0.26\n", " -0.25\n", " -0.24\n", " -0.23\n", " -0.22\n", " -0.21\n", " -0.20\n", " -0.19\n", " -0.18\n", " -0.17\n", " -0.16\n", " -0.15\n", " -0.14\n", " -0.13\n", " -0.12\n", " -0.11\n", " -0.10\n", " -0.09\n", " -0.08\n", " -0.07\n", " -0.06\n", " -0.05\n", " -0.04\n", " -0.03\n", " -0.02\n", " -0.01\n", " 0.00\n", " 0.01\n", " 0.02\n", " 0.03\n", " 0.04\n", " 0.05\n", " 0.06\n", " 0.07\n", " 0.08\n", " 0.09\n", " 0.10\n", " 0.11\n", " 0.12\n", " 0.13\n", " 0.14\n", " 0.15\n", " 0.16\n", " 0.17\n", " 0.18\n", " 0.19\n", " 0.20\n", " 0.21\n", " 0.22\n", " 0.23\n", " 0.24\n", " 0.25\n", " 0.26\n", " 0.27\n", " 0.28\n", " 0.29\n", " 0.30\n", " 0.31\n", " 0.32\n", " 0.33\n", " 0.34\n", " 0.35\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " -0.40\n", " -0.38\n", " -0.36\n", " -0.34\n", " -0.32\n", " -0.30\n", " -0.28\n", " -0.26\n", " -0.24\n", " -0.22\n", " -0.20\n", " -0.18\n", " -0.16\n", " -0.14\n", " -0.12\n", " -0.10\n", " -0.08\n", " -0.06\n", " -0.04\n", " -0.02\n", " 0.00\n", " 0.02\n", " 0.04\n", " 0.06\n", " 0.08\n", " 0.10\n", " 0.12\n", " 0.14\n", " 0.16\n", " 0.18\n", " 0.20\n", " 0.22\n", " 0.24\n", " 0.26\n", " 0.28\n", " 0.30\n", " 0.32\n", " 0.34\n", " 0.36\n", " \n", " \n", " betas\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Plot the regularization path for beta.\n", "df = DataFrame(λ=lambda_vals, betas=vec(beta_vals[1,:]), label=\"beta1\")\n", "for i=2:n\n", " df = vcat(df, DataFrame(λ=lambda_vals, betas=vec(beta_vals[i,:]), label=string(\"beta\", i)));\n", "end\n", "plot(df, x=\"λ\", y=\"betas\", color=\"label\", Geom.line, Scale.x_log10)" ] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.3.9", "language": "julia", "name": "julia-0.3" }, "language_info": { "name": "julia", "version": "0.3.9" } }, "nbformat": 4, "nbformat_minor": 0 }