{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Rosenblatt Perceptron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of the classic Perceptron by Frank Rosenblatt for binary classification (here: 0/1 class labels)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing a toy dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class label counts: [50 50]\n", "X.shape: (100, 2)\n", "y.shape: (100,)\n" ] } ], "source": [ "##########################\n", "### DATASET\n", "##########################\n", "\n", "data = np.genfromtxt('../data/perceptron_toydata.txt', delimiter='\\t')\n", "X, y = data[:, :2], data[:, 2]\n", "y = y.astype(np.int)\n", "\n", "print('Class label counts:', np.bincount(y))\n", "print('X.shape:', X.shape)\n", "print('y.shape:', y.shape)\n", "\n", "# Shuffling & train/test split\n", "shuffle_idx = np.arange(y.shape[0])\n", "shuffle_rng = np.random.RandomState(123)\n", "shuffle_rng.shuffle(shuffle_idx)\n", "X, y = X[shuffle_idx], y[shuffle_idx]\n", "\n", "X_train, X_test = X[shuffle_idx[:70]], X[shuffle_idx[70:]]\n", "y_train, y_test = y[shuffle_idx[:70]], y[shuffle_idx[70:]]\n", "\n", "# Normalize (mean zero, unit variance)\n", "mu, sigma = X_train.mean(axis=0), X_train.std(axis=0)\n", "X_train = (X_train - mu) / sigma\n", "X_test = (X_test - mu) / sigma" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEKCAYAAAAFJbKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X+cVPV97/HXW7JettFII3sj7kLAllJ/VKRuRMu9xmoSlEYxGFM1V0Pzg9vcGK03koviI6HGpCSkTZOYFLFaY675ZWMIvZgQDTW2pv5YxF9ICMQml11pJFgUA1R+fPrHnIVld2Z3ZnfmnDMz7+fjMQ/mnPky57PDcj7z/a2IwMzMrFyHZR2AmZnVFycOMzOriBOHmZlVxInDzMwq4sRhZmYVceIwM7OKZJY4JI2X9I+SnpW0TtLVRcpI0hckbZL0lKTfzyJWMzM76DUZXnsv8JGIeFzSkcAaSfdFxLN9ypwHTE4e04G/Sf40M7OMZFbjiIgtEfF48nwHsB5o71dsNnBnFDwMjJE0LuVQzcysjyxrHAdImghMAx7p91I7sLnPcXdybstg7zd27NiYOHFi9QI0M2twa9as+VVEtJVTNvPEIekI4NvAn0XEyyN4n3nAPIAJEybQ1dVVpQjNzBqfpF+UWzbTUVWSWigkjbsi4p4iRXqA8X2OO5JzA0TEsojojIjOtraykqaZmQ1DlqOqBNwGrI+IvypRbAVwRTK66nTgpYgYtJnKzMxqK8umqhnA5cDTkp5Izl0PTACIiKXAvcAsYBOwE/iTDOI0M7M+MkscEfHPgIYoE8CHqnG9PXv20N3dze7du6vxdg1h9OjRdHR00NLSknUoZlZHMu8cT0t3dzdHHnkkEydOpNBK1twigm3bttHd3c2kSZOyDsfM6kjTLDmye/dujj76aCeNhCSOPvpo18DMrGJNkzgAJ41+/HmY2XA0TVOVmaXkU+3w6isDzx9+BFxfdDS91RknjowtWrSII444gmuvvbbq771mzRrmzp3Lrl27mDVrFp///Oddy7DhqSQZFCs32HmrO03VVNVsPvjBD3LrrbeyceNGNm7cyPe///2sQ7J6Va1ksOioQhKyuubEUcLytT3MWLyaSQtWMmPxapavHXkV+8477+Tkk09m6tSpXH755QNev/XWW3nTm97E1KlTueiii9i5cycAd999NyeddBJTp07lzDPPBGDdunWcdtppnHLKKZx88sls3LjxkPfasmULL7/8MqeffjqSuOKKK1i+fPmIfwazEXPNo+65qaqI5Wt7uO6ep9m1Zx8APdt3cd09TwNw4bThfVtat24dN910Ez/+8Y8ZO3YsL7744oAyc+bM4QMf+AAAN9xwA7fddhsf/vCHufHGG1m1ahXt7e1s374dgKVLl3L11Vfz7ne/m1dffZV9+/Yd8l49PT10dHQcOO7o6KCnx+3LZjZyThxFLFm14UDS6LVrzz6WrNow7MSxevVqLr74YsaOHQvA61//+gFlnnnmGW644Qa2b9/OK6+8wsyZMwGYMWMGc+fO5V3vehdz5swB4IwzzuCTn/wk3d3dzJkzh8mTJw8rLrPccKd63XBTVRHPb99V0flqmTt3LjfffDNPP/00H//4xw/MsVi6dCk33XQTmzdv5tRTT2Xbtm1cdtllrFixgtbWVmbNmsXq1asPea/29na6u7sPHHd3d9Pe7rZlS8HhRwzv77lTvW44cRRx7JjWis6X4+yzz+buu+9m27ZtAEWbqnbs2MG4cePYs2cPd91114HzP/vZz5g+fTo33ngjbW1tbN68meeee47jjjuOq666itmzZ/PUU08d8l7jxo3jda97HQ8//DARwZ133sns2bOHHb81uVLJoNj563tg0Uu1jccy5aaqIubPnHJIHwdAa8so5s+cMuz3PPHEE1m4cCFvfvObGTVqFNOmTeOOO+44pMwnPvEJpk+fTltbG9OnT2fHjh2FeObPZ+PGjUQE55xzDlOnTuXTn/40X/3qV2lpaeGYY47h+uuvH3DNL3/5yweG45533nmcd955w47fmtxwmooOP6J005PVNRXWEWwsnZ2d0X8jp/Xr13P88ceX/R7L1/awZNUGnt++i2PHtDJ/5pRh92/kWaWfi1nNLDpqkNdcg6k1SWsiorOcsq5xlHDhtPaGTBRmZiPlPg4zy4dK+lEsU65xmFk+eMht3ch6z/HbJb0g6ZkSr58l6SVJTySPj6Udo5mZHSrrGscdwM3AnYOU+aeIeHs64ZhZXfBkwUxlWuOIiAeBgRMazMwG48mCmaqHzvEzJD0p6XuSTixVSNI8SV2SurZu3ZpmfCOyaNEiPvvZz9bkvRcuXMj48eM54gh3LppZ9eQ9cTwOvDEipgJfBEou7xoRyyKiMyI629raUgswz84//3weffTRrMMwswaT68QRES9HxCvJ83uBFklja37hT7UXJiP1f4xwH4E0l1UHOP300xk3btyIYjYz6y/rzvFBSToG+GVEhKTTKCS6bTW/cA3aT9NeVt3MrFYyTRySvg6cBYyV1A18HGgBiIilwDuBD0raC+wCLok6XSPFy6qbVZHXwcpUpokjIi4d4vWbKQzXbQpz585l+fLlTJ06lTvuuIMHHngAKNQuHnnkEVauXMmpp57KmjVruOyyy5g+fTorV65k1qxZ3HLLLZx99tnZ/gBmafGQ20zluo+jkaS9rLqZWa3kuo+jkWSxrPpHP/pRvva1r7Fz5046Ojp4//vfz6JFi1L4ac2s5jKcBOll1YtpolmpXlbdGlqe/y+PNLYqL0PvZdVHKutfKDOrjjzPMM9zbENwH4eZmVWkqWocEYGkrMPIjUZsprScy3PTkZWtaWoco0ePZtu2bb5ZJiKCbdu2MXr06KxDsWZSx80zdlDT1Dg6Ojro7u6mnhZArLXRo0fT0dGRdRiWN5XWClyLyEaGkyCbJnG0tLQwadKkrMMwy79KawV5rkXkeYb5SGPLMCk3TeIwsyaU5xpPnmMbQtP0cZiZWXW4xmFmlSnVp1GOPDcdWdmcOMysMiPpu6jj5hk7yInDrBGNZKRTtWoFrkU0LCcOs0Y0kpFOI6kVDGONJKs/mXaOS7pd0guSninxuiR9QdImSU9J+v20YzQzs0NlXeO4g8JGTXeWeP08YHLymA78TfKnmVn1eBJjRTKtcUTEg8DAHY0Omg3cGQUPA2MkjUsnOjMrqlTfRT33aeR5EmMOZV3jGEo7sLnPcXdybks24ZiZv4Fbw0wAlDRPUpekLq9HZU2vEWsFlht5r3H0AOP7HHck5waIiGXAMijsAFj70MxyzLUCq6G81zhWAFcko6tOB16KCDdTmVl6PtWedQS5k2mNQ9LXgbOAsZK6gY8DLQARsRS4F5gFbAJ2An+STaRmDa7ZRxWVmvQI7iAvItPEERGXDvF6AB9KKRyz5lXtm2a9JaLre2DRUVlHUTfy3lRlZvXI394bmhOHmZlVxInDzMwqkvfhuGaWV4P1Y2RpuP0r3iukbE4cZja8m2Ze+zEGi2vRUaUTSB477XPKicPMqn/TzPO396wTWwNw4jCz6vO394bmznEzM6uIE4eZmVXEicPMhievK/Bmff0m4D4Os2ZSzaVA8tqP0RtXXocLNwAnDrNmktchtLWQ18TWAJw4zFK2fG0PS1Zt4Pntuzh2TCvzZ07hwmleutvqhxOHWYqWr+3hunueZteefQD0bN/Fdfc8DeDkYXXDicMsRUtWbTiQNHrt2rOPJas2NHfiyGoZ9npb/j0nnDjMUvT89l0Vnc+dWt1os+p7aaY+nyrKdDiupHMlbZC0SdKCIq/PlbRV0hPJ4/1ZxGlWLceOaa3ofNWNdAitb7RGhjUOSaOALwFvBbqBxyStiIhn+xX9ZkRcmXqAZjUwf+aUQ/o4AFpbRjF/5pR0Aki7+cVNQQ0pyxrHacCmiHguIl4FvgHMzjAes5q7cFo7fzHn92gf04qA9jGt/MWc32uM/o1FRxUSRV+uoTSkLPs42oHNfY67gelFyl0k6Uzgp8A1EbG5SBkkzQPmAUyYMKHKoZpVz4XT2hsjURTjhNAU8r7kyD8AEyPiZOA+4CulCkbEsojojIjOtra21AI0syrIavmSvC6bknNZ1jh6gPF9jjuScwdExLY+h38LfCaFuMyslFL7bIxUVv0d7mcZliwTx2PAZEmTKCSMS4DL+haQNC4itiSHFwDr0w0xW55hbLnTe6NddFS2cVimMkscEbFX0pXAKmAUcHtErJN0I9AVESuAqyRdAOwFXgTmZhVv2jzD2CqVyy8aed4J0IZNEZF1DFXX2dkZXV1dWYcxIjMWr6anyKSw9jGtPLTg7Awisjzr/0UDCsN8azZiK0/DbGsZS55+zhqTtCYiOssp65njOVX3M4wtVakvZZKnm2Yth/x6OHFRThw5deyY1qI1jtRmGFtdaZgvGk30Db+e5X04btOaP3MKrS2jDjmX6gzjOrN8bQ8zFq9m0oKVzFi8muVrm+smk/lSJtXib/h1wYkjpxp6hnGV9bbv92zfRXBwIEEzJQ9/0bA0uakqxxp6hnEVeanygyPtcjeqyhqSE4fVvYZp3x8hf9GoAQ8nLsqJw+qeBxLUgVp2etfy5u4O+aKcOKzuZb5UuQ2t3E7v4SQB39xTVzJxSBoPLKGwiu33gCURsSd5bXlEXJhOiGaDc/t+A3ESqAuD1ThuB74NPAy8D/iRpPOThQffmEZwZuVy+75ZegZLHG0RsTR5/mFJ/wN4MFk7qvHWKTEzs7IMljhaJI2OiN0AEfF/Jf0bhUUJX5tKdGZmI+XZ6FU32ATAv6XfjnwRcT9wMfBMLYMyswaT5YZJno1edSVrHBHxuRLn1wJvrVlEZlWUy6XGm5G/2TcUD8e1huU9TcxqI9O1qiSdK2mDpE2SFhR5/b9I+mby+iOSJqYfpdWrwZYiMbPhyyxxSBoFfAk4DzgBuFTSCf2KvQ/494j4beBzwKfTjdLqmZciMauNIROHpDdIuk3S95LjEyS9rwrXPg3YFBHPRcSrwDeA2f3KzAa+kjz/e+AcSarCta0JNMxS4zYyWXbMN6hy+jjuAP4OWJgc/xT4JnDbCK/dDmzuc9xNv1Fcfcske5S/BBwN/GqE17Ym4KVIDHDHfA2U01Q1NiK+BeyHwg0c2Df4X0mfpHmSuiR1bd26NetwLAe8p4lZbZRT4/i1pKNJZotLOh14qQrX7gHG9znuSM4VK9Mt6TXAUcC2Ym8WEcuAZQCdnZ2e2V5Hajlk1kuRmFVfOYnjfwMrgN+S9BDQBryzCtd+DJgsaRKFBHEJcFm/MiuA9wD/klxzdUQ4KTQQD5kdnOehWB4NmjgkHQaMBt4MTAEEbOhdJXckkj6LKyksYTIKuD0i1km6EeiKiBUU+lG+KmkT8CKF5GINxLv3leakank1aOKIiP2SvhQR04B11b54RNwL3Nvv3Mf6PN9NYYkTa1AeMluak6rlVTmd4z+UdJGHwVoteMhsaU6qllflJI7/CdwN/IeklyXtkPRyjeOyJjF/5hRaW0Ydcs5DZgucVC2vhkwcEXFkRBwWEYdHxOuS49elEZw1Pg+ZLc1J1fJqyFFVks4sdj4iHqx+ONaMPGS2OG+Ja3lVznDc+X2ej6awVMga4OyaRGRWI/U4tNVJ1fJoyMQREef3PZY0HvjrmkVkVgMe2mpWPcNZHbcbOL7agZjVkpdYN6uecvo4vkiy3AiFRHMK8Hgtg7J01GPTzXB5aKtZ9ZTTx9HV5/le4OsR8VCN4rGUNFvTzbFjWukpkiQ8tNWscuU0VY2JiK8kj7si4iFJV9c8MqupZmu68dBWs+opJ3G8p8i5uVWOw1LWbE03ni9iVj0lm6okXUphtdpJklb0eelICgsOWp3p26dxmMS+IgsNN3LTjYe2mlXHYH0cPwa2AGOBv+xzfgfwVC2Dsurr36dRLGm46cbMylEycUTEL4BfAGekF47VSrE+DYBREvsjGn5UlZlVTznDcU8Hvkhh7sbhFPbO+LXXq6ovpfou9kfwr4v/KOVozKyelTMc92YKGyjdDXQCVwC/U8ugstDocxo8HNXMqqWsmeMRsQkYFRH7IuLvgHNHclFJr5d0n6SNyZ+/WaLcPklPJI8VxcpUQ2/7f8/2XQQH5zQsX9t/C/T65eGoZlYt5SSOnZIOB56Q9BlJ15T59wazAPhhREwGfpgcF7MrIk5JHheM8JolNcOcBg9HNbNqKaep6nIKieJK4BpgPHDRCK87Gzgref4V4AHg/4zwPYetWeY0eDiqmVVDOavj/kJSKzAuIv68Std9Q0RsSZ7/G/CGEuVGS+qisNTJ4ohYXuoNJc0D5gFMmDChomDc/m9mVr4hm5wknQ88AXw/OT6lnP4GSfdLeqbIY3bfchERHFxEsb83RkQnhYmIfy3pt0pdLyKWRURnRHS2tbUNFd4h3P5vZla+cpqqFlHYvOkBgIh4QtKkof5SRLyl1GuSfilpXERskTQOeKHEe/Qkfz4n6QFgGvCzMmKuiHdaMzMrXzmJY09EvCSp77lSNYRyraCwBtbi5M/v9i+QjLTaGRH/IWksMAP4zAivW5Lb/+tXow+lNsubchLHOkmXAaMkTQauorAcyUgsBr4l6X0UZqe/C0BSJ/CnEfF+ChMOb5G0n0KT2uKIeHaE17URyNMNujeWnu27EAe/yTT68vDVlKd/T6sviiJrFh1SQPoNYCHwtuTUKuCmiNhd49iGrbOzM7q6uoYuaGXrv9YVFPqBshjSWyyW/trHtPLQgrMPlPcN8lB5+ve0fJC0JulTHlLJznFJX02efiAiFkbEm5LHDXlOGlYbeZrrUmrdrb56h1IPZ3Ln8rU9zFi8mkkLVjJj8eqGmgjaK0//nlZ/BmuqOlXSscB7Jd0JHNrJEeGl1ZtInua6lHPNwyQmLVhZdPn43htksW/WzbIzYp7+Pa3+DDYcdymFWd2/C6zp93A7UJMpNacli7ku5VxzXwRB8eXjofQNslm+iefp39PqT8nEERFfiIjjgdsj4riImNTncVyKMVoO5GmuS7FYeqvDow4d/VdSqRtks3wTz9O/p9WfIScARsQH0wjE8i1Pa10Vi+Vzf3wKP1/8R+wfYrAHDH6DbJZv4nn697T6M+SoqnrkUVXNa8bi1UWXjyl3wyqPNrJmVcmoqnLmcZilohrDZufPnDKiG79XETAbmhOH5UK1RjNV48Y/nFUE+k5IHJWM5Gp30rEG5cRhuTDYaKZKb7xpLx/TP+n1juRq1KG8Zk4cVpZaz74uNWqpWH9F3gw2IXG4yc8sz0a6k581gTS21i01aknJ9fNsqKG6jTaU18yJw4aUxqS4+TOnUGwGRiTXz7Ohhuo22lBeMycOG1Iak+IunNZecq3+vH9jLzaZrpcn1VkjcuKwIaU1Ka69Tiff9Z1MBwdnr3tSnTUqd47bkErNjaj2N+m0rlML3gjMmkkmNQ5JF0taJ2l/snlTqXLnStogaZOkBWnGaAeltTyFl8Ewqw+ZLDki6XhgP3ALcG1EDFgfRNIo4KfAW4Fu4DHg0nJ2AfSSI5ZX3lTK8ir3S45ExHoADb6S6WnApoh4Lin7DWA24O1jraQ835grmR2f55/DLM99HO3A5j7H3cD0jGLJBd9MBpf3TZjKnR0/3J/Dvx+Wlpr1cUi6X9IzRR6za3S9eZK6JHVt3bq1FpfIVBqT8Opd3jdhKndY83B+Dv9+WJpqljgi4i0RcVKRx3fLfIseYHyf447kXKnrLYuIzojobGtrG0nouZT3m2Ie5H0TpnKHNQ/n5/Dvh6Upz/M4HgMmS5ok6XDgEmBFxjFlJu83xZFYvraHGYtXM2nBSmYsXj3sb8l534Sp3F33hvNzNPLvh+VPVsNx3yGpGzgDWClpVXL+WEn3AkTEXuBKYBWwHvhWRKzLIt48yPtNcbiq2cSS9+1Qyx1uPJyfo1F/PyyfvANgnainnemWr+3hz/9hHf++cw8AY1pbWHTBiUXjLLVjX/uYVh5acPawrt0IHcSV/hz19Pth+ZT74bhWuXrZmW752h7m//2T7Nl38AvJ9l17mH/3k8DAUUHVbmIpNYO73hJKpTPR6+X3wxqDE0cdqYdlLZas2nBI0ui1Z38c6Kjte3M7qrWF7bv2DChfzSaWvA/TrZZ6+P2wxpDnznGrQ4PVFHpv2H37M3796l5aDjt0Imi1+yU84sisupw4rKoGqymMkgbcwPfsC44Y/Zqark/lEUdm1eWmKquq+TOnDOjjAGg5TOzZX3wgxvade1j7sbfVLKZjx7QW7YD3iCOz4XGNw6rqwmntLHnnVH7zN1oOnBvT2sKSi6dmtt9G3ofpmtUb1zis6gbrpM1ivw2PODKrLicOS02WN3CPODKrHicOS5Vv4Gb1z4nDrI96myholgUnDrNEs0wUNBspj6oyS3iioFl5nDjMEp4oaFYeJw6zhJcmNyuPE4dZwhMFzcrjznGzhCcKmpUnk8Qh6WJgEXA8cFpEFN11SdLPgR3APmBvuZuMmA2X55mYDS2rGsczwBzgljLK/mFE/KrG8ZiZWZkySRwRsR5A0lBFzcwsZ/LeOR7ADyStkTRvsIKS5knqktS1devWlMIzM2s+NatxSLofOKbISwsj4rtlvs1/i4geSf8VuE/STyLiwWIFI2IZsAygs7Oz+MYPZmY2YjVLHBHxliq8R0/y5wuSvgOcBhRNHGZmlo7cNlVJeq2kI3ufA2+j0KluZmYZyiRxSHqHpG7gDGClpFXJ+WMl3ZsUewPwz5KeBB4FVkbE97OI18zMDspqVNV3gO8UOf88MCt5/hwwNeXQzMxsCJ45bqnyfhdm9c+Jw1Lj/S7MGkNuO8et8Xi/C7PG4MRhqfF+F2aNwYnDUuP9LswagxOHpcb7XZg1BneOW2q834VZY3DisFR5vwuz+uemKjMzq4gTh5mZVcSJw8zMKuLEYWZmFXHnuFkRXlPLrDQnDrN+vKaW2eDcVGXWj9fUMhtcVhs5LZH0E0lPSfqOpDElyp0raYOkTZIWpB2nNSevqWU2uKxqHPcBJ0XEycBPgev6F5A0CvgScB5wAnCppBNSjdKaktfUMhtcJokjIn4QEXuTw4eBjiLFTgM2RcRzEfEq8A1gdloxWvPymlpmg8tD5/h7gW8WOd8ObO5z3A1MTyUiG1Ijjzrymlpmg6tZ4pB0P3BMkZcWRsR3kzILgb3AXVW43jxgHsCECRNG+nY2iGYYdeQ1tcxKq1niiIi3DPa6pLnA24FzIiKKFOkBxvc57kjOlbreMmAZQGdnZ7H3syoZbNSRb7ZmjS+rUVXnAh8FLoiInSWKPQZMljRJ0uHAJcCKtGK00jzqyKy5ZTWq6mbgSOA+SU9IWgog6VhJ9wIknedXAquA9cC3ImJdRvFaHx51ZNbcMukcj4jfLnH+eWBWn+N7gXvTisvKM3/mlEP6OMCjjsyaSR5GVVmd8agjs+bmxGHD4lFHZs3LicPK1shzN8ysfE4cVpZmmLthZuXx6rhWFq8Ya2a9nDisLJ67YWa9nDisLJ67YWa9nDisLF4x1sx6uXPcyuK5G2bWy4nDyua5G2YGbqoyM7MKOXGYmVlFnDjMzKwiThxmZlYRJw4zM6uIE4eZmVVExbf7rm+StgK/yDqOQYwFfpV1EDnjz2QgfyYD+TMZqFqfyRsjoq2cgg2ZOPJOUldEdGYdR574MxnIn8lA/kwGyuIzcVOVmZlVxInDzMwq4sSRjWVZB5BD/kwG8mcykD+TgVL/TNzHYWZmFXGNw8zMKuLEkQFJSyT9RNJTkr4jaUzWMWVN0sWS1knaL6mpR81IOlfSBkmbJC3IOp48kHS7pBckPZN1LHkhabykf5T0bPJ/5+q0ru3EkY37gJMi4mTgp8B1GceTB88Ac4AHsw4kS5JGAV8CzgNOAC6VdEK2UeXCHcC5WQeRM3uBj0TECcDpwIfS+l1x4shARPwgIvYmhw8DHVnGkwcRsT4iNmQdRw6cBmyKiOci4lXgG8DsjGPKXEQ8CLyYdRx5EhFbIuLx5PkOYD2QyoY5ThzZey/wvayDsNxoBzb3Oe4mpZuB1S9JE4FpwCNpXM87ANaIpPuBY4q8tDAivpuUWUihunlXmrFlpZzPxMwqI+kI4NvAn0XEy2lc04mjRiLiLYO9Lmku8HbgnGiSMdFDfSYGQA8wvs9xR3LObABJLRSSxl0RcU9a13VTVQYknQt8FLggInZmHY/lymPAZEmTJB0OXAKsyDgmyyFJAm4D1kfEX6V5bSeObNwMHAncJ+kJSUuzDihrkt4hqRs4A1gpaVXWMWUhGTRxJbCKQmfntyJiXbZRZU/S14F/AaZI6pb0vqxjyoEZwOXA2cl95AlJs9K4sGeOm5lZRVzjMDOzijhxmJlZRZw4zMysIk4cZmZWEScOMzOriBOHWRGSrpK0XlLFs/olTZR0WS3iSt7/TEmPS9or6Z21uo5ZKU4cZsX9L+CtEfHuYfzdiUDFiSNZGbcc/x+YC3yt0muYVYMTh1k/yYTM44DvSbpG0muT/SAelbRW0uyk3ERJ/5R8+39c0h8kb7EY+O/JhKxrJM2VdHOf9/9/ks5Knr8i6S8lPQmcIelUST+StEbSKknj+scXET+PiKeA/TX+KMyK8lpVZv1ExJ8my8L8YUT8StKngNUR8d5k061HkwUbX6BQK9ktaTLwdaATWABcGxFvhwPrkpXyWuCRiPhIsu7Qj4DZEbFV0h8Dn6SwgrJZbjhxmA3tbcAFkq5NjkcDE4DngZslnQLsA35nGO+9j8IidQBTgJMoLEUDMArYMoK4zWrCicNsaAIu6r/RlKRFwC+BqRSafXeX+Pt7ObRZeHSf57sjYl+f66yLiDOqEbRZrbiPw2xoq4APJ6uRImlacv4oYEtE7Kew2Fxv5/YOCotY9vo5cIqkwySNp7DLXzEbgDZJZyTXaZF0YlV/ErMqcOIwG9ongBbgKUnrkmOALwPvSTq2fxf4dXL+KWCfpCclXQM8BPwr8CzwBeDxYhdJtop9J/Dp5D2fAP6gfzlJb0pWEr4YuCWJySw1Xh3XzMwq4hqHmZlVxInDzMwq4sRhZmYVceIwM7OKOHGYmVlFnDjMzKwiThxmZlbOp1OnAAAADUlEQVQRJw4zM6vIfwIA1FhR7S4AaAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", "plt.scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", "plt.xlabel('feature 1')\n", "plt.ylabel('feature 2')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the Perceptron model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "def custom_where(cond, x_1, x_2):\n", " return (cond * x_1) + ((1-cond) * x_2)\n", "\n", "\n", "class Perceptron():\n", " def __init__(self, num_features):\n", " self.num_features = num_features\n", " self.weights = torch.zeros(num_features, 1, \n", " dtype=torch.float32, device=device)\n", " self.bias = torch.zeros(1, dtype=torch.float32, device=device)\n", "\n", " def forward(self, x):\n", " linear = torch.add(torch.mm(x, self.weights), self.bias)\n", " predictions = custom_where(linear > 0., 1, 0).float()\n", " return predictions\n", " \n", " def backward(self, x, y): \n", " predictions = self.forward(x)\n", " errors = y - predictions\n", " return errors\n", " \n", " def train(self, x, y, epochs):\n", " for e in range(epochs):\n", " \n", " for i in range(y.size()[0]):\n", " # use view because backward expects a matrix (i.e., 2D tensor)\n", " errors = self.backward(x[i].view(1, self.num_features), y[i]).view(-1)\n", " self.weights += (errors * x[i]).view(self.num_features, 1)\n", " self.bias += errors\n", " \n", " def evaluate(self, x, y):\n", " predictions = self.forward(x).view(-1)\n", " accuracy = torch.sum(predictions == y).float() / y.size()[0]\n", " return accuracy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the Perceptron" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model parameters:\n", " Weights: tensor([[1.2734],\n", " [1.3464]], device='cuda:0')\n", " Bias: tensor([-1.], device='cuda:0')\n" ] } ], "source": [ "ppn = Perceptron(num_features=2)\n", "\n", "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n", "\n", "ppn.train(X_train_tensor, y_train_tensor, epochs=5)\n", "\n", "print('Model parameters:')\n", "print(' Weights: %s' % ppn.weights)\n", "print(' Bias: %s' % ppn.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating the model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test set accuracy: 93.33%\n" ] } ], "source": [ "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", "\n", "test_acc = ppn.evaluate(X_test_tensor, y_test_tensor)\n", "print('Test set accuracy: %.2f%%' % (test_acc*100))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAADFCAYAAAAMsRa3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl4VNX5wPHvyWRCErJBErYECPuWEELCGtzQVqoCAVllETeUzR1/WFpF3EttVXYoiCAqoBBEqlhLrWWVBMIOsgkkUMFgAEmQ7f7+SEIDZJJZ7sy9M/N+nofnkUly5yXOmXfOed9zrtI0DSGEEMJbBBgdgBBCCOEISVxCCCG8iiQuIYQQXkUSlxBCCK8iiUsIIYRXkcQlhBDCq0jiEkII4VUkcQkhhPAqkriEEEJ4lUAjnjQmJkZLSEgw4qmF0EV2dvZPmqbFGh1HKRlTwhfYO64MSVwJCQlkZWUZ8dRC6EIpddjoGMqSMSV8gb3jSpYKhRBCeBVJXEIIIbyKJC4hhBBexZAalxDC/S5evEhubi7nz583OhRTCQ4OJj4+HqvVanQowkk+lbj+tecE//7+JOPvboHVIpNJ4d9yc3MJDw8nISEBpZRT1zj36yV+PneBOtVCCHDyGmaiaRr5+fnk5ubSoEEDo8MRTvKpd/ecowXMW/cDg2Zv5OTZX40ORwhDnT9/nujoaKeTFkDhhcucKrzAoZ/OcenyFR2jM4ZSiujoaJmFejmfSlxP/aYp7wxow7a8ArpPXkPO0QKjQxLCUK4kLYDY8CrUrR5K4YXLHDh5jl8vXdYpMuO4+jsRxvOpxAXQs00cn47oTKBF0W/GehZvOmp0SEJ4tWqhQTSMqcqlK1c4cOIc5369ZHRIws/5XOICaFUnkhWju9C+QXWe+3Qbf8zcwYVL3r/MIYRRqlYJpHFsGJYAOPjTOQoKLzh9rQkTJvDnP/9Zx+j+Jzs7m6SkJBo3bszjjz+OpmlueR5hLJ9MXADVqgYx74F2PHpzQxZsOMx9szdw4qysawthS+aWPNLfWE2DcStJf2M1mVvyrvl6FauFRrFhhFotHDlVyIkz502XGEaMGMHs2bPZt28f+/bt48svvzQ6JOEGPpu4AAItATx/VwsmD0xh57EzdJ+8hi1HfjY6LCFMJ3NLHs8v3U5eQREakFdQxPNLt9+QvAItATSIrUpUaBD/PXOe3J+LuFJB8po/fz6tW7cmOTmZIUOG3PD12bNn065dO5KTk7n33nspLCwEYMmSJSQmJpKcnMzNN98MwM6dO2nfvj1t2rShdevW7Nu375prHT9+nDNnztCxY0eUUgwdOpTMzEwXfzPCjHw6cZXqnlyHpSM7ExQYQP+ZG1i06YjRIQlhKpNW7aXo4rWNF0UXLzNp1d4bvjdAKepWC6FmRDA/F17gBxsdhzt37uSVV15h9erVbN26lXfeeeeG7+nduzebNm1i69attGjRgjlz5gAwceJEVq1axdatW/nss88AmDFjBk888QQ5OTlkZWURHx9/zbXy8vKueSw+Pp68vGsTr/ANfpG4AFrUjmDF6C50aFid//t0O+OXbZe6lxAljhUUOfS4UoqaEcHUrRbKORsdh6tXr6Zv377ExMQAUL169Ruus2PHDm666SaSkpJYuHAhO3fuBCA9PZ1hw4Yxe/ZsLl8uvm6nTp147bXXePPNNzl8+DAhISFO/3uFd/ObxAUQFRrEvAfaM+LWRizceISBszdw4ozUvYSoE1V+ErD1eKlqVYNo4ELH4bBhw5gyZQrbt2/nxRdfvLq/asaMGbzyyiscPXqU1NRU8vPzue+++/jss88ICQnhrrvuYvXq1ddcKy4ujtzc3Kt/z83NJS4uzqF4hHfwq8QFYAlQ/F+35ky9ry27jp3hnslryD4sdS/h38be2YwQq+Wax0KsFsbe2azSnw2rEkij2DACrus47Nq1K0uWLCE/Px+AU6dO3fCzZ8+epXbt2ly8eJGFCxdeffzAgQN06NCBiRMnEhsby9GjRzl48CANGzbk8ccfp2fPnmzbtu2aa9WuXZuIiAg2bNiApmnMnz+fnj17Ovy7EObnd4mr1N2ta7NsVGdCgiwMmLWeDzdK3Uv4r4yUOF7vnURcVAgKiIsK4fXeSWSk2DdjCbZaaHxdx2HLli0ZP348t9xyC8nJyTz99NM3/NzLL79Mhw4dSE9Pp3nz5lcfHzt2LElJSSQmJtK5c2eSk5NZvHgxiYmJtGnThh07djB06NAbrjdt2jQefvhhGjduTKNGjfjd737n9O9EmJcyop01LS1NM8tN704XXuTxj7fw7+9PMrB9XSb0aEWVQEvlPyj8mlIqW9O0NKPjKFXemNq9ezctWrTwaBxXrmjk/lxEQdEFqocGmfaMQyN+N6Jy9o4rv51xlYoMtTJ3WDtG3tqIj747yoBZG/hR6l5COCUgQFG3egg1woM5VdpxeEWaoIS+/D5xQXHd67luzZk2qC17/3uWeyavIeuHG9fjhRCVU0pRKzKY+NKOwxPnuOADZxwK85DEVcZdSbXJHJVO1SALA2dv4IMNh013MoAQ3qJ61SAaRIdy6coV9p84R6GccSh0IonrOk1rhrN8dBe6NI7hD5k7GPfpdp84EVsII4QFW6/pODztwhmHQpRyOXEppeoqpf6llNqllNqplHpCj8CMFBliZc797RjTtTGLso7Sf+YG/nta6l5COKO04zDYauHwqUJOnDXfGYfCu+gx47oEPKNpWkugIzBKKdVSh+saKiBA8cxvmzFjcCr7fiyue313SOpeQjgj0BJAw5iqRIVY+e/p8+QVVHzGoRAVcTlxaZp2XNO0zSX/fRbYDfjMdvVuibXIHJVORHAg983ewPz1P8inReFWvriKAaUdh6HMnzKJv771Fj/8dI7LOnccjh8/nrp16xIWFqbrdYW5BOp5MaVUApACbCzna8OB4QD16tXT82ndrknNcDJHp/PUxzm8sHwn23NP83JGIsFW2e8l3KJ0FWOzUiocyFZK/UPTtF1ue8bX4uDCLzc+HhQGv9fvoFqlFGHBVgKCrJz7tfiMw4ToUIJ02jvZvXt3Ro8eTZMmTXS5njAn3ZozlFJhwKfAk5qmnbn+65qmzdI0LU3TtLTY2Fi9ntZjIoKtzB6axuO3N2FJdi79Z663eQCpEK4wZBWjvKRV0eN2snVbk9CgQBrEhPLR/Pdom9aOpNatXb6tCUDHjh2pXbu2SzEL89NlxqWUslKctBZqmrZUj2uaUUCA4unfNCWxTgRPL95KjylrmHpfWzo0jDY6NOGjvHkVo/S2JuvWrSMmJuaGswrDgq08PGQA/QcP49IVjffffYM5c+YwZsyYq7c1iYuLo6CgAPjfbU0GDRrEhQsXrp4aL/yPHl2FCpgD7NY07S+uh2R+v21VUvcKsTLobxuZt/aQ1L2E7rx9FcOe25rs37ubh/vezb13dObDjz4kK2cbmqbJbU1EhfRYKkwHhgBdlVI5JX/u0uG6pta4RhiZo9K5tVkNJqzYxbNLtnH+oh9/AnwtDiZE3vjnNZ/p0/Eof1nFGDZsGFOnTmH3zh089dzvKThzjryCIqZPn+7wbU2E/3B5qVDTtDWA+U7R9ICIYCuzhqQyefV+/vr193z/41lmDEklrpJ7GHk1W0V8Wy78UvwzOhb4fZ2vrGJ07dqVXr168fTTTxMdHc2pU6dumHWV3tbk8uVLrFr+CdVja3Hq3AX27z/ATe3a0aFDB7744guOHj3K6dOnr97W5MiRI2zbto2uXbsa9K8TRpKTM1wUEKB44o4m/G1oGj/8dI7uk9ew/kC+0WG5jzPFehcL/H7I86sYQTbax209bodWrVo5fFuTkCAL8dVCeOWF39OyVRKtHLytyXPPPUd8fDyFhYXEx8czYcIEp+N3C1mZ0IXf39ZETwdO/sLw+Vn8kF/I+Lta8EB6AsqEt3RwyYRI/a6lc6u1J8ltTdzr7PmLHMkvRAUoEqJDCQ3SdeeOcb+bisbPhNOei8Ok5LYmBmgUW1z3ur15DSZ+votnFm/177pXZWQmJmwID7bSqEYYAcDBk+c4XXTR6JCEiUji0ll4sJUZg1N55jdNWZaTx73T15H7c6HRYQnhdYKtFhrVKDnjMP8cJ8/+Kt27ApDE5RYBAYoxtzdhzv1pHMkvpPvkNazb/5PRYQk/5O1v9NaSMw4jQ6wcP13EsQLXD+j19t+JkMTlVl2b12T56HRiwqowZO53/O0/B71/0FRUxHehkC/0FxwcTH5+vte/5gICFPWqhxIbXoX8c7/yQ34hl68492/SNI38/HyCg4N1jlJ4kr4VT3GDhrFhLBuVzrOLt/LKyt1szzvNG71bExLkpeccOtNMoWdDh7BbfHw8ubm5nDx50uhQdHPh10scKLzIYYsiJqwKlgDHm5+Cg4OJj493Q3R2CAqzfSakh86L9AWSuDwgrEog0we3Zdo3B/jzV3vZ9+MvzBySSt3qoUaH5hm2BiuUn9RkoOrCarXSoEEDo8PQ3b+/P8mohZsJDbIwd1g7EuO86INRRa9rWx/wzNjEZHCSlaVCD1FKMeq2xsy9vx1Hfy6kx5Q1rNnnJ3Wv3+cVt/pe/8cWMw5UYRq3NI3l0xGdsVoC6DtjPV/v+tHokPyPmw5ltpckLg+7rXkNVozuQmx4FYbO3cjsb32g7iWEhzWrFc6ykZ1pUjOMRxZk8d7aQ0aHJEp5YGO1JC4DJMRUZdnIdLol1uLVv+/m8Y9zKLxwyeiwhPAqNSKC+Xh4R37ToiYvrdjFi8t3cOmyvjemFC5y0wxMEpdBqlYJZOp9bXmuWzM+33aM3tPWcfSU7PcSwhGhQYFMH5zKIzc14P31hxm+IJtzv8qHQF8nzRkGUkox8tbGtKwdweMfbaH7lDVMHpjCTU3Md4sKnyWdXF7PEqAYf3dL6kVX5cXlO+g7Yz1zhqVRO9LLDruuqONQXEMSlwnc2qwGK8Z04dEF2dw/9zue69acR29u6HvnHF7PkW5DZxNJZYnJ4CKz0M+QjvWJrxbC6IWbyZi6ljn3e0HHobd+cKpo7HqAJC6TqB9dlaUjOzP2k2288cUetuedZlKf1rofLuox9gzI8gam3i3Bkpj8ym3NavDJiM48OG8T/WauZ/LAFG5vUdPosGzz1tfn9WPXw3s1vfRd0TeFBgUyZWAKSXGR/OnLPRw4Ubzfq350VaNDc5y3Dkjh9VrUjiBzVDoPvb+JR+Zn8WL3VtzfOcHosBy/l5038fAypyQuk1FK8dgtjWhZO4IxH22hx5S1vDswhVuaGlj3qmj2BL47GIXXqhkRzOJHO/H4Rzm8+NlOfsg/xx/ubunUSRu68eVx4uFlTUlcJnVz01hWjO7C8AVZDHvvO8be2YwRtzQypu4lsyfhhUKDApk5JJVXV+5m7tpDHD1VyDsDUqhaxY/f9ry1pnYdaYc3sXrRoSwd2Zl7WtfhT1/uZdSHm6XVV29uuPOvMA9LgOKF7i2Z2LMVq/ecoN/M9fx45rzRYRnHRz6E+vFHD+8QGhTIuwPa0Doukte/2M3+E78wa0gaCTFeWPeyh95r5ZVdz4s+ZQrnDe2UQN1qoYz+8H8dhy3rRBgdlm3ywalCkri8gFKKR25uSIvaEYz+aDM9pqzhnYEp3Nashn5PUlkdSy+VXU/vRCKJSZS4rXkNFj/WiYfmZdF3xjqmDGqr7xhyRUVnd9rDR5YA7aVL4lJKzQXuAU5ompaoxzXFjbo0iWHF6OL9Xg/O28Szv23GyFt1qnu5YwnB1cFYys8GpXCfVnUiyRyVzoPzNvHQvE281DORIR3rGx2W63xkCdBeetW45gHddLqWqEDd6qF8OqIzPZLrMGnVXkZ8sJlf3F33cmbWpedMzc8GpXCvWpHBLHmsE7c2q8EfM3fwyue7nL4xpTCGLjMuTdO+VUol6HEtUbmQIAtv929DUlwkr3+xh15T1zJraBoN3FX3cmSjMOg30xLCTapWCWT20DRe/nwXf1tziMOnCnlnQBv3bvg3w5FOZohBBx6rcSmlhgPDAerVq+epp/VZSikevqkhLWtHMOrDkrrXgDZ0be6hUwJ8ZAAI/2UJUEzo0Yr60aG8/Pku+s/cwJz706gREeyeJzRyWdvHlts9lrg0TZsFzAJIS0uTeblOOjeOuXrO4UPvZ/HUHU0ZfVtjAvTcaGnrLsUys3IbqRt7zgPpDahXPZQxH20hY+pa5j7Qjua1TNxx6IzKltu9LLFJV6EPiK9WXPd6ful2/vKP79mRd5q3+iUTHmy1/yKOHppZ3vd62Yvf5OYBU4D5BsfhF25vUZPFj3biofc30Wf6eqYOaqvvaTXuHhuuroB4WR1ZEpePCLZa+Eu/ZJLiInn177vJKKl7NYq184Vra/A4cniml734zUzqxp6XGFfacZjFg/M2MbFnKwZ10Knj0N1jw88+GOrVDv8RcCsQo5TKBV7UNG2OHtcW9lNK8WCXBrQoqXtlTFnLX/u34Y6Wbqp76XloqKf2kfkQqRvrr3ZkCEse68SYDzczftkODucXMq5bc32X3oXL9OoqHKjHdYQ+OjWKZsWYLjy2IJuH52fx5B1NeLxrE/0Hn54zKZmtOUzqxu4RVtJxOPHzXcz69iCH88/xdv8UQoIs7n9yWW63i5xV6KPiooo/OfZuG8fbX+9j+IJszpy/aHRYQniFQEsAL/VoxQv3tOSrXT8yYNZ6Tpz1wBmH7voA52NnckqNy4cFWy281TeZ1nGRvLyypO41JI3GNRx4sVZU9DV6NuSlg054h9Kl9/hqITzxcQ69pq5j7rB2NKsVbnRojqtstuZl21skcfk4pRTD0hvQvHYEo0puaf6Xfsn8tlUt2z9k73KFvY0berz4/az1XurG5vHbVrVY/GgnHnx/E32mr2Pa4Lbc1MTBjkOzJQYvX5KUxOUnOjYsrnuN+CCb4Quyefz2Jjx5u426lx7LFX6WaPQmdWNzSYov7jh8aN4mhr23iVcyEhnY3oGGGLMlAy+vKUuNywdlbskj/Y3VNBi3kvQ3VpO5pXjQ1IkKYdGjneiTGs+7/9zHI/OzOF3kQt1Lz3Vzky5JCFGqtG7cpXEMzy/dzutf7OaKnHFoCJlx+ZjMLXk8v3Q7RRcvA5BXUMTzS7cDkJESR7DVwqQ+rUmOj+SlFbtK6l6pNKnpxLq9PZ8i7V2S+H1exe311y9LesmShvAt4cFW5tyfxouf7WTmvw9yJL+Qv/ZvQ7BVp45Dsy0pmpQkLh8zadXeq0mrVNHFy0xatZeMlDiguO41pFMCzWpFMHJhNhlT1/JWvzZ0S6yg7lWWI+vjjixJOHKYr5csaQjfE2gJ4JWMRBrEVOXVv+/m+KwNzB6aRmx4FdcvLh/G7CKJy8ccKyiy+/H2DaoX7/f6YDOPfZDNmK6NefKOplT62bGiZOTISRtCeKnSQ67jq4Xy5KIt9Jq2lveGtbNv5cLLGyPMQGpcPqZOVIhDj9eODGHR8I70S4tn8ur9PPz+Jk5bbXRMyXKFENfolliLRcM7cf7iFXpPX8fa/T9V/kNmaIzw8n1dMuPyMWPvbHZNjQsgxGph7J3NbP5MsNXCm/e2pnV8FC+t2EnPqNnMejiNps7UvYTwM8l1o8gc1ZkH523i/rnf8VqvJPq1q2t0WBXz8pmdX8y4bHXZ+aKMlDhe751EXFQIiuJOqNd7J12tb9milGJwx/p89EhHfvn1MhlT1/LF9uOeCVoILxdfLZRPRnSmU6Nonvt0G3/6co90HLqRz8+4Kuuy80UZKXEO/dsyt+QxadVejhUUUScqhNG3NWL51mOMWLiZkbc24pnfNsOi9zmHpUsSFXUSBoVVfEJH2Xqa1AeEwSKCrcwd1o4Xlu9k2jcHOJxfyFv9kvXrOBRX+XzisqfLzp+Vl9jf/HIvL/dsRfNa4Uz75gA7j53h3QEpRIaW3N9Lj+OeLvxS+enyF365cSOzdBkKE7NaAnitVyINYkJ57e97OHa6iNlD04gJ06HjUFzl80uFjnTZeZoZljBtJfa/fr2P13u35rVeSaw78BM9pq5h73/PFn+DXjMbSTbCBymlGH5zI6YPasuuY2foNW0t+0+c/d83eHljhBn4/IyrTlQIeeUkKVtddp5iliXMyhL7fR3q0axWOCM+yKbXtLVM6pPM3a1rV7xR8vrEJi3ywg/9Lqk2tSKDeWR+Fr2nrWPGkFQ6N4qRJW0d+PyMa+ydzQi5bo25si47T6hoCdOT7GmfT61fjRVjutC8VjijPtzMG1/s4fK43OJlvOv/QHGiKvtHCD+VUq8ay0amUzMimKFzvmNJ1lGjQ/IJPp+4nO2yczezLGHam9hrRgTz8fBODOpQjxn/PsCw976joPDCjReU5T8hrlG3enHHYceG0Yz9ZBt/XrVXOg5d5PNLheB4l50nmGUJs/T3UrarcOydzcr9fQUFBvBqrySS4iJ5YflOuk9Zw6whabSoHeF8ABU1epS35i9nuQkvFBli5b0H2vGHZTuY8q/9HD5VyKQ+raXj0El+kbjMyJmNwu7iaGIf0L4eTUvqXr2nrePNPq3pkVzH8Sd2poVd6gPCS1ktAbxxbxIJMVV588s9HCsoYtaQVKKl49BhkrgMUt5M57bmsUxatZenFuVUOPMxg7b1iuteIz/YzOMfbWFn3mnG3tms4heUPffoknPchA9TSjHi1kbUqx7KU4tz6D29+K7KjWJlxcARkrgMVHamY5YuQ0fUCA/mw0c6MvHzncz89iA7j51hshZGNeVCncsM57gJ4WZ3ty7uOBxe0nE4c0gqHRtGGx2W19ClOUMp1U0ptVcptV8pNU6Pa/obs3QZOiooMIBXMpJ4894kvjt0iu4XXmPnlfrlfKN8ohSirNT6xR2HMWFBDJmzkU+zc40OyT6vxd3YOTwhsvhxD3F5xqWUsgBTgd8AucAmpdRnmqbtcvXa/sQsXYbO6t+uHs1qRfDYgmzuLXqTNzNa07ONOWeKQphFvehQlo5IZ8TCbJ5ZspXDpwp56o4mKKXzEWt6MsGqiB4zrvbAfk3TDmqadgH4GOipw3X9iqO3IzGjNnWjWDGmC63jonji4xxeXbmLS5evGB2WEOZUMnOJ/FMM83Lvoq/lG9795z6e/OOL/HrpcuU/78f0SFxxQNlddbklj11DKTVcKZWllMo6efKkDk9rbo4e52TWjdKOig2vwsJHOnB/p/rM/s8hhs79jlPnytnvJYS/KzNDCVKX+VPgLMYGLmL5pQ4M/ttGGTcV8NgGZE3TZmmalqZpWlpsrI0bFfqI0kaLvIIiNP7XaFFR8vLERmlPnY1otQTwUs9EJvVpTdbhn+k+eQ078uzoKAQ5x034DEfHm1IwKnA5k63vsjX3NL2nreXQT+c8FK130aOrMA8oe9e0+JLH/JazJ9K7c6O0EV2LfdPq0qxWOI8uyObe6et4897WlT+XtLwLH+DKeOtu2UCdBzrwyPzi80FnDk6lg3QcXkOPGdcmoIlSqoFSKggYAHymw3W9lhkbLYzqWmwdX1z3Sq4bxZOLcpi4Qupe9pBOXe/m6nhLrV+dZSM7U71qEIPnbGTZFhN1HJpgVcTlGZemaZeUUqOBVYAFmKtp2k6XI/NiZjnOqSwjk2lMWBUWPtyBV1fuZu7aQ+w6fpqp97WVEwNskE5d76fHeKsfXZVlI9J59IMsnlq0lcP5hTxxuwk6Dk2wKqJLjUvTtL9rmtZU07RGmqa9qsc1vZkrjRbuqkMZ3bVotQQwoUcr3uqbzJYjBXSfvIbtuXbWvfyPdOp6ObvGmx0zl8hQK/Mf7MC9beN5++t9PL14q3QcIidnuIUjB9eW5c46lFnORrw3NZ6mNcN5dEEWfWas4/XeSfRuG+/RGLxAeZ26Ha7/JqXUcGA4QL169TwTmbCLXePNzplLUGAAf+7bmoToUN76x/fklZxxGBUapHfYXkNpmueP109LS9OysrI8/rz2ytyS53DS0UP6G6vLXWKMiwph7biuLl/fqH9XefJ/+ZVRH25mw8FTDOucwPi7W2C1eM9ddpRS2Zqmpbnp2n2AbpqmPVzy9yFAB03TRtv6GbOPKX/kjvG2PCePsUu2EVcthPeGtSMhpqpO0ZqDveNKZlzXMfLMQHfXocx0e5fosCp88FAHXv9iD3PWHGL38TNMHdSWGKl7gXTq+gR3jLeebeKoExXC8PlZ9Jq2lllD02iXUF3X5/AG3vMR1wMyt+TxzOKthp0ZaHQdytMCLQH88Z6WvN2/DTlHi+te23ILjA7LDKRTV9jULqE6y0amExUaxKDZG1me43+faWTGVaJ0pnXZxtKpJ7rvzFKH8rSMlDga1wjj0QXZ9JmxnlczEumbVrfyH/RRvtypa6blatOy49Y+CTFVWTqiM49+kM0TH+dwJL+Q0V0bG99x6CGmnXF56pSHUuXtuyjLE7MeT5yeYVaJcZGsGNOFtPrVGPvJNl5cvoOLfrzfyxc7dZ05UcYv2XmIbbWqQSx4qD29UuJ46x/f8+ySbVy45B9jxpQzLiPqTBXNqDw56zFTHcrTqlcNYv6D7Xnzyz3M/s8hdh8/y9RBbYkNl7qXL3D2RBlhW5VAC3/pl0z96FDe/nofeQWFzBycRmSo1ejQ3MqUMy4jTnmwNaOyKOU3sx4zCLQEMP7ulrwzoA3b8orrXjlHpe7lC8x4oowvUErx5B1N+Wv/ZDYfLqDX9LUczvftMw5NmbiMeIHb2jT8Vr9kSVoG6Nkmjk9HdCbQoug3Yz2LNx2t/IeEqflb85Gn9UqJZ8FD7Tl17gK9pq0j+/Apo0NyG1MmLiNe4P5cXzKrVnUiWTG6C+0bVOe5T7fxh8ztfrOG74t85dY9ZtahYTRLR3QmIjiQgbM3smLrMaNDcgtT1riM6q7zVH2ptLMqr6AIi1Jc1jTipMOqXNWqBjHvgXZMWrWXmd8eZM/xs0wb3JYa4cFGhyYc5OyJMn4nKMx2V6EdGsaGsXRkOo8uyGLMR1s4cqqQkbc28qmOQ9OenGHGtlk9Yrq+8aQsq0VRNSiQ00UXTfFvNtv/g8+2HuO5T7YSGWJl+uBU2tarZlgs7jw5wxlycoZ9zPaadqdfL13muU+2sTznGH1T43m1VxJBgaZcZLuY0GpFAAAOa0lEQVTK60/OMFt3nV6djhW13V+8rFFQdPGG65f+nCcHm5EniNjSI7kOjWPDePSDLAbM3MDEnq0Y0F7O6BP2MeNr2p2qBFp4u38b6kdX5d1/7iOvoIjpg1J9ouPQ3OnXRPTqdHSkwaTo4mUmfLbTkL0vRt2/qzIt60SwYnQXOjSszril2xm/TOpewj5mfU27k1KKp3/TlLf6JrPph1P0nr6WI/mFRoflMklcdtKr09HRBpOCoouGDDYzty5HhQYx74H2PHZLIxZuPMLA2Rs4cea80WEJkzPza9rd7k2NZ8FDHfjplwv0mraWzUd+Njokl0jispMznY7lnf5RXmeVM9w92MzeumwJUIz7XXOm3JfCrmNnuGfyGp9u/xWuM/tr2t06Noxm6cjOVK0SyMBZG1i57bjRITlNEpedHG3ltXW8DXC17b4yIVYL1WysR5cOtswtebR56SsSxq0kYdxKUiZ+pcsyore0Lt/Tug7LRnUmJMjCgFkb+HDjEaNDEiblLa9pd2oUG8aykZ1JjItk1Iebmf7NAYxo0HOVabsKzciRjiRb99aKCrGS8+JvK/0+i1K81S8ZoNytAa/3TgJg7JKtXLxy7f9Dq0UxqY/rG6e9qQPrdOFFxny8hW+/P8nA9nWZ0KMVVQJdn9naIl2FxbzpNQLeF6+7nL94mWeXbOXzbccZ0K4uL2ckmuJ+eF7fVejtbC3lFRRdJHNL3tXBYmvP2vWbn8sbbOlvrL4haUFxd6Ie57+ZrbOzIpGhVt4b1o63vtrLtG8OsOe/Z5kxOJWaEbLfy13M3qVnK0mZITajBVstvDsghYToqkz5135yfy5i6qC2RIZ4R8ehJC47OTpI60SFlDuTAq5JKvZsyrQ12Cqqc/lDwfl6lgDFc92akxgXybNLtnLP5DVMH9SWND+80Z4nmPnQXLMnVTMICFA8e2cz6keH8vzS7fSZvo65w9pRt3qo0aFVyvi5oZdwtJW2onXz65NKRkoca8d15dAbd7N2XFe7B1ZFRWV/KTiX566k2mSOSqdqkIWBszfwwYbDXrmOb3Zm7tLzx9Z3Z/VNq8v8B9vz45nz9Jq21isOtZbEZSdHB2lGSlyljRWuGntnM6wBNx7jYrUovyo4l6dpzXCWj+pCeuMY/pC5g3Gfbud8BfdbE44zc5eemZOqGXVuHMPSkcVNTv1nrueL7ebuOHQpcSml+iqldiqlriilTFOodgdnBumL3VtV2sXkyg0zM1LimNQ3magy69LVQq30b1eXSav2XnNNT9+Y0wwiQ63Mub8dY7o2ZlHWUfrP2sDx0/LGpRczd+mZOamaVeMa4SwbmU7LOhGM/HAzM/9t3o5Dl7oKlVItgCvATOBZTdPsamvyxg6o8s4YLK+JwpHndPaajsZpDVCgips2ynsef+i0+nLHcZ5ZvJWQIAvTBqXSvoFrdS/pKixm1teOO8aWvzh/8TLPLNnKym3HGdi+HhN7tvJYx6FHugo1Tdtd8mSuXMYjXC3WOnuydUVdTO4obpd3zfI6D8uu9/tDEbtbYm0axYYxfEE2983ewAvdWzKkY32veO2amVm79OQkeucFWy1MHpBC/eqhTPvmALk/FzJ1UFsigs3TcajLPi6l1DdUMuNSSg0HhgPUq1cv9fDhwy4/ryNs7ZeKiwph7biubnnOyj6NNhi3kvJ++wo49MbdTj2nrWuWR2G7+9GdvxcjnS66yNOLcvjnnhP0TY3n5YxEgp04yURmXMIfLN50lN8v207D2KrMHdaO+Gru7TjUbcallPoaqFXOl8Zrmrbc3oA0TZsFzILiQWbvz+nFXcXazC15vLRiJz8XFp/qHhViZUKPVkDlMxlbScOVdfiK2vDL+15/K2JHhliZPTSNt/+5j3f/uY+9Pxbv95LahxA36teuLnHVQnjsg2wypq5jzv1pJNeNMjqsypszNE27Q9O0xHL+2J20zMAdxdrMLXmM/WTr1aQFxRuMn1yUw5OLciptxy2vuK0oTnLONlCUd01rgMJquXZJrLSI7o9F7ICA4hOzZw5J5eDJc/SYsoaNB/ONDksIU0pvHMPSEZ0JtgbQf9Z6vtzxX6ND8p92eHd0QE1atfeahgd7lJ3JZKTEXXNuoYKry3zO3r6k7DUVxUt+k/omM6lP8jWPlRapzdwZ5m53tqpF5qjORIRYGfS3jby/7gejQxLClJrULO44bF4rghELs/nbfw4a2nHoUnOGUqoXMBmIBVYqpXI0TbtTl8h05o5irTPLadfPZEqL2+XV4Jxt1LBVMK/oMX8tYjeuEU7mqHSeXrSVwguyz0sIW2LDq/Dx8I48vTiHV1bu5tBP53ipRysCDTjj0NWuwmXAMp1icTu9O6AcqSdBxTMZI2tNZu0M85SIYCuzhqRihgZDpVRfYALQAmhv7xYTITwh2GphysC2/Kn6Xmb8+wC5Pxcx5b4Uwj3cceg3S4XuMPbOZjfUjmwpuzxXHn+sNZlJQIAyS2v8DqA38K3RgQhRnoCSe+G93juJNft/ou+M9R5v5pLE5YKMlDgm9Um2ebQTFM+y3u7fptIzCG3Vmm5rHut3J174M03TdmuaJgfqCdMb2L4e8x5oR97PRWRMXcv23NMee245Hd5F1y+zOXuSQHm1ptuax/Jpdp7Pbw4Wzrlub6TB0Qh/dFOTWD4Z0ZkH522i38z1vDswhd+0rOn255UbSZqYEZumhX1c2YBsz95Iezb1lyVjShjpxNnzPPJ+FtvyTvOHu1vyYHqCU0vvciNJH+Bvm4P9haZpdxgdgxB6qhEezMfDO/Hkoi28/PkuDuef44V7Wrqt41BqXCYmDRtCCG8REmRh+qBUht/ckPnrD/PI/Cx++fWSW55LEpeJ+fPmYH+llOqllMoFOlG8N3KV0TEJYa+AAMXv72rBKxmJfLvvJ0Z8kO2W55GlQhPz983B/sjb9kYKUZ7BHetTt3ooEcHuSTGSuEzO3zcHC99n1nt6Cdfc0jTWbdeWxCWEMIyr98kT/klqXEIIw1R0M1UhbJEZl85k2UMI+8mWD+EMSVw6kmUPIRzjjpup6kk+iJqTLBXqSJY9hHCMmbd8lH4QzSsoQsP5e+QJ/cmMS0ey7CGEYzy55cPR2VNFH0Rl1mUsSVw6MvuyhxBm5IktH84s48sHUfOSpUIdmXnZQwh/5swyvhy5Zl6SuHSUkRLH672TiIsKQVH5zSOFEJ7hzOzJ3R9EM7fkyb32nCRLhTpz17KHdDcJ4TxnlvHdWX+TDmTXSOLyAvIiF8I1Y+9sds0YAvtmT+76ICqNH66RpUIvIG32QrjGbMv40vjhGpdmXEqpSUB34AJwAHhA07QCPQIT/yMvciFcZ6YDq6UD2TWuzrj+ASRqmtYa+B543vWQzM/TRVXpbhLCXFx9D5AOZNe4lLg0TftK07TSW1xuAOJdD8ncjNhNLy9yIcxDj/cAsy1dehs9mzMeBBbZ+qJSajgwHKBevXo6Pq1nGVFUlRtKCmEeer0HmGnp0ttUmriUUl8Dtcr50nhN05aXfM944BKw0NZ1NE2bBcwCSEtL05yK1gSMqjfJi1wIc5Cas/EqTVyapt1R0deVUsOAe4DbNU3z2oRkLymqCuEab9+TKO8BxnOpxqWU6gY8B/TQNK1Qn5DMK3NLHud+vXTD41JvEsI+vnDiuidrznK6Rvlc7SqcAoQD/1BK5SilZugQkymVDriCoovXPF4t1CpFVSHs5At7Ej3VWOELSd5dXGrO0DStsV6BmF15Aw4gNChQkpYQdvKV+pAnas5yuoZtcnKGnXxlwAlhJNmTaD95z7FNEpedZMAJT1BKTVJK7VFKbVNKLVNKRRkdk55kT6L95D3HNklcdpIBJzzEp0+jkY239pP3HNvkdHg7ySZg4Qmapn1V5q8bgD5GxeIusifRPvKeY5skLgfIgBMe5hen0Qjb5D2nfJK4hPAwOY1GCNdI4hLCw+Q0GiFcI4lLCBMpcxrNLf5wGo0QzpCuQiHMxW9OoxHCWTLjEsJE/Ok0GiGcpYxYQldKnQQOu+nyMcBPbrq2MyQe28wUCzgWT31N02LdGYwjZEwZykzxmCkWcDweu8aVIYnLnZRSWZqmpRkdRymJxzYzxQLmi8cszPZ7kXhsM1Ms4L54pMYlhBDCq0jiEkII4VV8MXHNMjqA60g8tpkpFjBfPGZhtt+LxGObmWIBN8XjczUuIYQQvs0XZ1xCCCF8mCQuIYQQXsUnE5fZbsanlOqrlNqplLqilDKkVVUp1U0ptVcptV8pNc6IGMrEMlcpdUIptcPIOEpiqauU+pdSalfJ/6MnjI7JjGRMlRuDjKnyY3H7mPLJxIX5bsa3A+gNfGvEkyulLMBU4HdAS2CgUqqlEbGUmAd0M/D5y7oEPKNpWkugIzDK4N+NWcmYKkPGVIXcPqZ8MnFpmvaVpmmXSv66AYg3OJ7dmqbtNTCE9sB+TdMOapp2AfgY6GlUMJqmfQucMur5y9I07bimaZtL/vsssBuQGyBdR8bUDWRM2eCJMeWTies6DwJfGB2EweKAo2X+nou8Od9AKZUApAAbjY3E9GRMyZiyi7vGlNcesqvXzfg8GY8wL6VUGPAp8KSmaWeMjscIMqaEntw5prw2cZntZnyVxWOwPKBumb/HlzwmAKWUleIBtlDTtKVGx2MUGVMOkTFVAXePKZ9cKixzM74ecjM+ADYBTZRSDZRSQcAA4DODYzIFpZQC5gC7NU37i9HxmJWMqRvImLLBE2PKJxMXJrsZn1Kql1IqF+gErFRKrfLk85cU1UcDqygulC7WNG2nJ2MoSyn1EbAeaKaUylVKPWRULEA6MAToWvJayVFK3WVgPGYlY6oMGVMVcvuYkiOfhBBCeBVfnXEJIYTwUZK4hBBCeBVJXEIIIbyKJC4hhBBeRRKXEEIIryKJSwghhFeRxCWEEMKr/D/VKrdRWCwXbQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### 2D Decision Boundary\n", "##########################\n", "\n", "w, b = ppn.weights, ppn.bias\n", "\n", "x_min = -2\n", "y_min = ( (-(w[0] * x_min) - b[0]) \n", " / w[1] )\n", "\n", "x_max = 2\n", "y_max = ( (-(w[0] * x_max) - b[0]) \n", " / w[1] )\n", "\n", "\n", "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n", "\n", "ax[0].plot([x_min, x_max], [y_min, y_max])\n", "ax[1].plot([x_min, x_max], [y_min, y_max])\n", "\n", "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }