{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Loss functions\n", "> In this chapter you will discover the conceptual framework behind logistic regression and SVMs. This will let you delve deeper into the inner workings of these models. This is the Summary of lecture \"Linear Classifiers in Python\", via datacamp.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Datacamp, Machine_Learning]\n", "- image: images/log_hinge.png" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear classifiers - the coefficients\n", "- Dot Products\n", " - `x@y` is called the dot product of `x` and `y`, and is written $x \\cdot y$\n", "- Linear Classifier predictions\n", " - raw model output = coefficients $\\cdot$ features + intercept\n", " - Linear classifier prediction: compute raw model output, check the sign\n", " - if positive, predict one class\n", " - if negative, predict the other class\n", " - This is the same for logistic regression and linear SVM\n", " - `fit` is different but `predict` is the same" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Changing the model coefficients\n", "When you call `fit` with scikit-learn, the logistic regression coefficients are automatically learned from your dataset. In this exercise you will explore how the decision boundary is represented by the coefficients. To do so, you will change the coefficients manually (instead of with `fit`), and visualize the resulting classifiers." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#hide\n", "X = np.array([[ 1.78862847, 0.43650985],\n", " [ 0.09649747, -1.8634927 ],\n", " [-0.2773882 , -0.35475898],\n", " [-3.08274148, 2.37299932],\n", " [-3.04381817, 2.52278197],\n", " [-1.31386475, 0.88462238],\n", " [-2.11868196, 4.70957306],\n", " [-2.94996636, 2.59532259],\n", " [-3.54535995, 1.45352268],\n", " [ 0.98236743, -1.10106763],\n", " [-1.18504653, -0.2056499 ],\n", " [-1.51385164, 3.23671627],\n", " [-4.02378514, 2.2870068 ],\n", " [ 0.62524497, -0.16051336],\n", " [-3.76883635, 2.76996928],\n", " [ 0.74505627, 1.97611078],\n", " [-1.24412333, -0.62641691],\n", " [-0.80376609, -2.41908317],\n", " [-0.92379202, -1.02387576],\n", " [ 1.12397796, -0.13191423]])\n", "\n", "y = np.array([-1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1,\n", " -1, -1, -1])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def make_meshgrid(x, y, h=.02, lims=None):\n", " \"\"\"Create a mesh of points to plot in\n", " \n", " Parameters\n", " ----------\n", " x: data to base x-axis meshgrid on\n", " y: data to base y-axis meshgrid on\n", " h: stepsize for meshgrid, optional\n", " \n", " Returns\n", " -------\n", " xx, yy : ndarray\n", " \"\"\"\n", " \n", " if lims is None:\n", " x_min, x_max = x.min() - 1, x.max() + 1\n", " y_min, y_max = y.min() - 1, y.max() + 1\n", " else:\n", " x_min, x_max, y_min, y_max = lims\n", " xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n", " np.arange(y_min, y_max, h))\n", " return xx, yy\n", "\n", "def plot_contours(ax, clf, xx, yy, proba=False, **params):\n", " \"\"\"Plot the decision boundaries for a classifier.\n", " \n", " Parameters\n", " ----------\n", " ax: matplotlib axes object\n", " clf: a classifier\n", " xx: meshgrid ndarray\n", " yy: meshgrid ndarray\n", " params: dictionary of params to pass to contourf, optional\n", " \"\"\"\n", " if proba:\n", " Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,-1]\n", " Z = Z.reshape(xx.shape)\n", " out = ax.imshow(Z,extent=(np.min(xx), np.max(xx), np.min(yy), np.max(yy)), \n", " origin='lower', vmin=0, vmax=1, **params)\n", " ax.contour(xx, yy, Z, levels=[0.5])\n", " else:\n", " Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])\n", " Z = Z.reshape(xx.shape)\n", " out = ax.contourf(xx, yy, Z, **params)\n", " return out\n", "\n", "def plot_classifier(X, y, clf, ax=None, ticks=False, proba=False, lims=None): \n", " # assumes classifier \"clf\" is already fit\n", " X0, X1 = X[:, 0], X[:, 1]\n", " xx, yy = make_meshgrid(X0, X1, lims=lims)\n", " \n", " if ax is None:\n", " plt.figure()\n", " ax = plt.gca()\n", " show = True\n", " else:\n", " show = False\n", " \n", " # can abstract some of this into a higher-level function for learners to call\n", " cs = plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8, proba=proba)\n", " if proba:\n", " cbar = plt.colorbar(cs)\n", " cbar.ax.set_ylabel('probability of red $\\Delta$ class', fontsize=20, rotation=270, labelpad=30)\n", " cbar.ax.tick_params(labelsize=14)\n", " #ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=30, edgecolors=\\'k\\', linewidth=1)\n", " labels = np.unique(y)\n", " if len(labels) == 2:\n", " ax.scatter(X0[y==labels[0]], X1[y==labels[0]], cmap=plt.cm.coolwarm, \n", " s=60, c='b', marker='o', edgecolors='k')\n", " ax.scatter(X0[y==labels[1]], X1[y==labels[1]], cmap=plt.cm.coolwarm, \n", " s=60, c='r', marker='^', edgecolors='k')\n", " else:\n", " ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=50, edgecolors='k', linewidth=1)\n", "\n", " ax.set_xlim(xx.min(), xx.max())\n", " ax.set_ylim(yy.min(), yy.max())\n", " # ax.set_xlabel(data.feature_names[0])\n", " # ax.set_ylabel(data.feature_names[1])\n", " if ticks:\n", " ax.set_xticks(())\n", " ax.set_yticks(())\n", " # ax.set_title(title)\n", " if show:\n", " plt.show()\n", " else:\n", " return ax" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, l1_ratio=None, max_iter=100,\n", " multi_class='auto', n_jobs=None, penalty='l2',\n", " random_state=None, solver='lbfgs', tol=0.0001, verbose=0,\n", " warm_start=False)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "model = LogisticRegression()\n", "model.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASZklEQVR4nO3dfYxVdX7H8c93GBRkVtdkFTswERAqa3HWB4JPpG5Au8AI1m5stVZJ9g9rU5M1alyRKEusJVt30U1304ZWs7iOuhpWV7eQgm59FgSpDhcQZXkQGVaU4o48DjP32z/mjlyHeWLuOfec3z3vVzIJ956bc75zmfnM7/zu78HcXQCAcFUlXQAAoDQEOQAEjiAHgMAR5AAQOIIcAAJXncRFTx0yxGtrapK4NAAEa8OePZ+5+2ldn08kyGtravT01VclcWkACNaERxdv7+55ulYAIHAEOQAEjiAHgMAR5AAQOIIcAAJHkANA4AhyAAgcQQ4AgSPIASBwBDkABI4gB4DAEeQAEDiCHAACR5ADQOAIcgAIHEEOAIEjyAEgcAQ5AASOIEfsWg636pZly9VyuDXpUoCKRJAjdo25nFbv2qXG3PqkSwEqEkGOWLUcbtUT6zfqOUlPrt9AqxyIAUGOWDXmcprprmmSGtxplQMxIMgRm87W+L3t7ZKk+9rbaZUDMSDIEZvO1vjYwuOxolUOxIEgRyy6tsY70SoHokeQIxaNuZwuyec1SNLWoq9Bki7O52mVAxGqTroAVKZtez/XusGDdXkPx23v3rLWA1Qyghyx+NEVU5IuAcgMulYAIHAEOQAEjiAHgMAR5AAQOIIcAAJHkANA4AhyAAgcQQ4AgYssyM1skJn9r5n9NqpzAgD6FmWL/PuSNkZ4PgBAP0QS5GY2UlKDpP+M4nwAgP6LqkX+sKS7JOV7eoGZ3Wxma8xszd5DhyK6LACg5CA3s6sk7Xb3d3p7nbsvcveJ7j7x1CFDSr0sAKAgihb5ZZJmmdk2SU9JmmJmj0dwXgBAP5Qc5O4+x91HuvsoSddJ+p27/13JlQEA+oVx5AAQuEg3lnD3lyW9HOU5AQC9o0UOAIEjyAEgcAQ5AASOIAeAwBHkABA4ghwAAkeQA0DgCHIACBxBjorRcrhVtyxbrpbDrUmXApQVQY6K0ZjLafWuXWrMrU+6FKCsCHJUhJbDrXpi/UY9J+nJ9RtolSNTCHJUhMZcTjPdNU1SgzutcmQKQY7gdbbG721vlyTd195OqxyZQpAjeJ2t8bGFx2NFqxzZQpAjaF1b451olSNLCHIErTGX0yX5vAZJ2lr0NUjSxfk8rXJkQqQbSwDltm3v51o3eLAu7+G47d1b1nqAJBDkCNqPrpiSdAlA4uhaAYDAEeQRYXo4gKQQ5BFhejiApBDkEWB6OIAkEeQRYHo4gCQR5CVK4/Rw+uuBbCHIS5TG6eH01wPZQpCXIC3Tw4tb4PTXA9lDkJcgLdPDi1vg9NcD2cPMzhKkYXp4cQv8+tx6maS3i/rrL1q/QTdM+DOdfOIJsdcCIBkEeQnSMD28uAU+Op/XN6Vu++v/4cLzkysSQKzoWgnYzi/26RdNOd3b3q7PJW1313z3r7wmDaNoAMSLIA/YvFdeVUNhxMxPJV0qJd5fD6D86FrpRcvhVt31u5f1L1O+nbo+5p1f7NN7uz/VLwuPN0nKSbpc0ueSTjnxRFnR61nOFahcBHkvikeDpK2Ped4rr+rPdbQF/kDRsVurqjR6/NmpqxlAPAjyHhSPBrkhZSM/Wg63asOnn+kUSVMktevYVjgtcCA7CPIedDceOy0t3MZcTn9ZVaXFRRORbho0SCfTCgcyiQ87u5HG9VM6pWU2KYD0IMi7kcb1UzqlZTYpgPSga6WLzhbv2920eNMwSzINs0kBpEvJQW5mdZIek3SGpLykRe7+01LPm5SuLd5OxS3eJPuh0zCbFOgq766lW7bqsdwOfbJ/v4YPG6abJtRpxpjRqjLr+wQoSRQt8jZJd7j7WjP7mqR3zGyFu2+I4NxlR4sXOD55d9320kq91TxEB9t+LKleew41af4b87Vi6yo9NPUiwjxmJQe5u++StKvw7y/MbKOkEZKCDHJavMDxWbplayHEV0kaUnh2nA62NejN5klatmWrGs4ak2SJFS/SDzvNbJSk8yWtivK8ANLrsdwOHWy7T0dDvNMQHWybp8W5HUmUlSmRBbmZ1UhaIuk2d2/p5vjNZrbGzNbsPXQoqssCSNgn+/dLqu/h6LmF44hTJEFuZoPVEeKN7v7r7l7j7ovcfaK7Tzx1SNe/3ABCNXzYMElNPRxdVziOOJUc5GZmkh6RtNHdF5ZeEoCQ3DShTkOr50vqeqd9SEOr52v2hLokysqUKFrkl0m6UdIUM3u38DUjgvMCCMCMMaN1Se1hDa2epI4b8w8kLdHQ6km6tLZV08eMTrjCyhfFqJXXJTG2CMioKjM9PPViLduyVYtzd345jnz2hDpNZxx5WTCzE0DJqszUcNYYhhkmhLVWACBwBDkABI4gB4DAEeRl1HK4VbcsW86a4QAiRZCXUfEeoAAQFYK8TIr3AGUnHwBRIsjLpLs9QAEgCgR5GaR5D1AA4SPIyyDNe4ACCB9BHjN2vQcQN4I8Zux6DyBurLUSM/YABRA3gjxm7AEKIG50rQBA4AhyAAgcQQ4AgaOPHBUv766lW7bqsdyOL3evuWlCnWawew0qBEGOipZ3120vrdRbzUN0sO3Hkuq151CT5r8xXyu2rtJDUy8izBE8ulZQ0ZZu2VoI8VWSvitpnKTv6mDb23qz+QQt27I14QqB0hHkqGiP5XboYNt9koZ0OTJEB9vmaXFuRxJlAZEiyFHRPtm/X1J9D0fPLRwHwkaQo6INHzZMUlMPR9cVjgNhI8hR0W6aUKeh1fMlHepy5JCGVs/X7Al1SZQFRIogR0WbMWa0Lqk9rKHVkyQtkfSBpCUaWj1Jl9a2avqY0QlXCJSO4YeoaFVmenjqxVq2ZasW5+78chz57Al1ms44clQIghwVr8pMDWeNUcNZY5IuBYgFXSsAEDha5AAyp9KWbSDIAWRKJS7bQNcKgEypxGUbCHIAmVKJyzYQ5AAypRKXbSDIAWRKJS7bQJADyJRKXLaBIAeQKZW4bAPDDwFkSiUu20CQA8icSlu2ga4VAAhcJEFuZtPMbJOZbTazu6M4JwCgf0oOcjMbJOnnkqZLOkfS9WZ2TqnnBQD0TxR95JMkbXb3LZJkZk9JulrShgjODSADKm0Rq3KLIshHSCqe0/qxpIu6vsjMbpZ0syQNP+kktexvi+DSAMop764V27fpVx8065MD+zT8pBr9zZ/W6sozRw04cPPuuueN1Vr9h6E62H50Easfvj5fSzev1D9fNpEw70MUQd7dO+zHPOG+SNIiSRo+st6f+dY/RXBpYOA8n9f7Tc9r7evPaF9Ls2pOrtUFk6/V+PpZsirGAXTl+bxeaLxdH/2+VUdaOwJ37+EmPbB2gR7/4oBm/u1PBvS+bXz3Ob316Qc60v6ajq5/Mk6H2hv05u7JekCTNP5bV0f5rYTrV43dPh3FT+vHkoqnQo2U1BzBeYHYdIbSS881anfz7Tqwb7l2N9+uF599XC88cYc8n0+6xNR5v+l5ffT7nTrS+pqKVw1sO/K6tn+4Q5uaXhjQede+/oyOtN6t7haxajsyR++88XRphWdAFEG+WtI4MxttZidIuk7S8xGcF4hNXKFUyeIK3H0tzeptEat9f6Rd2JeSg9zd2yTdKum/JW2U9LS7ry/1vECcaAUev7gCt+bkWvW2iFXNKbUDOm+WRDKz092XSloaxbmAckh7KzCN/fc1J9fqwL4mddy9dDXwwL1g8rV68dkFajvSoK/+YT2k6sELdOFlNw7ovFnCJzrIpDS3AtPaf3/B5GtVPXiBuls1sCNw/3pA5x1fP0tnjh2p6sGTVbyIVfXgyTpzXJ3Orp9ZWuEZQJAjk+IKpSiktf8+rsC1qirNvGGhrrzmRp0+YqFOqvmOTh/R8XigI2GyhkWzkEnj62fpw3Uva/vmyWo7MkfSuZLWqXrwgsRbgX333y/U+PPKPxyvM3A3Nb2gd95YqH1/bFbNKbW68LIbdXb9zJIC16qqNP68qxP5vioBQY5MijOUSpXm/nsCN50IcmRWWkMprg8VUbkIcmRKGkeDdMUoDhyvdPzkAmWQ1tEgXTGKA8eLFjky46ujQY6u6dF2pEHbP5ysTU0vpKKbJc3990gnghyZkdbRIN1Ja/890ok/7ciMNI8GAUpBkCMz0jybEygFQY7MSPNsTqAUBDkyg9EgqFR82InMYDQIKhVBjkxhNAgqEUEO9EMIM0KRXQQ50Iejmw7vLIxDr9eBfU168dkF+jD3CkutInEEOdCHUGaEhoA7m3jwzgF9YH/PaISy1k2ICHKgD8wIjUZadz6qBAQ50AdmhEaDO5v4EORAH5gRGg3ubOJDkAN9YEZoNLiziQ+jVoA+MCM0Gux8FB+CHOgHZoSWbnz9LH247mVt3zxZbUfmSDpX0jpVD17AnU2JCHIAZcGdTXwIcgBlw51NPPgTCACBI8gBIHAEOQAEjiAHgMAR5AAQOIIcAAJHkANA4AhyAAgcE4KACsROPNlCkAMVhj1Gs4f/TaDCsBNP9hDkQIVhJ57sKSnIzexBM3vfzJrM7Fkz+3pUhQEYGHbiyZ5SW+QrJE1w93p1bJsyp/SSAJSCnXiyp6Qgd/fl7t5WeLhS0sjSSwJQCvYYzZ4o+8i/J2lZTwfN7GYzW2Nmaw7u3xPhZQEUY4/R7Olz+KGZvSjpjG4OzXX33xReM1dSm6TGns7j7oskLZKk4SPrfUDVAugTO/FkT59B7u5X9HbczGZLukrSVHcnoIEUYCeebClpQpCZTZP0A0mXu/uBaEoCAByPUu+xfibpa5JWmNm7ZvbvEdQEADgOJbXI3X1sVIUAAAaGtVaAAWJhKqQFQQ4MAAtTIU34SQMGgIWpkCYEOTAALEyFNCHIgQFgYSqkCUEODAALUyFNCHJgAFiYCmlCkAMDwMJUSBOGHwIDwMJUSBOCHBggFqZCWtBsAIDAEeQAEDiCHAACR5ADQOAIcgAIHEEOAIEjyAEgcAQ5AASOIAeAwBHkABA4ghwAAkeQA0DgCHIACBxBDgCBI8gBIHAEOQAEjiAHgMAR5AAQOHP38l/U7FNJ+yV9VvaLD8w3RK1xoNZ4UGv00lLnme5+WtcnEwlySTKzNe4+MZGLHydqjQe1xoNao5f2OulaAYDAEeQAELgkg3xRgtc+XtQaD2qNB7VGL9V1JtZHDgCIBl0rABA4ghwAApdokJvZD81sp5m9W/iakWQ9/WFmd5qZm9k3kq6lJ2Z2v5k1Fd7T5WZWm3RN3TGzB83s/UKtz5rZ15OuqSdmdq2ZrTezvJmlchiamU0zs01mttnM7k66np6Y2aNmttvMcknX0hczqzOz/zGzjYX//+8nXVN30tAif8jdzyt8LU26mN6YWZ2kKyV9lHQtfXjQ3evd/TxJv5V0X9IF9WCFpAnuXi/pA0lzEq6nNzlJfyXp1aQL6Y6ZDZL0c0nTJZ0j6XozOyfZqnr0C0nTki6in9ok3eHu35R0saR/TOP7moYgD8lDku6SlOpPiN29pejhMKW0Xndf7u5thYcrJY1Msp7euPtGd9+UdB29mCRps7tvcfdWSU9Jujrhmrrl7q9K+r+k6+gPd9/l7msL//5C0kZJI5Kt6lhpCPJbC7fWj5rZqUkX0xMzmyVpp7u/l3Qt/WFmD5jZDkk3KL0t8mLfk7Qs6SICNkLSjqLHHyuFgRMyMxsl6XxJq5Kt5FjVcV/AzF6UdEY3h+ZK+jdJ96ujxXi/pJ+o4xc6EX3Ueo+kvyhvRT3rrVZ3/427z5U018zmSLpV0ryyFljQV52F18xVxy1sYzlr66o/taaYdfNcKu/EQmRmNZKWSLqtyx1vKsQe5O5+RX9eZ2b/oY7+3MT0VKuZnStptKT3zEzq6AJYa2aT3P0PZSzxS/19XyU9Iem/lFCQ91Wnmc2WdJWkqZ7wpIbjeE/T6GNJdUWPR0pqTqiWimJmg9UR4o3u/uuk6+lO0qNW/qTo4TXq+EApddx9nbuf7u6j3H2UOn5pLkgqxPtiZuOKHs6S9H5StfTGzKZJ+oGkWe5+IOl6Arda0jgzG21mJ0i6TtLzCdcUPOtouT0iaaO7L0y6np4kOrPTzH4p6Tx13AJuk/T37r4rsYL6ycy2SZro7mlY1vIYZrZE0tmS8pK2S7rF3XcmW9WxzGyzpBMl7Sk8tdLdb0mwpB6Z2TWS/lXSaZI+l/Suu38n2aq+qjB892FJgyQ96u4PJFxSt8zsSUnfVsfSsJ9ImufujyRaVA/MbLKk1yStU8fvkyTdk7YRdkzRB4DApWHUCgCgBAQ5AASOIAeAwBHkABA4ghwAAkeQA0DgCHIACNz/A9yoyewOon3mAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Number of errors: 3\n" ] } ], "source": [ "# Set the coefficients\n", "model.coef_ = np.array([[0,1]])\n", "model.intercept_ = np.array([0])\n", "\n", "# Plot the data and decision boundary\n", "plot_classifier(X,y,model)\n", "\n", "# Print the number of errors\n", "num_err = np.sum(y != model.predict(X))\n", "print(\"Number of errors:\", num_err)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is a loss function?\n", "- Least squares: the squared loss\n", " - scikit-learn's `LinearRegression` minimizes a loss:\n", " $$ \\sum_{i=1}^{n}(\\text{true ith target value - predicted ith target value})^2 $$\n", " - Minimization is with respect to coefficients or parameters of the model.\n", "- Classification errors: the 0-1 loss\n", " - Squared loss not appropriate for classification problems\n", " - A natrual loss for classification problem is the number of errors\n", " - This is the **0-1 loss**: it's 0 for a correct prediction and 1 for an incorrect prediction\n", " - But this loss is hard to minimize" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Minimizing a loss function\n", "In this exercise you'll implement linear regression \"from scratch\" using `scipy.optimize.minimize`.\n", "\n", "We'll train a model on the Boston housing price data set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "X = pd.read_csv('./dataset/boston_X.csv').to_numpy()\n", "y = pd.read_csv('./dataset/boston_y.csv').to_numpy()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-9.16298525e-02 4.86755134e-02 -3.77647962e-03 2.85635806e+00\n", " -2.88074603e+00 5.92522231e+00 -7.22459484e-03 -9.67997914e-01\n", " 1.70448274e-01 -9.38966357e-03 -3.92421957e-01 1.49830960e-02\n", " -4.16972109e-01]\n", "[[-9.16297843e-02 4.86751203e-02 -3.77930006e-03 2.85636751e+00\n", " -2.88077933e+00 5.92521432e+00 -7.22447929e-03 -9.67995240e-01\n", " 1.70443393e-01 -9.38925373e-03 -3.92425680e-01 1.49832102e-02\n", " -4.16972624e-01]]\n" ] } ], "source": [ "from scipy.optimize import minimize\n", "from sklearn.linear_model import LinearRegression\n", "\n", "# The squared error, summed overt training examples\n", "def my_loss(w):\n", " s = 0\n", " for i in range(y.size):\n", " # Get the true and predicted target values for example 'i'\n", " y_i_true = y[i]\n", " y_i_pred = w@X[i]\n", " s = s + (y_i_true - y_i_pred) ** 2\n", " return s\n", "\n", "# Returns the w that makes my_loss(w) smallest\n", "w_fit = minimize(my_loss, X[0]).x\n", "print(w_fit)\n", "\n", "# Compare with scikit-learn's LinearRegression coefficients\n", "lr = LinearRegression(fit_intercept=False).fit(X, y)\n", "print(lr.coef_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss function diagrams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Comparing the logistic and hinge losses\n", "In this exercise you'll create a plot of the logistic and hinge losses using their mathematical expressions, which are provided to you." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3wU1frH8c9JT0iDFAIkIaFFegsdKdKbCkgXQZrYuZbfVa9Xr+Xau4iKNBGEIKIC0pFeDRB676GGQEILkHJ+f8zixZCQbNhkdjfP+/XKi012due7Az5Ozpx5jtJaI4QQwvG5mB1ACCGEbUhBF0IIJyEFXQghnIQUdCGEcBJS0IUQwkm4mbXj4OBgHRUVZdbuhcjV3r17AYiJiTE5iRC327Rp0zmtdUhOz5lW0KOiooiPjzdr90LkqlWrVgAsX77c1BxC5EQpdTS352TIRQghnIRpZ+hC2KtXX33V7AhCFIgUdCGyadu2rdkRhCgQKehCZJOQkABAnTp1TE7i/NLT00lMTOTatWtmR7E7Xl5ehIeH4+7unu/XSEEXIptRo0YBclG0KCQmJuLn50dUVBRKKbPj2A2tNcnJySQmJhIdHZ3v1+V5UVQp5aWU2qiU2qqU2qmUeiOHbTyVUnFKqQNKqQ1KqSir0gshiqVr164RFBQkxTwbpRRBQUFW/+aSn1ku14H7tNa1gTpAR6VU42zbDAUuaK0rAZ8C71uVQghRbEkxz1lBjkueBV0bLlu+dbd8Ze+5+wDwveXxTKCNKqy/pSvJMP8lSE8rlLcXQghHla956EopV6VUAnAWWKy13pBtk3LAcQCtdQaQCgTl8D4jlFLxSqn4pKSkgiU+vBw2fAPfd4PLBXwPIYSw8PX1LfBrhw0bxq5du3J9ftKkSZw8eTLf29+tfF0U1VpnAnWUUoHAL0qpGlrrHbdsktPZ+G0rZ2itxwJjAWJjYwu2skaNnuDiDrOGw/i2MGAmBFcu0FsJkZN33nnH7AjCQYwbN+6Oz0+aNIkaNWpQtmzZfG1/t6y6U1RrnQIsBzpmeyoRiABQSrkBAcB5G+TLWbX7YfDvcP0yjGsLR9YU2q5E8dO0aVOaNm1qdgxRxLTWvPjii9SoUYOaNWsSFxcHQFZWFk888QTVq1ena9eudO7cmZkzZwJGm4j4+HgyMzMZPHjwX6/99NNPmTlzJvHx8QwYMIA6deqQlpb21/YACxYsoF69etSuXZs2bdrY5DPkeYaulAoB0rXWKUopb6Att1/0nA0MAtYBDwF/6MJe2y48FoYtgam94IcH4YExUKtXoe5SFA9r164FkKJexN6Ys5NdJy/a9D2rlfXn9W7V87XtrFmzSEhIYOvWrZw7d44GDRrQokUL1qxZw5EjR9i+fTtnz56latWqDBky5G+vTUhI4MSJE+zYYQxcpKSkEBgYyOjRo/noo4+IjY392/ZJSUkMHz6clStXEh0dzfnztjn/zc8ZehlgmVJqG/Anxhj6XKXUm0qp+y3bjAeClFIHgOeAl2ySLi+lomHoIghvCLOGwYoPQdZIFXfplVde4ZVXXjE7hihiq1evpl+/fri6ulK6dGlatmzJn3/+yerVq+nVqxcuLi6EhYXRunXr215boUIFDh06xNNPP82CBQvw9/e/477Wr19PixYt/ppjXqpUKZt8hjzP0LXW24C6Ofz8tVseXwPMOT32KQUDZ8Hsp2HZ25ByBLp+Bq75v7tKCGG+/J5JF5bcBhXyM9hQsmRJtm7dysKFC/nqq6+YMWMGEyZMuOO+CmMioHN0W3TzhO7fQst/wpYpMPUhuJZqdiohhANp0aIFcXFxZGZmkpSUxMqVK2nYsCHNmzfn559/JisrizNnzuR4B/G5c+fIysqiZ8+evPXWW2zevBkAPz8/Ll26dNv2TZo0YcWKFRw+fBjAZkMuznPrv1LQ+hUILA9znoHxHWDADAiMNDuZEMIBdO/enXXr1lG7dm2UUnzwwQeEhYXRs2dPli5dSo0aNahSpQqNGjUiICDgb689ceIEjz76KFlZWQC8++67AAwePJiRI0fi7e3NunXr/to+JCSEsWPH0qNHD7KysggNDWXx4sV3/RlUYV+7zE1sbKwutAUuDi2HuEfA3Qv6x0HZ20aMhMiVLHBRdHbv3k3VqlXNjpGny5cv4+vrS3JyMg0bNmTNmjWEhYUV+n5zOj5KqU1a69ictneeM/RbVWgFQxcaM2AmdoaHJkBMJ7NTCQfx2WefmR1B2JmuXbuSkpLCjRs3+Pe//10kxbwgnLOgA4RWhWFLYVofmN4fOr4PjUaYnUo4AGmbK7JzlN/WnOOiaG78Shs3IFXpCPNfhAWvQFam2amEnVuyZAlLliwxO4YQVnPeM/SbPEpAnymw8BVY/xWkHIUe34GHj9nJhJ16++23AVm5SDge5z5Dv8nFFTq9Dx3fgz2/w/dd4fJZs1MJIYRNFY+CflPjx6HvVDizC8a1gaS9ZicSQgibKV4FHeCeLvDo70Y/9fHt4PAqsxMJIUx05MgRatSocdvPX3vtNYe7llL8CjpAufrGDBjfMPihO2ydbnYiIYSdefPNNx3uOkrxLOgAJcsbjb0iG8Mvj8Hy96WxlwDg22+/5dtvvzU7hihCmZmZDB8+nOrVq9O+fXvS0tIYPHjwX21yo6KieP3116lXrx41a9Zkz549gNE1sV27dtSrV4/HHnuM8uXLc+7cOQCmTJlCw4YNqVOnDo899hiZmYU/w875Z7nciXcgPDzLaBWw/B24cAS6fQ5uHmYnEyaKiYkxO0LxNP8lOL3dtu8ZVhM6vZfnZvv372fatGl899139O7dm59//vm2bYKDg9m8eTNjxozho48+Yty4cbzxxhvcd999vPzyyyxYsICxY8cCxh2ecXFxrFmzBnd3d5544gmmTp3KI488YtvPl03xLuhgFO8Hv4aS0UZRTz1uTHP0DjQ7mTDJnDlzAOjWrZvJSURRiY6O/uuGsvr163PkyJHbtunRo8dfz8+aNQswWu7+8ssvAHTs2JGSJUsCsHTpUjZt2kSDBg0ASEtLIzQ0tLA/hhR0wGjs1eqfRiOv2U/DhA7Qf4YxLCOKnY8//hiQgl7k8nEmXVg8PT3/euzq6kpa2u2L0N/cxtXVlYyMDODOLXcHDRr0V5OuolJ8x9BzUqcfDPwFLp0ylrY7sdnsREIIO9a8eXNmzJgBwKJFi7hw4QIAbdq0YebMmZw9a9zvcv78eY4ePVroeaSgZxd9LwxdbHRqnNTFuBFJCCFy8Prrr7No0SLq1avH/PnzKVOmDH5+flSrVo23336b9u3bU6tWLdq1a8epU6cKPY9zts+1hctn4cc+cHILdHzXuClJFAvSPrfoOEr73Nxcv34dV1dX3NzcWLduHY8//jgJCQk2e39pn2srvqFGY69Zw2HBS8YMmA7vGG0EhBACOHbsGL179yYrKwsPDw++++47U/NIQb8TDx/oPRkW/dvS2Os49PzOaPglnNYPP/xgdgThICpXrsyWLVvMjvEXGUPPi4srdHwHOn0I++Yb4+qXzpidShSiiIgIIiIizI5RbJg17GvvCnJcpKDnV6MR0PdHo6HXuLZwdrfZiUQhiYuLIy4uzuwYxYKXlxfJyclS1LPRWpOcnIyXl5dVr5OLotY6ucW4WJp+Dfr8ABVamp1I2JhcFC066enpJCYmcu3aNbOj2B0vLy/Cw8Nxd3f/28/loqgtla0Lw5bA1N4wpQfc/yXU6W92KiEckru7O9HR0WbHcBp5DrkopSKUUsuUUruVUjuVUs/msE0rpVSqUirB8vVa4cS1E4GRxiLU5ZvBr4/DsneksZcQwnT5OUPPAJ7XWm9WSvkBm5RSi7XWu7Jtt0pr3dX2Ee2UVwAMmAlz/wEr3ocLR+H+L8DNM+/XCiFEIcizoGutTwGnLI8vKaV2A+WA7AW9+HHzgAdGQ6ko+ONtuHjCGFf3Lml2MiFEMWTVRVGlVBSwEqihtb54y89bAT8DicBJ4AWt9c4cXj8CGAEQGRlZvyh6GxSZbTPgtyehZBQM+Mn4Uzikm/2sg4ODTU4ixO3udFE03wVdKeULrAD+q7Wele05fyBLa31ZKdUZ+FxrXflO7+ews1zu5MhqmD4AXN2h33QIz/GYCyFEgd2poOdrHrpSyh3jDHxq9mIOoLW+qLW+bHk8D3BXShW/05uo5sYMGI8Sxg1Iu+eYnUgUwKRJk5g0aZLZMYSwWn5muShgPLBba/1JLtuEWbZDKdXQ8r7JtgzqMIIrw9AlULoGxA2EdV/JDBgHIwVdOKr8zHJpBgwEtiulbrYRewWIBNBafwM8BDyulMoA0oC+ujjf+uUbAoPnwqwRsPAVOH8YOr4HrjLtXwhRePIzy2U1oPLYZjQw2lahnIK7N/T6Hpa8Bmu/NJa26zkePH3NTiaEcFLSy6UwubhA+7eh80ewfxFM6gyXTpudSgjhpKSgF4WGw41ZL+cOwHdt4IxM4RdC2J4U9KJSpQMMmQ9ZGcYi1AeXmZ1I5GLevHnMmzfP7BhCWE0KelEqUxuGL4WACJj6EGyWhRTskY+PDz4+PmbHEMJqUtCLWkA4DFkA0S1g9lOw9C2Z1mhnxowZw5gxY8yOIYTVpKCbwcsf+s+Aeo/Aqo+MdUszrpudSljMmDGDGTNmmB1DCKvJxGizuLpDty+Mni9L34TUE9B3KviUMjuZEMJByRm6mZSCe5835qefiIfx7eD8IbNTCSEclBR0e1DzIXhkNlxNNtYrPb7R7ERCCAckBd1elG9i9IDx9Ifvu8Gu38xOJIRwMFLQ7UlwJaNbY1gtmDEI1nwhM2BMsHz5clkgWjgkKej2pkQwDJoN1R6Axf+G35+HzAyzUwkhHIDMcrFH7t7w0ERYGgVrPjMaez00URp7FZGPPvoIgBdeeMHkJEJYR87Q7ZWLC7R7A7p+CgeWwsSOcPGk2amKhblz5zJ37lyzYwhhNSno9i52CPSPM3qqj2sLp3eYnUgIYaekoDuCyu2MdgFaw4SOcGCJ2YmEEHZICrqjCKtpzIApWR6m9oZNk8xOJISwMw5Z0JMvF9O+JwHl4NH5ULE1zHkWlrwBWVlmp3I63t7eeHt7mx1DCKs5XEFfsOM0LT5Yxm8JJ8yOYg4vf+gXB/UHw+pP4OehkH7N7FROZf78+cyfP9/sGEJYzeEKep2IQKqXDeDZ6Qm8PGsb19IzzY5U9FzdoOtn0PYN2DkLJj8AV5LNTiWEMJnDFfSwAC9+HN6Ix1tVZNrG43Qfs5ZDSZfNjlX0lILmo4z56Se3GI29kg+ancopvPXWW7z11ltmxxDCag5X0AHcXF34Z8d7mPhoA06nptHty9XM2VpM52jX6AGD5kDaBWNa47ENZidyeEuXLmXp0qVmxxDCag5Z0G9qHRPK78/cyz1l/Hl62hZe/XV78RyCiWxkzIDxLmk09toxy+xEQggTOHRBBygb6M30EY15rEUFpqw/Rs+v13I0+YrZsYpeUEWjqJerBzMfhdWfSmMvIYqZPAu6UipCKbVMKbVbKbVTKfVsDtsopdQXSqkDSqltSql6hRM3Z+6uLrzcuSrjHokl8UIaXb9Yzbztp4oygn3wKQUDf4UaPWHJf2DuKGnsJUQxkp8z9Azgea11VaAx8KRSqlq2bToBlS1fI4CvbZoyn9pWK83vzzSnYqgvT0zdzOu/7eB6RjEbgnH3gh7joPlzxs1H0/rA9Utmp3IoQUFBBAUFmR1DCKspbeWv5Uqp34DRWuvFt/zsW2C51nqa5fu9QCutda6nybGxsTo+Pr5gqfNwIyOL9xfsYfzqw9QKD2B0v3pEBvkUyr7s2qbvYe4/ILSa0Q8moJzZiYQQd0kptUlrHZvTc1aNoSulooC6QPapFOWA47d8n2j5WfbXj1BKxSul4pOSkqzZtVU83Fz4d9dqfDuwPofPXaHLl6tYsON0oe3PbtUfBAN+ggtHYFwbOLXN7ERCiEKU74KulPIFfgZGaa0vZn86h5fcduqvtR6rtY7VWseGhIRYl7QAOlQPY94z9xIdXIKRUzbxxpyd3MgoZrfKV2pjNPZSLjCxE+xfnPdrirmXX36Zl19+2ewYQlgtXwVdKeWOUcynaq1zmhOXCETc8n04YBcTwyNK+fDTyCYMbhrFxDVH6PXtOo4lXzU7VtEKqwHDlkKpaPixD8RPMDuRXVu3bh3r1q0zO4YQVsvPLBcFjAd2a60/yWWz2cAjltkujYHUO42fFzVPN1f+c391vh5Qj0NJl+nyxSrmbrOL/98UHf8yRmOvSm2McfXFr0ljLyGcTH7O0JsBA4H7lFIJlq/OSqmRSqmRlm3mAYeAA8B3wBOFE/fudKpZhnnP3EvFUF+e+nELL8/aRtqNYjQLxtMP+k6D2KGw5nNjvnp6mtmphBA2kueaolrr1eQ8Rn7rNhp40lahCtPNIZhPFu/j6+UHiT9ygdH96xET5md2tKLh6gZdPjaGXxa9aixr12+asTi1EMKhOfydogXhbukFM3lIQy5cvcH9o1fz44ZjWDuF02EpBU2fhl7fw+ltRg+YcwfMTmU3wsPDCQ8PNzuGEFazeh66rRTmPHRrnL10jednbGXV/nN0qVmGd3rUJMDb3exYRef4nzCtL+hMYzimfBOzEwkh7sBm89CdUaifF98/2pCXOt3Dwp2n6fLFKjYfu2B2rKIT0QCGLQafIJh8P2yfaXYiIUQBFfuCDuDiohjZsiIzRhpnp72/WcfXyw+SlVVMhmBKVYChi6FcrLEC0qqPi3Vjr1GjRjFq1CizYwhhNSnot6gXWZLfn7mXDtXDeH/BHgZN3EjSpWKyfqlPKXjkV6jZC5a+CbOfhsx0s1OZIiEhgYSEBLNjCGE1KejZBHi7M7p/Xd7pXpONh8/T6fNVrNpfeG0K7IqbJ/T4Dlq8CFt+gKm94Fqq2amEEPkkBT0HSin6N4pk9lPNKVXCnYHjN/Le/D3Fo22AUnDfq3D/aDiyCiZ0gtREs1MJIfJBCvodxIT58duTzenXMJJvVhyk59drOVhc1i+tNxAGzITU4/BdGzgpQxBC2Dsp6Hnw9nDl3R41+ebh+hy/cJWuX6xm2sZiMme9YmsYshBc3GBiZ9i30OxERaJKlSpUqVLF7BhCWK3Yz0O3xunUazz/UwJrDiTToXpp3utRi5IlPMyOVfgunYYfe8Pp7dDpA2g43OxEQhRbMg/dRsICvPhhSCP+1bkqf+w5S8fPV7LmwDmzYxU+vzAYPA8qt4d5L8DCf0ljLyHskBR0K7m4KIa3qMAvTzTD19ONAeM28M683c6/1J2nL/T9ERqOgHWj4adBTtvYa8SIEYwYMcLsGEJYTQp6AdUoF8Dcp+9lQKNIxq48RI8xazlw1skvmLq4GkMuHd6F3XPg+25w2fmmdO7bt499+/aZHUMIq0lBvwveHq78t3tNxg6sz8mUNLp+uYop64869wVTpaDJE9DnBzi9w1jaLkmKnxD2QAq6DbSvHsbCUS1oEFWKV3/dwfDJm0i+7OR3mFbtBoN/h/SrML4dHFljdiIhij0p6DYS6m80+Xq1S1VW7kui4+erWLHP+YYj/ia8PgxbAr6h8MODsG2G2YmEKNakoNuQi4ti2L0V+PXJZgR6uzNowkZe+22Hc6+KVDIKhi6CiEYwazis+NDhG3vVqVOHOnXqmB1DCKvJPPRCci09k/cX7GHimiNUCC7BJ33qUCci0OxYhSfjhtHQa9t0qPMwdPsMXItRX3khiojMQzeBl7srr3erztRhjUhLz6Tn12v5ZPE+0jOddP62mwd0/wZavgQJU2BKT0hLMTuVEMWKFPRC1qxSMAtGteD+2mX5Yul+y/TGS2bHKhxKQeuX4cGv4egamNARUo6ZncpqDz/8MA8//LDZMYSwmhT0IhDg7c6nfeowZkA9Ei9cpcsXq5mw+rDzLqBRpz88PMtYgHpcWzi5xexEVklMTCQxUTpMCscjBb0Ida5ZhoWjWtC0YhBvzt3FwAkbOJninHdbUqGlcbHU1dNo7LV3vtmJhHB6UtCLWKi/FxMGN+DdHjXZciyFDp+t5Jctic55M1LoPca0xpAYmN4fNow1O5EQTk0KugmUUvRrGMn8Z++lSmk//hG3lSd/3Mz5KzfMjmZ7fqWNG5CqdIL5L8KClyHLiadxCmGiPAu6UmqCUuqsUmpHLs+3UkqlKqUSLF+v2T6mcyofVIIZjzXh/zrGsHjXGTp8tpLFu86YHcv2PEoYrQIaPQ7rx8CMR+DGVbNT5apJkyY0adLE7BhCWC3PeehKqRbAZWCy1rpGDs+3Al7QWne1ZsfOPg/dWjtPpvL8jK3sOX2J7nXL8Xq3agT6OGGv9fXfwIKXoGxd6B9n3GUqhMi3u5qHrrVeCZy3eSrxN9XLBjD7qeY806Yyc7aepN2nK1nijGfrjUdC36lwdrelsddesxMJ4TRsNYbeRCm1VSk1XylVPbeNlFIjlFLxSqn4pCQn73NSAB5uLjzXrgq/PtmMoBIeDJscz3NxCaReTTc7mm3d0wUe/R3SrxmNvQ6vNDvR3/Ts2ZOePXuaHUMIq9mioG8GymutawNfAr/mtqHWeqzWOlZrHRsSEmKDXTunGuUsZ+v3VeK3rSdp9+kKlu52srP1cpbGXn5l4IcesHW62Yn+kpycTHJystkxhLDaXRd0rfVFrfVly+N5gLtSKviukxVzHm4uPNc+ht+ebEapEh4M/T6e52Y42dl6yfLGItTlm8Avj8Hy9xy+sZcQZrrrgq6UClNKKcvjhpb3lNMbG7l5tv70fZX4LeEk7T9bwR97nOhs3TsQBvwMtfvD8nfh18eNRl9CCKvlZ9riNGAdEKOUSlRKDVVKjVRKjbRs8hCwQym1FfgC6Kud8i4Z83i4ufB8+xh+faIZgd4eDJkUz/MztpKa5iRn624e8OAYaP0v2DoNpvSQxl5CFIBbXhtorfvl8fxoYLTNEolc1QwPYPbTzRj9xwHGLD/Iqv1JvPlADTrWCDM72t1TClr+HwSWh9+ehPHtYcBPxrBMEWvTpk2R71MIW5B+6A5qx4lU/m/mNnadukinGmG88UB1Qv28zI5lG4dXQdwAcPUw5qqXq292IiHshvRDd0I1ygXw21PN+L+OMSzdc5a2H68g7s9jztETJvpeGLoY3L1hYhfY87vZiYRwCFLQHZi7qwtPtKrEgmfv5Z4y/vzz5+0MGLeBo8lXzI5290JiYNhSKF0Npg+A9V8X2a47depEp06dimx/QtiKFHQnUCHEl+nDG/Pf7jXYnphKh89WMnblQTIcfXUk31AYNNe4EWnBSzD/n0XS2CstLY20NCdtayycmhR0J+HiohjQqDyLn2tJ80ohvDNvD93HrGXnyVSzo90dDx/oPRmaPAUbvoG4h+GGE/wGIkQhkILuZMICvPjukfp81b8ep1LTuH/0Gj5YsIdr6Q7cstbFFTr8Fzp/BPsWGAtmXHKiufhC2IgUdCeklKJLrTIsea4l3euWY8zyg3T6fBVrD5wzO9rdaTgc+k6Dc/uMpe3O7jY7kRB2RQq6Ewv08eCjXrX5YWhDsrSm/7gN/CMugXOXr5sdreBiOsKj8yDzOozvAIeW23wXXbt2pWtXq7pBC2EXZB56MXEtPZOvlh3gmxUH8XZ35aVOVenbIAIXF2V2tIJJOQ5Te0Hyfuj2BdQdYHYiIYqEzEMXeLm78nz7GOY/24JqZf155ZftPPTNWnafumh2tIIJjIChCyGqOfz2BPzxX2nsJYo9KejFTKVQX6YNb8zHvWpzJPkqXb9czbvzdnP1RobZ0aznFQADZkLdh2HlB0bHxoy7H05q1aoVrVq1uvt8QhQxKejFkFKKnvXDWfpcS3rVD+fblYdo94mDrpDk6g73j4b7XoVtcUZv9bQLZqcSwhRS0IuxkiU8eK9nLWaObIKvpxvDJsczYnI8J1Mc7KYapaDFi9BjHCRuhHHt4Pxhs1MJUeSkoAtio0ox95nmvNTpHlbuT6LNxyv4atkBrmc42Nz1Wr1g4K9wJcmY1pgoF91F8SIFXQBGX5iRLSuy5LmWtKwSwocL99Lh05Us23vW7GjWiWpmLG3n6QuTusCu2WYnEqLISEEXfxNe0odvBtZn8pCGuLgoHp34J8O+j+dY8lWzo+VfcGWjsVdYTZjxCKwdbdUMmN69e9O7d+9CDChE4ZB56CJXNzKymLjmMJ8v3U9GlmZky4o83rIi3h6uZkfLn/Q0mDUCds+GBsOh43vgmueaLkLYNZmHLgrEw82Fx1pW5I/nW9GxehhfLN1P209WsGDHacfou+7uDb2+h6bPwJ/fGYtmXL+c58uuXr3K1asO9BuJEBZS0EWewgK8+KJfXaaPaIyvpxsjp2zikQkbOZiUd3E0nYsLtH8LunwM+xfBxE5w8dQdX9K5c2c6d+5cRAGFsB0p6CLfGlcI4vdnmvN6t2okHEuhw6creXPOLlKvOsBi1Q2GQb84SD5ozIA5s9PsRELYnBR0YRU3VxcebRbNHy+0oldsOBPXHqblR8v4fu0R0u19QY0q7WHIfNCZMKEjHPzD7ERC2JQUdFEgIX6evNujFr8/fS9Vw/x5ffZOOn2+yv6nOZapbUxrDIgwmntt/sHsRELYjBR0cVeqlfXnx+GNGDuwPhmZWTw68U8GTdjI/jOXzI6Wu4BwGLIAolvC7Kdg6VvS2Es4BZnDJe6aUor21cNoFRPK5HVH+Hzpfjp+vooBjSIZ1bYKpUp4mB3xdl7+0D8Ofn8OVn0EF47Ag2PAzZPBgwebnU6IAslzHrpSagLQFTirta6Rw/MK+BzoDFwFBmutN+e1Y5mH7rzOX7nBp4v38ePGY/h4uPJsm8oMbFIeTzc7nL+uNaz+FJa+AZFNoe9U8ClldiohcnW389AnAR3v8HwnoLLlawTwtbUBhXMpVcKDtx6swfxn76VuZEne/n03bT5ewa9bTpCVZWdDG0rBvc/BQxPgxCYY345z++M5d87Bl+sTxVKeBV1rvRI4f4dNHgAma8N6IFApVcZWAYXjqlLaj8lDGjJ5SEP8vdwZFZdAt9GrWbU/yexot6vREx75Da4m81DHe3moWwezEwlhNVtcFC0HHL/l+0TLz26jlBqhlFAPxB0AABdRSURBVIpXSsUnJdnhf9SiULSoEsLcp5vzaZ/apFxNZ+D4jQwcv4EdJ1LNjvZ35ZsYPWBcXOH0dtj5q9mJhLCKLQp6TotS5vh7tdZ6rNY6VmsdGxISYoNdC0fh4qLoXjecpc+35NUuVdl+IpWuX65m1PQtHD9vR7fZB1WEsNrg4Qs/DYI1n8sMGOEwbFHQE4GIW74PB07a4H2FE/Jyd2XYvRVY8WJrRrasyPwdp2nz8QremruLC1dumB3P4OoOYTWgendY/JoxEybTAZfoE8WOLQr6bOARZWgMpGqt79wsQxR7Ad7uvNTpHpa90IoH6pRl4prDtPhgGZ8t2cela3bQSkC5QM8J0GwUxE+A6f3guh3PrReC/E1bnAa0AoKBM8DrgDuA1voby7TF0RgzYa4Cj2qt85yPKNMWxa32nr7EJ4v3snDnGQJ93HmsRUUGNS2Pj0fR3yoRFxcHQJ8+fYwfxE+E35+H0tWg/wzwL1vkmYS46U7TFqUfurAr2xNT+XjxXpbvTSLY15MnW1ekX8NIvNxNnsO+f4kxpu7pDwN+MoZkhDCBFHThcOKPnOejRXtZf+g8ZQK8ePq+yvSKDcfdtfC7VRw/bkzaioiI+PsTp7fD1N7G0EvvSVCpbaFnESI7KejCYa09cI4PF+1ly7EUIkv58GybyjxYtxyuLjlNrrKNVq1aAbB8+fLbn7x40ijqZ3dB10+g/uBCyyFETmTFIuGwmlYKZtbjTZk4uAF+Xm48/9NW2n2ygp83JZJhRrte/7JGC96KrWHOs7DkP5Bl522DRbEhBV3YPaUUre8JZc5Tzfnm4Xp4urvy/E9bue/jFUzfeIwbGUVcUD39jMUyYocYfWB+Hgrp14o2gxA5kIIuHIaLi6JjjTLMe6Y54x6JpaSPOy/N2k7rj5bzw/qjXM/ILLowrm7Q5RNo9ybsnAWTH4AryUW3fyFyIAVdOBylFG2rlebXJ5sx6dEGlPb35N+/7qDlB8uZuOYw19KLqLArBc2ehV6T4OQWGN/WWOJOCJPIRVHh8LTWrD2YzBdL97Ph8HmCfT0Z0SKa/o3K4+tp/Tz2OXPmANCtW7f8v+jYBuPmI62h3zSIbGz1foXID5nlIoqNDYeS+fKPA6w+cA4/LzcGNi7Po82iCfHzLPydJx80lrVLTYTuXxsdHIWwMSnootjZejyFb1ceZP6O07i7utCzXjgjWlQgOrhEnq/du3cvADExMdbv+Op5mN4fjq2Dtv8xWgeowptiKYofKeii2Dp87grfrTrEzE2JpGdm0bF6GCNbVqR2RGCur7njPPT8SL8Gvz0BO3425ql3/ti4iCqEDdypoMu/MuHUooNL8E73moxqW5nv1x7hh3VHmb/jNE0qBPFYywq0rBKCsvUZtLsX9BgHJaNg1ceQcty4cOrlb9v9CJGNzHIRxUKonxcvdriHtS+34dUuVTl87gqDJ/5Jh89WMm3jMdvPjHFxgTavwf1fwqHlMLETpJ6w7T6EyEYKuihWfD3dGHZvBVb+X2s+7lUbNxcXXp61nSbvLuXDhXs4nWrjG4TqPWI087pwFMa1gVPbbPv+QtxCxtBFsaa1ZuPh80xYc5hFu87gqhTXfn2NsAAv4tettt2Ozuw0esBcSzGGXyq3s917i2JFerkIkQulFI0qBPHtwFhWvNCaQU2jcK/fk+PlO9Hz67XM3XbSNj1jSleHYUugVAX4sQ/8Of7u31OIbOQMXYhsLl/PYGb8cSauPcLR5KuU9vekX8NI+jaIJCzA6+7e/PplmDkE9i+Eps9A2zeM8XYh8kmmLQphhYSEBABq1qrNsj1nmbLhKCv2JeGiFO2qlubhxuVpWjEIl4K28M3MgAX/hD/HQbUHofs34O5tw08gnJkUdCGskNM89GPJV5m68Sg/xSdy/soNooNLMKBRJA/VDyfQx8P6nWgN676CRa9CeAOjXUCJYNt8AOHUpKALYYU73Vh0LT2T+TtOMWX9MTYdvYCnmwtda5Wlf6NI6kUGWj+nfddvMGsE+JWBATMhuNLdfwDh1OTGIiFsxMvdle51w+leN5zdpy4yZf1Rft1ygp83J1Ip1Jc+sRF0r1eOYN989o6p9gD4lYVpfY1ujX1/hPJNC/dDCKclZ+hCZGPtrf+Xr2cwd+tJZsQfZ/OxFNxcFG2qhtKnQQQtKofglp91UM8fNhp7pRyFB7+Gmg8V/AMIpyZn6EIUIl9PN/o2jKRvw0j2n7nEjPjjzNp8goU7z1Da35Oe9cLpHRtB1J0ag5WKhqGLIO5hYwWkC0fg3uelsZewipyhC5HN2rVrAWjatOBDHzcysvhjz1lmxB9n+d6zZGloGFWKB+uWo0vNMgT4uOf8wozr8NtTsH0G1B0IXT8F11y2FcWSXBQVwkRnLl5j5qZEZm1O5GDSFTxcXbjvnlAerFuO1veE4Onm+vcXaA3L3oGVH0CF1tD7e/AKMCe8sDt3XdCVUh2BzwFXYJzW+r1szw8GPgRudh8arbUed6f3lIIu7JUtztBzorVmx4mL/LLlBLO3nuTc5ev4e7nRpVZZutctR2z5kn+f275lCsx5FoKrQP8ZEBhh0zzCMd1VQVdKuQL7gHZAIvAn0E9rveuWbQYDsVrrp/IbSgq6sFd33Q89HzIys1hzMJlft5xgwY7TpKVnUi7QmwfqlKVLrTJUK+NvTIE8tBziBoK7D/SPg7J1Ci2TcAx328ulIXBAa31Ia30DmA48YMuAQhQ3bq4utKwSwqd96hD/als+61OHiqG+fLvyEF2+WM19H6/gw4V72OVVDz1kgTGOPrEz7F1gdnRhx/JT0MsBx2/5PtHys+x6KqW2KaVmKqVy/N1QKTVCKRWvlIpPSkoqQFwhnE8JTzcerFuOyUMasvGVNrzTvSblAr35ZsUhOn+xivt+SOLrSt+SFlABPb0fbPzO7MjCTuWnoOc0byr7OM0cIEprXQtYAnyf0xtprcdqrWO11rEhISHWJRWiGAjy9aR/o0imDGv0V3EvG+jFh2tTqJf4D9a51Id5L3Du5xfRWTZelEM4vPzMQ08Ebj3jDgdO3rqB1jr5lm+/A96/+2hCFG83i3v/RpGcu3ydhTtPM2bbG+w7+jmDt49l2c5trKn5X1rViKJRhVK45+cGJuHU8nNR1A3jomgbjFksfwL9tdY7b9mmjNb6lOVxd+CfWuvGd3pfuSgq7NXNbot16tjnBcjky9c5Nu9jau/6gG26IkOvP88NryBax4TSrlppWsWE4Oclc9edlS2mLXYGPsOYtjhBa/1fpdSbQLzWerZS6l3gfiADOA88rrXec6f3lIIuxF3aPRf98zDSPIP4quy7TD/kTfKVG7i7KhpXCDKKe5VQIoN8zE4qbEhuLBLCCkuWLAGgbdu2JifJh8RNMK0PZKaT2XsKm12qs3jXGRbvOsPhc1cAqBBcgpYxIbSKCaVRdCm83F3zeFNhz6SgC2GFopiHblMXjhjrlZ4/BA98BbX7AHD43BWW7z3L8r1JrD+UzPWMLLzcXWhSIYhWMaG0igmhfNAd+ssIuyTNuYRwZiWjYOhC4wakX0YYHRtbvEh0cAmig6N5tFk019IzWXcomRV7k1i+9yzL9hqXwKKDS9CsUhDNKgbTpGJQwRbrEHZDCroQzsC7JDw8C2Y/Dcv+a5y1d/0M3IwC7eXuSuuYUFrHhALVOWI5e1+xL4lfNp9gyvpjKAXVy/rTrGIwTSsF0yCqJD4eUiIcifxtCeEs3DyM9UlLRcPydyE1EXpPBu/A2zaNCi7B4OBoBjeLJj0zi22JKaw5kMyaA+eYuOYI3648hLurom5kSZpWDKJZpWBqhQfc3khM2BUZQxciG4cbQ89Jwo8w+xkIqgQDZkBgZL5fmnYjkz+PnGfNwXOsPZDMjpOpaA0ebi7UCQ+kQXRJGkSVon75kjI90gRyUVQIK+zduxeAmJgYk5PcpcMrYfrD4O4F/aZDuXoFepvUq+lsOJzMn0fOs/HIBXacSCUzS+OioGoZfxpElaJhdCkaRJUixC+fS++JApOCLkRxdXaPsbTd1XPw0ASI6XTXb3n1RgZbjqWw4fB5/jx8ni3HL3AtPQuAqCAf6kaWpG5kIHUiArknzB8PN7mD1ZakoAthhTlz5gDQrVs3k5PYyKUzxiLUpxKg43vQ6DGbvv2NjCx2nEzlz8PniT96gYTjKSRdug4YwzQ1yvpTN7IkdSKMIh9e0ttoDSwKRAq6EFZwijH07G5chVnDYc9caPwEtH8bXArnAqfWmhMpaSQcTyHhWAoJx1PYfiKV6xnGWXywryd1IgKpFR5AjXL+VC8bQKifpxT5fJJ56EIUdx4+xoyXRa/C+jGQcgx6fGf83MaUUoSX9CG8pA9da5UFID0ziz2nLpFw/AJbjhtFfumeM9w8nwz29bQUd39qlA2gRrkAOZMvACnoQhQXLq7Q8V3jRqQFL8GkLsYqSL6hhb5rd1cXaoYHUDM8gIFNjJ9dvp7B7lMX2XEilZ0njT9X7T9HZpZR5f293KheNoDqZf2JCfPjnjB/KoX64u0hUydzIwVdiOKm0WMQEAE/D4VxbWDATAgp+hk9vp5uNIgyZsfcdC09k72nLxkF/mQqO0+kMnn9UW5YhmuUgvKlfIgJ8yOmtB8xYf7EhPkSFVQCN2kfLAVdiGLpns4w+Hf4sQ+Mbwd9pkB0C7NT4eXuSu2IQGpH/O9mqIzMLI6ev8q+05fYc/oS+85cYu+ZSyzedQbLyTweri5UDPUlprQvFUN8qRDiS8XQEkQFlShWzcjkoqgQ2Rw/bqy4GBGR40qKziXlmNHYK/kA3P8l1OlndqJ8u5aeyYGzl9l7S5Hfd/oSJ1Ov/bWNUhBe0psKwTcLfQkqhJSgUogvIQ56IVZmuQghcpeWAjMGGjcitXoZWv7TqIQO6uqNDA6fu8LBpCscSrr815+Hkq6Qlv6/Zfv8PN0oH+xD+VIliAzyIbKUD+VL+RBRyoeygd64utjnMZCCLoQV4uLiAOjTp4/JSYpQxg2YOwoSpkLtftDti78aezmLrCzN6YvXOJR0hYNJlzmYdJmjyVc5dv4qiReukp75v1ro7mrM1ImwFPnIUj5EBvkQUdKHcoHe+Hu7mXZ2LwVdCCs45Tz0/NAaVn4Ey96GqHuhzw9GF8diIDNLcyo1jWOWAn/0/NX/PU6+wsVrGX/b3tfTjbKBXpQN9KZsoDflAr2N7wOM78MCvAptjVeZhy6EyJtS0PJFo5HXb0/C+A4w4CcoWd7sZIXO1eV/c+eb5vB86tV0jp6/wvHzaZxKTeNEShonLqRxMjWNbYmpnL9y42/buygo7e/1V3EP87d8BRhfUUElCqXvjRR0IcTf1e4DAeVg+gBjWmO/OAivb3YqUwX4uFPLJ5Ba4be3IgajQ+XJ1DROphhfJy6kcSLlGidSrrLr5EX+2H32b+P3j7WowMudq9o8pxR0IcTtoprD0MUw9SHjBqSe46BqV7NT2S1vD1cqhhgzaXKiteZiWganL17j9MVrlAnwKpQcMhNfCJGzkCowbCmUrg5xD8P6r81O5LCUUgT4uBMT5kfLKiFUKe1XKPuRM3Qhspk5c6bZEeyHbwgMmmOsVbrgJTh/2GgfUEiNvcTdkTN0IbIJDg4mODjY7Bj2w8MHek2GJk/Bxm+NsfUbV8xOJXIgBV2IbCZNmsSkSZPMjmFfXFygw3+h80ewfyFM7Gz0WRd2JV8FXSnVUSm1Vyl1QCn1Ug7Peyql4izPb1BKRdk6qBBFRQr6HTQcDn2nwbn9xgyYs7vNTiRukWdBV0q5Al8BnYBqQD+lVLVsmw0FLmitKwGfAu/bOqgQwk7EdIRH50FmOoxvD4eWm51IWOTnomhD4IDW+hCAUmo68ACw65ZtHgD+Y3k8ExitlFLarNtQhRCFq2wdGLYEfuwNkx+A4CqgZAQ33+oOhKZP2fxt81PQywHHb/k+EWiU2zZa6wylVCoQBJy7dSOl1AhgBEBkZGQBIwsh7EJgBAxZACs/NLo2ivwrpEVF8lPQc+pAk/3MOz/boLUeC4wFo5dLPvYthLBnXgHG+qTCLuSnoCcCtzaGDgdO5rJNolLKDQgAztskoRBFbN68eWZHEKJA8jPo9SdQWSkVrZTyAPoCs7NtMxsYZHn8EPCHjJ8LR+Xj44OPj+0XTxaisOV5hm4ZE38KWAi4AhO01juVUm8C8Vrr2cB44Ael1AGMM/O+hRlaiMI0ZswYAJ544gmTkwhhHemHLkQ2xbYfunAId+qHLvOMhBDCSUhBF0IIJyEFXQghnIQUdCGEcBKmXRRVSiUBRwv48mCy3YVqJ+w1F9hvNsllHcllHWfMVV5rHZLTE6YV9LuhlIrP7Sqvmew1F9hvNsllHcllneKWS4ZchBDCSUhBF0IIJ+GoBX2s2QFyYa+5wH6zSS7rSC7rFKtcDjmGLoQQ4naOeoYuhBAiGynoQgjhJByioCulPlRK7VFKbVNK/aKUCsxluzsuZl0IuXoppXYqpbKUUrlOQVJKHVFKbVdKJSilCr0jmRW5ivR4WfZZSim1WCm13/JnyVy2y7QcrwSlVPZ2zbbKYpeLn+cj12ClVNItx2dYEeWaoJQ6q5TakcvzSin1hSX3NqVUPTvJ1UoplXrL8XqtiHJFKKWWKaV2W/57fDaHbWx7zLTWdv8FtAfcLI/fB97PYRtX4CBQAfAAtgLVCjlXVSAGWA7E3mG7I0BwER6vPHOZcbws+/0AeMny+KWc/i4tz10u5Bx5fn7gCeAby+O+QFwRHJ/85BoMjC6qf0+37LcFUA/YkcvznYH5GCuYNQY22EmuVsBcE45XGaCe5bEfsC+Hv0ubHjOHOEPXWi/SWmdYvl2PsWpSdn8tZq21vgHcXMy6MHPt1lrvLcx9FEQ+cxX58bJ4APje8vh74MEi2GdO8vP5b806E2ijlMppucWizmUKrfVK7rwS2QPAZG1YDwQqpcrYQS5TaK1Paa03Wx5fAnZjrL98K5seM4co6NkMwfg/WnY5LWad/eCZRQOLlFKbLAtl2wOzjldprfUpMP7BA7mtluullIpXSq1XShVG0c/P5//b4ufAzcXPC1N+/156Wn5Fn6mUisjheTPY83+DTZRSW5VS85VS1Yt655bhurrAhmxP2fSY5WdN0SKhlFoChOXw1L+01r9ZtvkXkAFMzektcvjZXc/JzE+ufGimtT6plAoFFiul9ljOKszMVSjHC+6czYq3ibQcswrAH0qp7Vrrg7bIZ2Gzxc9tLD/7nANM01pfV0qNxPgt4r5CzpUfZhyv/NiM0f/kslKqM/ArULmodq6U8gV+BkZprS9mfzqHlxT4mNlNQddat73T80qpQUBXoI22DD5lk5/FrG2eK5/vcdLy51ml1C8Yv1bfVUG3Qa5COV5w52xKqTNKqTJa61OWXy3P5vIeN4/ZIaXUcoyzG1sWdHtd/DzPXFrr5Fu+/Q7jupI9KLR/U3fj1iKqtZ6nlBqjlArWWhd60y6llDtGMZ+qtZ6VwyY2PWYOMeSilOoI/BO4X2t9NZfN8rOYdZFTSpVQSvndfIxxgTfHq/FFzKzjdeuC4oOA236bUEqVVEp5Wh4HA82AXTbOYa+Ln+eZK9sY6/0YY7P2YDbwiGXmRmMg9ebwmpmUUmE3r30opRpi1L3kO7/KJvtVGOst79Zaf5LLZrY9ZkV95beAV4sPYIwzJVi+bs48KAvMy3bFeB/Gmdy/iiBXd4z/w14HzgALs+fCmK2w1fK1015ymXG8LPsMApYC+y1/lrL8PBYYZ3ncFNhuOWbbgaGFlOW2zw+8iXHiAOAF/GT597cRqFBExyivXO9a/i1tBZYB9xRRrmnAKSDd8u9rKDASGGl5XgFfWXJv5w4zv4o411O3HK/1QNMiytUcY/hk2y21q3NhHjO59V8IIZyEQwy5CCGEyJsUdCGEcBJS0IUQwklIQRdCCCchBV0IIZyEFHQhhHASUtCFEMJJ/D/W7YPqdLTuIgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Mathematical functions for logistic and hinge losses\n", "def log_loss(raw_model_output):\n", " return np.log(1 + np.exp(-raw_model_output))\n", "def hinge_loss(raw_model_output):\n", " return np.maximum(0, 1 - raw_model_output)\n", "\n", "# Create a grid of values and plot\n", "grid = np.linspace(-2,2,1000)\n", "plt.plot(grid, log_loss(grid), label='logistic');\n", "plt.plot(grid, hinge_loss(grid), label='hinge');\n", "plt.axvline(x=0, linestyle='dashed', color='k')\n", "plt.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementing logistic regression\n", "This is very similar to the earlier exercise where you implemented linear regression \"from scratch\" using `scipy.optimize.minimize`. However, this time we'll minimize the logistic loss and compare with scikit-learn's `LogisticRegression`.\n", "\n", "The `log_loss()` function from the previous exercise is already defined in your environment, and the sklearn breast cancer prediction dataset (first 10 features, standardized) is loaded into the variables `X` and `y`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "X = pd.read_csv('./dataset/breast_X.csv').to_numpy()\n", "y = pd.read_csv('./dataset/breast_y.csv').to_numpy()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1.03614857 -1.65378453 4.08306703 -9.4092245 -1.06786857 0.07893722\n", " -0.85110258 -2.44102697 -0.45285622 0.43353259]\n", "[[ 1.03665946 -1.65380077 4.08233062 -9.40904867 -1.06787935 0.07901598\n", " -0.85099843 -2.44107473 -0.45288928 0.43348202]]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\kcsgo\\anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:760: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", " y = column_or_1d(y, warn=True)\n" ] } ], "source": [ "# logistic loss, summed over training examples\n", "def my_loss(w):\n", " s = 0\n", " for i in range(y.size):\n", " raw_model_output = w@X[i]\n", " s = s + log_loss(raw_model_output * y[i])\n", " return s\n", "\n", "# Returns the w that makes my_loss(w) smallest\n", "w_fit = minimize(my_loss, X[0]).x\n", "print(w_fit)\n", "\n", "# Compare with scikit-learn's LogisticRegression\n", "lr = LogisticRegression(fit_intercept=False, C=1000000).fit(X, y)\n", "print(lr.coef_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, logistic regression is just minimizing the loss function we've been looking at. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }