{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of *classic* logistic regression for binary class labels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from io import BytesIO\n", "\n", "import torch\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing a toy dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAACqCAYAAAD1E6s4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFftJREFUeJzt3XGMFNd9B/Dv7zbn6KRERwinWD7ugDruOchQIU6GiD9QwBXEhfiCGxTcWqZFQpEcpQkRCsgWspErUyGRBtX9AxWLVnaIqLAvDrgiNlSxauWo70wCOITIjmu4c6TgWpBIPYnj7tc/5pbbvZ3Zndl5s++9me9HQucd9mZ/rOftb/bNb35PVBVERES+aLMdABERURJMXERE5BUmLiIi8goTFxEReYWJi4iIvMLERUREXmHiIiIirzBxERGRV5i4iIjIK5+w8aLz5s3ThQsX2nhpIiNGRkY+UtUu23GUcUxRHsQdV1YS18KFCzE8PGzjpYmMEJEPbMdQiWOK8iDuuOJUIREReYWJi4iIvMLERUREXmHiyrPzx4Dv3wc8NSf4ef6Y7YiIssVjvhCsFGdQC5w/BvzkW8DEePD4xtXgMQAs3WwvLqKs8JgvDH7jyqvTe2cGcNnEeLCdKI94zBcGE1de3RhNtp3IdzzmC4OJK6865yfbTuQ7HvOFwcSVV2v3AO0d1dvaO4LtRHnEY74wmLjyaulmYONBoLMHgAQ/Nx7kRWpyh+kKQB7zhcGqwjxbupmDltyUVQUgj/lC4DcuImo9VgBSCkxcRNR6rACkFJi4iKj1WAFIKTBxEVHrsQKQUmDiIqLWYwUgpcCqQiLHiEgPgH8DcCeAKQCHVPUHdqPKACsAqUlMXETuuQXgu6r6toh8GsCIiLymqr+yHRiRCzhVSOQYVf2dqr49/d9/BHAJQLfdqIjcwcRF5DARWQhgGYCzIX+3XUSGRWT42rVrrQ6NyBomLiJHicinABwH8G1V/cPsv1fVQ6rar6r9XV1drQ+QyBImLiIHiUg7gqT1oqq+ZDseIpcwcRE5RkQEwGEAl1T1gO14iFzDxFVEprtyk2mrADwKYI2I/GL6z4O2gyJyRepy+MLcc5IXWXXlJmNU9b8AiO04iFxl4htX+Z6TLwBYCeBxEVlsYL+UBXblJiLPpU5cvOfEM+zKTUSeM9o5o9E9JwC2A0Bvb6/Jl6UkOucH04Nh24lMOrEDGDkC6CQgJWD5VmADa00oPWPFGbznxLATO4Cn5wJPdQY/T+wws1925aZWOLEDGD4cJC0g+Dl82NxxTIVmJHHxnhPDshz07MpNrTByJNl2ogRMVBXynhPT6g16E1Mt7MpNWSufdMXdTpSAiW9cvOfENA568p2Ukm0nSiD1Ny7ec5IBKYUnKQ568sXyrcH0dth2opTYOcM0E10pogZ30kHPDhlky4YDQP+2mZMtKQWPTUx187guPC4kaZKprhTlwZ2mlJgdMsi2DQfMl7/zuCYwcZlVrytF0kGVdtCbjIXIFTyuCZwqNMulrhQuxUJkCo9rAhOXWVHdJ2x0pXApFiJTeFwTmLjMWrsHaGuv3tbWXr8rRVYXmtkhg/KIxzWB17jME6n/uFKWF5rLv396bzCN0jk/GNy8DkA+43FNYOIy6/ReYPJm9bbJm9EXjrO+0MwOGZRHPK4Lj1OFJiW9cMwLzUREiTFxmZT0wjEvNBMRJZbvxJXlHfZh+0564ZgXmimCiDwvIr8XkYu2YyFyTX4TV7nw4cZVADpT+GAieUXtG0i2ZAiXGKFoRwCstx0EkYvyW5yRZeFDvX1/52Ky/fNCM4VQ1TemVxT31/ljra/+s/Ga1HL5TVxZFj6wqIIcICLbAWwHgN7eXsvRzGKjpyD7GBZGfqcKsyx8YFEFOUBVD6lqv6r2d3V12Q6nWr1ZiTy9JlmR38SVZeHD2j1A26y1sdpKwfaoghAuxUBFYmNWgjMhhZHfqcIs77C/MgRMzVrocWoSOPcCMPrftVMVV4aAX/6QUxhUHJ3zp4uXQrbn6TXJivx+4wKCpPCdi8BT15MXTdQzciR8+/s/C5+qGDnCKQxKRESOAvg5gD4RGRWRbbZjSsTGrR68vaQw8vuNK0s62fg5cZ7PKQyKoKpbbMeQytLNwUxD5WKof/aImZPHqMpB9jEsDCauZkgpWfKKej6nMCivzh8LpsfLx71OBo97V6ZLJI0qB3l7SSHke6owaUHEiR3A03OBpzqDnyd2hD9v+dbw7YtWh09VLN/qznInTRg8N4ZV+85g0a6TWLXvDAbPjVmLhTyRVYUfKwcJeU5cSTtnnNgBDB+uPkMcPhydvMJ89vPhnTB6Vza33EkWXT8SGjw3ht0vXcDY9XEogLHr49j90gUmL6ovqwo/Vg4S8py4kp6ZRRVchG2v99ywgpB6y52YiD1D+09dxvhE9TTn+MQk9p+63PJYyCNZ3evIeygJeU5cSc/Moq5ZhW1P8txmYnHorPLD6+OJthMByK7Cj5WDhDwnrqRnZlKKvz3Jc5uJxaGzyrvmdCTaTgQguwbSbEwNgNed85u4kp6ZRRVchG2v99ycLXeyc10fOtqrE3JHewk71/W1PBbyTFb3UWa1X0/wunOeE1fSM7MNB4D+bTPfmqQUPN5woPa5vStrv12VH+dsuZOBZd14dtMSdM/pgADontOBZzctwcCy7pbHQkS87gwAoqotf9H+/n4dHh5u+esa8/37wlvLRN6v1ROcGVJuiMiIqvbbjqPM+zFFsS3adRJhn9oC4P19f9HqcIyKO67y+40rS0kLPFiqS0SG8LozE1dzkhZ4sFSXiAzhdWffWj4lXd006vlpV0lduwf48ePV92aV7gCWPVrdBR7wqlR38NwY9p+6jA+vj+OuOR3Yua6vcNeycv0ecHXgXCgfj1HHaa6P4Wn+JK6kq5tGPd/UEiOzrw2qBkUbvSu9/HAoVyqVL/qWK5UA5O6gj5Lr94CrA+fKwLLu0GMy18dwBX+mCpN2k4h6voklRk7vBaYmqrdNTQTbPS3VZaVSzt8Dh7qxUHZyfQxX8CdxmeoyYaKAwqHOFqawQ0bO34McHrNUK9fHcAV/EpepLhMmCigc6mxhCiuVcv4e5PCYpVq5PoYr+JO4THWfWL41fVcKhzpbmMJKpZy/Bzk8ZqlWro/hCkaKM0RkPYAfACgB+BdV3Wdiv1WSrm4atQLrhgPRBRRRVVf/+hXg/Z/N7HvR6qCThYdFGFEaVSq5zkQlle/vQV1cHbgQsjqG04yvLKocU3fOEJESgN8A+HMAowDeArBFVX8V9Tstuct/dhUVEJxhRrVOinp+5wLgo1/XPn/RauCxV8zHTYnNrqQCgrPMLFtTZd05I+nJIDtnUFbSjK+kv9vKzhn3A3hXVX+rqjcB/AjAQwb2m46pKsSwpAVUfwMjq/JWSTV9MvgcgC8DWAxgi4gsthsVFVWa8ZXV2DSRuLoBVDbuG53eVkVEtovIsIgMX7t2zcDLNuDBWldkRg4rqdw8GaRCSjO+shqbJhJX2Br0NfOPqnpIVftVtb+rq8vAyzbgwVpXZEYOK6ncPBmkQkozvrIamyaKM0YB9FQ8ng/gw1R7NNGqae2e8GtW9aoQX/5G9X1eUgI+e0/4dOG8e6e7xLtxofvJwQs4evYqJlVREsGWFT3oXzA30UXRpBdRbbSWCXvNnev6sPPff4mJqZnzpfY28bmSKvbJIIBDQHCNK+ugqJh2rusLvU4VZ3yl+d16TCSutwDcIyKLAIwB+DqAR5rem6lWTUmrqK4M1d6crJPApz8HfHQZNZ8bH7830z3DcvucJwcv4IWhK7cfT6rihaEr+OHQFUxNb2vU+iVpqxgbrWWiXvPh5d21H/VhH/3+MH8ySNSkNJWKWVU5GlmPS0QeBPCPCCqgnlfVv6/3/LoVULbWunp6bnRXjbgsrbt19+5XMRnz/2P3nA68uWtNzfZV+85gLGTe2dTzTYh6zZJI6L8/y1iyrCoUkU8gqNRdi+Bk8C0Aj6jqO1G/U6SqwiI0kS2quOPKyH1cqvoqgFdN7MvaWldpkxZgrcAjbtICkl8sNbXdhKh9R/37fS3OUNVbIvJNAKcwczIYmbSKpChNZKk+9zpn2FrrKmr/SVgq8ChJ/HmxpBdLTW03IWrfUf9+j4szoKqvquqfqurdjWYwiiRvtz5Qc9xLXPVaNbW1V29vazfXsmb51vDti1bXxlO6ozYWi+1ztqzoafwk1C9YSNoqZue6PrS3VSeMZgsiBs+NYdW+M1i06yRW7TuDwXNj0a9ZmvWapaAQpQhtbiiXtz5QE9xLXEs3B90tOnsASPBz48GgTdPsM+sE3zQa2nAA6N82881LSsHjx16pjeeh54CBf66N0VJVYf+CuSjNSiJtEvypUuftGljWjWc3LUH3nA4IgutDDe+MN1AQUZ76Gbs+DsXM1E9U8qqprdPg3584dvJSDm99oCYYKc5IqqkLyVFFG5YKIlwSVbQQxlTBgqnijCT7sVEQEiXrlk9JFaU4w0Z7L2qdlhZntAQ7XkRKMk1iakrF1JRNkv1wmoh8a4TcbAUkKyfr8ydxdc6P+MbFjhd3zemI/Y3L1JRK1Gsm3X+S/Zh6TfJb1LL1rmm2ApKVk425d40rCtcTihRWWNFekppLTiWD3SRMFWdEFYV86d6umoKNZtYailv4QWRasxWQrJxszJ/EFVW0wfWEQgsr7l/4mZo6hskpxfAHH5t7YQPFGWGxP7y8G8dHxmoKNgAkKsJIXPhBZFCzU9ucEm/Mn6lCIEhSTFShZk+f3L07/H7wo2ev4pmBJalfb/+py5iYrE6NE5OK/acuN7WAY+XvrNp3JvKM881da2Lvv96ZK6dcKGvNTm1zSrwxf75xUSJR3SSSdNmoJ8uzQhuFH0SmNTO1neb3isSvb1wUW1T/viRdNurJ8qzQRuEHkWmNKiCjKgd9q5y0gYmrRbIub529/5V/8hm8+V7t9awtK3qMLF+S1XIFgLmlELKMkSiOqArIRpWDvlRO2sKpwhbIukggbP9vX7mBVXfPvf0NqySCv17Zi/4FcxPFEhU7kKxQIommunhkuB8i01g5mA6/cbVA1kUCUfv/n/8dx3vPPli1vV7hQ1gs9WJPUiiRlKkzTp65kot4/TUdfuNqgawP0iy7T3CAEZnHnovpMHG1QNYHaZL9+7B8CVHesXIwHSauFjB5kIZ1gsiy+wQHGJH5Diy8/pqOP93hPWeiqrBeZ2ygunz2S/d24fjIWKznNlNVWPQBxu7wxcGO9K0Td1wxcXnE1yVA8oiJqzg4llon7rjiVKFHuARI/onI10TkHRGZEhFnEmORcSy5h4nLI1kWYZAzLgLYBOAN24FQgGPJPUxchmW5jMbOdX1oL81aSqQUvpSIi0UVXGKkMVW9pKq8C9UhLo6lomPiMqgly2jMviQZcYnStaolLjFinohsF5FhERm+du2a7XBya2BZNx5e3l3Vhebh5byx3SZ2zjCoFR0yJqZmLSUyFb2UiEtdI7jEyAwReR3AnSF/9YSq/jjuflT1EIBDQFCcYSg8mmXw3BiOj4zdblo9qYrjI2PoXzC3cMeuK5i4DHKpQ4ZrfI7dNFV9wHYMFB9PutzDqUKDXOqQ4RqfY6di40mXe5i4YohbVJD1RVyfLxL7HHsrichXRWQUwBcBnBSRU7ZjKjqedLmHiauBJEUFWRdEuFZwkYTPsbeSqr6sqvNV9ZOq+jlVXWc7pqLjSZd72DmjAd41T2HYOcM/aVqXse1Za8QdVyzOaIDz20T+a7TicCMuVegSpwob4vw2kf+44nC+5OMb1/ljwOm9wI1RoHM+sHYPsHSzkV3vXNcX2hna1vy2z1MWPsdOfuPMSb74n7jOHwN+8i1gYvoAvHE1eAwYSV7lD1YXPnDTTnfY5HPs5L+75nSEXqvmzImf/E9cp/fOJK2yifFgu6FvXa7Mb/t8I6TPsZP/XJs5oXT8T1w3RpNt95jP0x0+x07+eHLwAo6evYpJVZREsGVFD54ZWOLUzAml53/i6pwfTA+Gbc8Zn6c7fI6d/PDk4AW8MHTl9uNJ1duPy8mLiSofUlUVish+Efm1iJwXkZdFZI6pwGJbuwdon/Xh194RbM8ZX26EDOs04kvs5K+jZ0NOYOtsJ3+lLYd/DcB9qroUwG8A7E4fUkJLNwMbDwKdPQAk+LnxoLHrWy7xoftEVKcRAM7HTn6bjGimELWd/JVqqlBVf1rxcAjAX6YLp0lLN+cyUYVxfbqjXhHGm7vWOB07+a0kEpqkyutoUX6YvAH5bwH8R9RfctG7YmARBtmyZUVPou3kr4bfuOIseiciTwC4BeDFqP1w0btiYBFGsbh0U/kzA0sAILSqkPKlYeJqtOidiDwGYAOAtWqjYy85hffLFIeLN5U/M7CEiaoA0lYVrgfwPQBfUdX/MxMS+cyHAhIyg/3/yJa093H9E4BPAnhNggugQ6r6jdRRkddcLyAhM3g9k2xJW1X4eVOBEJFfeD2TbOGyJkTUFN5UTrb43/LJEpeqqSg/RGQ/gI0AbgJ4D8DfqOp1u1GFY/+/xvg5kQ0mria4WE1FufEagN2qektE/gFBN5rvWY4pEq9nRuPnRHY4VdgEVlNRVlT1p6p6a/rhEID8dYsuCH5OZIeJqwmspqIWYTcaj/FzIjtMXE2IqppiNRXFISKvi8jFkD8PVTwnVjcaVe1X1f6urq5WhE4J8HMiO0xcTWA1FaWhqg+o6n0hf8ot1MrdaP6K3Wj8xc+J7LA4owmspqKsVHSjWc1uNH7j50R2mLiaxGoqygi70eQIPyeywcRF5BB2oyFqjNe4iIjIK2Lj2q+IXAPwQctfuL55AD6yHUQEV2MrclwLVNWZUr6UY8rV/49xMHY7soo91riykrhcJCLDqtpvO44wrsbGuPLB5/eLsdthO3ZOFRIRkVeYuIiIyCtMXDMO2Q6gDldjY1z54PP7xdjtsBo7r3EREZFX+I2LiIi8wsRFREReYeKqICJfE5F3RGRKRKyXqYrIehG5LCLvisgu2/GUicjzIvJ7EbloO5ZKItIjIv8pIpem/z/+ne2YfODacR+Hq2MjDlfHTyMujS8mrmoXAWwC8IbtQESkBOA5AF8GsBjAFhFZbDeq244AWG87iBC3AHxXVb8AYCWAxx16z1zmzHEfh+NjI44jcHP8NOLM+GLiqqCql1TVleVJ7wfwrqr+VlVvAvgRgIca/E5LqOobAD62Hcdsqvo7VX17+r//COASAHY4bcCx4z4OZ8dGHK6On0ZcGl9MXO7qBnC14vEo+CEcm4gsBLAMwFm7kVAGODYssz2+CtcdXkReB3BnyF89UV7IzxESso33LsQgIp8CcBzAt1X1D7bjcYFHx30cHBsWuTC+Cpe4VPUB2zHENAqgp+LxfAAfWorFGyLSjmBQvaiqL9mOxxUeHfdxcGxY4sr44lShu94CcI+ILBKROwB8HcArlmNymgQrLx4GcElVD9iOhzLDsWGBS+OLiauCiHxVREYBfBHASRE5ZSsWVb0F4JsATiG4CHpMVd+xFU8lETkK4OcA+kRkVES22Y5p2ioAjwJYIyK/mP7zoO2gXOfScR+Hy2MjDofHTyPOjC+2fCIiIq/wGxcREXmFiYuIiLzCxEVERF5h4iIiIq8wcRERkVeYuIiIyCtMXERE5JX/B/s1FOrL5PWOAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### DATASET\n", "##########################\n", "\n", "ds = np.lib.DataSource()\n", "fp = ds.open('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')\n", "\n", "x = np.genfromtxt(BytesIO(fp.read().encode()), delimiter=',', usecols=range(2), max_rows=100)\n", "y = np.zeros(100)\n", "y[50:] = 1\n", "\n", "np.random.seed(1)\n", "idx = np.arange(y.shape[0])\n", "np.random.shuffle(idx)\n", "X_test, y_test = x[idx[:25]], y[idx[:25]]\n", "X_train, y_train = x[idx[25:]], y[idx[25:]]\n", "mu, std = np.mean(X_train, axis=0), np.std(X_train, axis=0)\n", "X_train, X_test = (X_train - mu) / std, (X_test - mu) / std\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(7, 2.5))\n", "ax[0].scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1])\n", "ax[0].scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1])\n", "ax[1].scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1])\n", "ax[1].scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Low-level implementation with manual gradients" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "def custom_where(cond, x_1, x_2):\n", " return (cond * x_1) + ((1-cond) * x_2)\n", "\n", "\n", "class LogisticRegression1():\n", " def __init__(self, num_features):\n", " self.num_features = num_features\n", " self.weights = torch.zeros(num_features, 1, \n", " dtype=torch.float32, device=device)\n", " self.bias = torch.zeros(1, dtype=torch.float32, device=device)\n", "\n", " def forward(self, x):\n", " linear = torch.add(torch.mm(x, self.weights), self.bias)\n", " probas = self._sigmoid(linear)\n", " return probas\n", " \n", " def backward(self, probas, y): \n", " errors = y - probas.view(-1)\n", " return errors\n", " \n", " def predict_labels(self, x):\n", " probas = self.forward(x)\n", " labels = custom_where(probas >= .5, 1, 0)\n", " return labels \n", " \n", " def evaluate(self, x, y):\n", " labels = self.predict_labels(x).float()\n", " accuracy = torch.sum(labels.view(-1) == y) / y.size()[0]\n", " return accuracy\n", " \n", " def _sigmoid(self, z):\n", " return 1. / (1. + torch.exp(-z))\n", " \n", " def _logit_cost(self, y, proba):\n", " tmp1 = torch.mm(-y.view(1, -1), torch.log(proba))\n", " tmp2 = torch.mm((1 - y).view(1, -1), torch.log(1 - proba))\n", " return tmp1 - tmp2\n", " \n", " def train(self, x, y, num_epochs, learning_rate=0.01):\n", " for e in range(num_epochs):\n", " \n", " #### Compute outputs ####\n", " probas = self.forward(x)\n", " \n", " #### Compute gradients ####\n", " errors = self.backward(probas, y)\n", " neg_grad = torch.mm(x.transpose(0, 1), errors.view(-1, 1))\n", " \n", " #### Update weights ####\n", " self.weights += learning_rate * neg_grad\n", " self.bias += learning_rate * torch.sum(errors)\n", " \n", " #### Logging ####\n", " print('Epoch: %03d' % (e+1), end=\"\")\n", " print(' | Train ACC: %.3f' % self.evaluate(x, y), end=\"\")\n", " print(' | Cost: %.3f' % self._logit_cost(y, self.forward(x)))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | Train ACC: 0.000 | Cost: 5.581\n", "Epoch: 002 | Train ACC: 0.000 | Cost: 4.882\n", "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n", "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n", "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n", "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n", "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n", "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n", "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n", "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n", "\n", "Model parameters:\n", " Weights: tensor([[ 4.2267],\n", " [-2.9613]], device='cuda:0')\n", " Bias: tensor([0.0994], device='cuda:0')\n" ] } ], "source": [ "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n", "\n", "logr = LogisticRegression1(num_features=2)\n", "logr.train(X_train_tensor, y_train_tensor, num_epochs=10, learning_rate=0.1)\n", "\n", "print('\\nModel parameters:')\n", "print(' Weights: %s' % logr.weights)\n", "print(' Bias: %s' % logr.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluating the Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test set accuracy: 100.00%\n" ] } ], "source": [ "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", "\n", "test_acc = logr.evaluate(X_test_tensor, y_test_tensor)\n", "print('Test set accuracy: %.2f%%' % (test_acc*100))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAADFCAYAAAAMsRa3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VNX9//HXISYk7AJhSwiLIotAWAIJpWoVLVgXKCiWVUBC69e61P5wqVT9ClrQ1hWFsgsEkEXBokJZxIWvBBLWQMAYEEhAloSIbCYk5/fHJDHLJJnJ3Jl778zn+Xj4eJibyZ0PkDPvued+5hyltUYIIYSwixpmFyCEEEK4Q4JLCCGErUhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSsSXEIIIWzlGjOetHHjxrp169ZmPLUQhkhOTj6rtQ43u44iMqaEP3B1XJkSXK1btyYpKcmMpxbCEEqpo148dyjwJVATxxhdqbV+obKfkTEl/IGr48qU4BJCVOpn4Dat9QWlVDDwtVLqM631NrMLE8IKJLiEsBjtWPn6QuGXwYX/yWrYQhSS5gwhLEgpFaSU2g2cBjZorROdPGaCUipJKZV05swZ3xcphEnkiksIC9Ja5wPdlFINgI+UUp211illHjMLmAUQExNT7oosLy+PjIwMrly54pOa7SI0NJTIyEiCg4PNLkVUk8fBVZ0byUII12itc5RSW4ABQEoVDy8lIyODunXr0rp1a5RSXqnPbrTWZGVlkZGRQZs2bcwux/JyLuXy9Kq9TLqrEy0b1jK7nGJGTBUW3UiOBroBA5RScQacVxhs9a5M+k7dTJtnPqHv1M2s3pVpdknCCaVUeOGVFkqpMOB24KC757ly5QqNGjWS0CpBKUWjRo3kKtQF56/k8eC87Xx+8AxHsy6ZXU4pHl9xyY1ke1i9K5NnP9zH5bx8ADJzLvPsh/sAGNQ9wszSRHnNgfeVUkE43lwu11qvrc6JJLTKk7+Tql38+Spj5+9g/4nzzBjZk1+3a2x2SaUYco+rcIAlA9cD71Z0IxmYABAVFWXE0wo3vLb+UHFoFbmcl89r6w9JcFmM1nov0N3sOkRgupybz0Pv72DXsXNMH96DOzo1NbukcgzpKtRa52utuwGRQG+lVGcnj5mltY7RWseEh1tmwYGAcSLnslvH7UhrzYYDp3BMAgirevHFF/nnP//plXMnJyfTpUsXrr/+eh577DH5XXDTlbx8JixKIvFINq8P7cbvujQ3uySnDG2H11rnAFtw3EgWFtKiQZhbx+1Ga83Ln6QSvzCJ9ftPmV2OLfnDPdCHH36YWbNmkZaWRlpaGuvWrTO7JNvIvVrAn5fs5Ku0s0wb3NXSMzEeB5dRN5KFd03s356w4KBSx8KCg5jYv71JFRlHa80rn6Yy5+sjjPlVa/rfaL2pDasrugeamXMZzS/3QD0Nr4ULF9K1a1eio6MZNWpUue/Pnj2bXr16ER0dzZAhQ7h0ydEEsGLFCjp37kx0dDQ333wzAPv376d3795069aNrl27kpaWVupcJ0+e5Pz58/Tp0welFKNHj2b16tUe1R8oruYX8PiyXWxMPc3kgTcytFdLs0uqlBH3uAy7kSy8p+jd02vrD3Ei5zItGoQxsX97S7+rckVRaM3+6ggP9mnFC/d0kpvv1eCNe6D79+/n5ZdfZuvWrTRu3Jjs7Oxyjxk8eDDx8fEATJo0iblz5/Loo4/y0ksvsX79eiIiIsjJyQFg5syZPP7444wYMYLc3Fzy80vXm5mZSWRkZPHXkZGRZGba76rR1/ILNH9dsYfPUn5g0l0dGdWntdklVcmIrkK5kWwTg7pH2D6oStJa84/PDhaH1ov33iihVU3euAe6efNm7rvvPho3dnSkNWzYsNxjUlJSmDRpEjk5OVy4cIH+/fsD0LdvX8aMGcPQoUMZPHgwAH369OHll18mIyODwYMH065du1LncnY/qzq/D6t3ZfrdG7yKFBRonv1wL2t2n2Bi//aMv6mt2SW5RJZ8EraktWbqZweZ9eVhRktoecwb90C11lX+m4wZM4bp06ezb98+XnjhheLPV82cOZMpU6Zw/PhxunXrRlZWFsOHD+fjjz8mLCyM/v37s3nz5lLnioyMJCMjo/jrjIwMWrRo4VbN3poytSKtNc9/nMLypAweu+16Hrn1erNLcpkEl7CdotD695eHGRXXiv+V0PKYN+6B9uvXj+XLl5OVlQXgdKrwp59+onnz5uTl5ZGQkFB8PD09ndjYWF566SUaN27M8ePHOXz4MG3btuWxxx7j3nvvZe/evaXO1bx5c+rWrcu2bdvQWrNw4UIGDhzoVs2VTZn6E601Uz5JZfG2Y/zx5rb85Y4bzC7JLbJWobAVrTVT1/0SWi8NlNAygjfugd54440899xz3HLLLQQFBdG9e3cWLFhQ6jGTJ08mNjaWVq1a0aVLF3766ScAJk6cSFpaGlpr+vXrR3R0NFOnTmXx4sUEBwfTrFkznn/++XLPOWPGDMaMGcPly5e58847ufPOO92qOVA+NvLa+kPMLWxmeubODrYbQ8qMzznExMRo2fROuEtrzbR1h5j5RToj46KYPLCzaQNOKZWstY4x5cmdcDamUlNT6dixo0kVWVtFfzd9p24m00lIRTQIY+szt/miNK97e1Mar2/4lmG9o3jl9+aNIWdcHVcyVShsQWvNq+sdoTUiNoqX7rXWgBP+wZ8/NgIw84t0Xt/wLYN7RPDyIPuOIZkqFJZXFFoztjhCa/LAztSoYc8BJ7zr3KVcTv14hdz8AkKCatC0fijX1gpx+ef99WMjAPO3HmHqZwe5u2tzXh3S1dZjSIJLWFrRfPyMLekMl9ASlTh3KZfMc5cpKLz9kZtfQOY5x7Sfu+HlD0FV0pLEY/zvfw7w205NeeOBblwTZO/JNntXL/ya1pp//vcQ721JZ1jvKKZIaIlKnPrxSnFoFSnQmlM/BvYWJiuTM3hu9T5ubR/OO8O7E2zz0AIJLmFRRaH17ueO0Hp5kISWqFxufoFbxwPBx3tO8NTKPfzqukbMGNmTmtcEVf1DNiDBJSxHa82//vttYWi1lNASLgmp4EqiouP+bl3KD/zlg93EtGrI7NExhAb7R2iBBJewGK01r2/4lumff8cferXk5UFdJLT8jLe2NWlaP5Tpr07ht71vJK69Y83CGkrRtH6o4c9ldZ8fPM2jS3fSNbI+88b2olaIf7Uz+NefRtia1po3NnzLO5sdofXK7yW0fOaVCMi9UP54SB34mz2WO7q2Vgj3DR7IyHETGNC3R7W6Cv3B12ln+ePiZNo3q8uCsb2pU9P/XubliktYQlFovb35Ox6ICezQUkq1VEp9rpRKVUrtV0o97vUndRZalR13kS+3NQH47W9u4uZuN1BDQYfm9QIutBIPZzF+4Q7aNq7NonGx1A8LNrskr/C/KBa29MbGtOLQ+sfgwA2tQleBv2qtdyql6gLJSqkNWusDZhfmDl9vaxLoko+eY9yCHUQ0CGPRQ7FcW9t/Q1uuuITp3tjwLW9vSmNoTKSEFqC1Pqm13ln4/z8BqYDtPljk6rYmN910E126dCEhIYH9+/cDv2xrMnv27OKA6tOnD6+88grTpk3j6NGjhIX5x+7dRtiX8SNj5m+ncd2aLImPI7xuTbNL8ioJrgBn9nbtb2z4lrc2pXF/z0imDrb3p/m9QSnVGsd+d4lOvjdBKZWklEo6c+aMr0urkq+3NQlUqSfPM2peIvVCg1kSH0fTev7fjOJxcJkyHy8MYfbeQ29u/CW0ptl8CRpvUErVAVYBT2itz5f9vtZ6ltY6RmsdEx4e7vsCq+DrbU0C0Xenf2LknERCrwliaXwcER7sn2YnRlxxFc3HdwTigEeUUp0MOK/wMjP3HnprYxpvbkzjPgktp5RSwThCK0Fr/aHXnzCkjnvHXVByW5Po6GiefPLJco8p2tbkjjvuoEOHDsXHJ06cSJcuXejcuTM333wz0dHRfPDBB3Tu3Jlu3bpx8OBBRo8eXe58Tz31FJGRkVy6dInIyEhefPHFatdvdUfOXmT47ESUUiyJjyWqUS2zS/IZw7c1UUqtAaZrrTdU9BjZ1sQa2jzzCc7+9RVwZOpdXnvetzam8cbGb4tDK8iGoeXNbU2UY37tfSBba/2EKz8j25q4x+5/N8ezL/HAv7/hytUClk2I44amdc0uyRCmbGti5/n4QOSN7dqr8vYmR2gN6WHf0PKBvsAo4Dal1O7C/35ndlHCGk7+eJnhc7Zx4eerLHqot9+EljsMa4d3ZT4emAWOd4dGPW+gmLR6H0sTj5OvNUFKMSy2JVMGdfHonBP7t+fZD/eVmi705t5D7xRuYDe4RwSv3iehVRGt9dc4LnyFKOX0+SsMn53IuYt5JIyP5cYW9c0uyRSGBJfP5+MDzKTV+1i87Vjx1/laF3/tSXj5cu+hdzal8a8N3zK4ewSv3RctoeUjrnT2BRozdn03QtaFnxkxJ5FT56+wcFxvols2MLsk03gcXIXz8XOBVK31656XJMpamni8wuOeXnX5Yu+h6ZtLhNb9Elq+EhoaSlZWFo0aNZLwKqS1Jisri9BQe7WM51zKZeTc7RzLvsT8sb2IaV3+M3GBxIgrrqL5+H1Kqd2Fx/6mtf7UgHMLHFdY7hy3knc//45//ldCywyRkZFkZGTgr/eUcy7lcvHnfDSOedXaNYNo4MIST6GhoURGRnq9PqOcv5LHg/O2k376ArMfjOFX1zU2uyTTeRxcMh/vfUFKOQ2pIIu/i3738+94bf0hfi+hZYrg4GDatGljdhle4Zg+zyh3fGRclMezEFZy8eerjJ2/g/0nzjNzZE9uucF6n9czg6ycYQPDYlu6ddwKikJrULcW/FNCSxissulzf3E5N5+H3t/BrmPneHtYd27v1NTskixDFtn1kdW7MqvdBFH0DtLTrkJPanDHe1t+Ca1/De0moSUMZ+fpc1dcyctnwqIkEo9k88bQbvyuS3OzS7IUCS4fKFpaqajtvGhpJcCt8PJkCsSIGlwxY0s6r647xEAJLeFFdp0+d0Xu1QL+vGQnX6Wd5dUhXb3ePGVHMlXoA2YureTLGmZsSWfauoPcG92Cf8n0oPAiI6bPzV5g2pmr+QU8vmwXG1NPM3lQZ4b2su7tADPJFZcPnMi57NZxO9Yw84tfQuv1odFcEyTviYT3eDp97qsZCHfkF2j+umIPn6X8wKS7OjIqrpUpddiBBJcPtGgQRqaTgPDm0kq+rOHfX6Qz9bOD3COhJXzIk+nzymYgzAiuggLNM6v2smb3CSb2b8/4m9r6vAY7kVcYH5jYvz1hwUGljnlzaSVf1vDvL9L5R2FovSGhJWzCCrMgRbTWPP9xCiuSM3isXzseufV6n9dgN3LF5QODukeQdDS71LTGkJ4Vr1jhje4/byzvNOtLR2jd3bW5hJawFSvMgoAjtCavTWXxtmP88Za2/OX2dj59fruS4PKB1bsyWZWcWdwFla81q5IziWnVsFxweHPu3cjlnWZ/eZhXPnWE1psPdJPQErbi6wWmndFa8+r6Q8zbeoQxv2rNMwM6yNJcLpJXGx9wp6PPCh2IVZnz1WFe/jSVuyS0hE0N6h7BPwZ3IaJBGAqIaBDGPwZ38en9rbc3fceMLekM6x3FC/d0ktByg1xx+YA78+lWmnt3Zs5Xh5nySSp3dWnOWxJawsZ8scB0RWZ+kV68L93LgzpLaLlJXnV8wJ0NG83Y3NFVJUPrzT9IaAlRHfO3Hinuwn31vq7UkM87ui1gX3m89eFDZ+d1p6PPCh2IzhSF1u+6NOPNP3QjWEJLCLclJB7lf/9zgP43NuX1ofIh/eoKyFefogaIzJzLaH5pgPA0vCo6L+DyfLoV5t7Lmvv1EaZ8ksqdnZvx1h+6S2j5gFJqnlLqtFIqxexahDFWJmfw3Ecp3No+nHeG9ZBx5IGAvMflrQ8fVnberc/c5vK5zZx7L2vu10eYvPYAd3ZuxtvDJLR8aAEwHVhoch3CAB/vOcFTK/fw6+sbM2NkT0KukXHkiYAMLm81QFi9scJd8yS0TKO1/lIp1drsOqzMW7sdGH3edSk/8JcPdhPTuiGzRvcktMytAOE+Q16J7Dat4a0GCCs3Vrhr/tYjvLT2AANulNCyKqXUBKVUklIqyV93Oa6Ir6f7q3vezQdP8ejSnXSNrM+8Mb2oFRKQ1wqGM+rVaAEwwKBzeZ23GiCs2ljhrgVbjxTfQH5nuISWVWmtZ2mtY7TWMeHhgbUzrrc+72jkeb9OO8ufFu+kfbO6LBjbmzo1JbSMYsjfpN2mNbyx/FHReVckHWNrenbxsR5R9RnUPcLp9IM3avDUgq1HeLEotOQGsrAoq0/3Jx7OYvzCHbRtXJtF42KpHxbsUV2iNJ+9BVBKTQAmAERFRfnqaSvkjQaISav3lQotgK3p2YyY/Q07j/1YahmniSv3gIa8Al18zOxtFd7/v+958T8H+G0nR2jJDWRhVd5aa9CI8yYfPce4BTuIaBDG4vGxXFs7xKOaRHk+e2UKhGmNpYnHnR7fmp5dbvohL18Xh1YRM5d2WvjN97zw8X7u6NSU6cMltMymlFoKfAO0V0plKKUeMrsmK7HqdP/ejBzGzNtOeN2aLImPo3Gdmh7VI5yTSVcDOdtK3F1mdCAu/OZ7nl/jCK13JbQsQWs9zOwarMzdHRecqax7sDpT+KknzzNq7nbqhQWTEB9H03qh1f7zicpJcBkoSCmPw8vXHYiLvpHQEvbjzo4LFf18ZbswuDtd/93pnxg5J5Gw4CCWxscRYcNOYjsxqh1epjWAYbEtnR7ve13DctMPwUGK4DLLvfi6A3HRtqP8fc1+bu8ooSXsxdPuPyO7B4+cvcjw2YkopVgSH0tUo1pun0O4x6iuQttNa7jzIcNJq/eVmpIYFtvS6ZbhUwZ1IfFwFmmnLxYfa9ekNgnxfZw+n5mbSy7adpS/r07h9o5NeG+EC6H1SgTkXih/PKQO/M2YdR6FcJWn3X9GdQ8ez77E8NnbuFqgWTYhjrbhddz6eVE9ATlV6M5mjZNW72PxtmPFX+drXfx12fCatHpfqdACSDt9kUmr9zFlUOn1Bs3cXHJxidB615XQAuehVdlxIbzI0+4/I7oHT/54meFztnEpN5+l8XHc0LSuyz8rPBOQc0PuTBNU1Cno7Lg7jzVrc8mExKNMWp1Cvw6O0Kp5jSw/I+zH0+4/T3/+9PkrDJ+dSM7FPBaO602nFvVcK1wYIiCvuNyZJqio2cLZcXcea8bmkgmJR3nuoxRu69CE90ZKaAn78nQRAU9+PuvCz4yYk8ip81dYOK430S0bVP8P4okAnr4PyOByZ5qgok7BICc7lrrzWHdqMGJaY0niseLQmiGhJfyAp4sIVOfncy7lMnLudo5lX2LB2N7EtG5Y7ef3WABP3wfkVKE70wQVdQo6O17ZY8tuMHlrh3CfbS65JPEYf/ton4SWEB44fyWP0fO2k376ArNHx9DnukZmlxSwAjK43NmsccqgLoyMiyq+agpSipFxUU67CmNaNaTshqZFX5ddcXpVciZDekZ4fXPJpdsdoXVr+3DPQiukgm6pio4L4Ucu/nyVsfN3cODEed4b0YObb/DP1X/sQmkDVntwV0xMjE5KSvL583pb36mbnU7pVTSFGNEgjK3P3Oa1epZtP8YzHxaFluwDZCSlVLLWOsbsOor465iygsu5+YxdsJ0d359j+rDu3NmludklObxYv5Lv/ei7Ogzk6rgKyHtc3lJRs0RFTRveXN6pKLR+I6ElTOatDR994UpePhMWJZF4JJs3H+hmndCqirNQ86OmjYCcKvSWipolnDVnVPZ4T32wwxFat9wQzkwJLWEib2346Au5Vwt4JGEnX6WdZdqQrgzsZrGwdXea3o+aNuSKy0AT+7cv9UFhcDRRDOkZwarkzHLHvbG80/Idx4tD69+jJLSEuSr7DKKVr7qu5hfw+LJdbDp4msmDOjM0xnnjlakqunqqbArRT/hVcLkzJVHRYz2Z1qhoxeopg7oQ06qh16dLlu84ztMf7uWmdh6GVqB9PiTQ/rw+5K0NH70pv0Dz1xV7+CzlBybd1ZFRca3MLkmU4TfB5c6ySBU9NulodqkrI3eXVqpqGSdvvsNcnvRLaM3y9Eor0D4fEmh/Xh/y1oaP3lJQoHlm1V7W7D7BUwPaM/6mtmaXJJzwm3tcRiyhtDTxuGVWnHbHiqTjPL1qL7++vrHnoSWEgby14aM3aK35+5oUViRn8Hi/dvzPb643uyRRAb+54jJiCSVPu//MmBZZkXScpwpDa/boGAktYSmeLs3kK1prJq9NJSHxGH+8pS1P3N7O7JKqL6ROxVPfrnBn6tykaXa/CS4jllCq6PNWvlxx2h0rkzMktPyUUmoA8BYQBMzRWk81uaRq8/Y0uae01ry6/hDzth5hzK9a88yADqgKOoFtwdPAcGfq3KRpdr8Jroo6+ipaQmniyj3k5f8SUsFBigd6tXTa/Xdrh3D6Tt1c7h3jiNnfsDU9u/ix7ZrUJiw4yCfdgyuTM5i4co+Elh82ViilgoB3gTuADGCHUupjrfUBcyvzT29v+o4ZW9IZHhvFC/d0sndoBQijdkAeoJQ6pJT6Tin1jBHndJfbyyKVvbDSjiWbyp6jqJW97OdQ7nh9S6nQAsfeW5HXhlZraSZ3rCoMrb7XeSm07LS8kxHv+Kz35+0NfKe1Pqy1zgWWAQPNKsafzdiSzhsbv+W+npFMGdhZQssmPL7istK7Q1enJF5bf4i8gtLJlVegeW39IbY+c1upc/Sdutlpw0XZDSOLpJ2+yPdT76pG9a5ZlZzB//NmaIFtr1SqzXp/3gig5CZuGUBs2QcppSYAEwCioqJ8U5kfmff1EaatO8g90S2YNqQrNcouNCosy4grLtu9OzSikcMMH+50hNavrmvE7NExhIUE6PSg/3P2Clru5qvWepbWOkZrHRMeLou+uiMh8SgvrT1A/xub8vrQaIIktGzFiHtctnt3aEQjh699tCuDv67YQ5+2jZgzupeEln/LAEou1RAJnDCpFr+zMjnDsTddjZ28890bBE8uMaNi43ujhnGnK9HTDsZqMiK4XH53CMwCx0rWBjxvtbnbyOHssZHXhjqdLux7nfEby320K4MnlztCa+6DEloBYAfQTinVBsgE/gAMN7ck/7BmdyZPrdzDr2vs473gtwhRpW8DyIfOcS+4TQp5I4LLlHeHzpZmAtc+L+LOZ0sGdY9gRdKxUo0YPaLqkxDfhzte31IqvNo1qc39MVFOOxA9+XP+1Z3QcnedMk8+m+Hrjr6Kns/PaK2vKqX+DKzH0Q4/T2u93+SybG9dykmeXL6HmNYNmX3iX4SqPLNLEtVkRHD5/N2hsyWbJq7cA5ripouqlmtytZFj0up95boHt6ZnM2L2N2Scu1Lq+PdZl5i4Yo/LNVRlze5Mnly+m9g2XrzS8uSzGb7+DEcAhFYRrfWnwKdm1+EvNh88xaNLdxEdWZ95Y3oR9o9cs0sSHvC4OUNrfRUoeneYCiz39rtDZ0sr5eXrcp2CRiy3tDTxuNPjW9OzvVrDmt2Z/OWD3fRu05C5Y6QRQ4jq+irtDH9avJMOzeoxf2xv6tT0m4+vBixD/gV9/e7QnU4/T7sCK1oGyh3u1lAytOaN6UWtEBloQlTHtsNZxC9Mom3j2iwc15v6YcFmlyQMYMtXRHc6/TxdbqmiZaDc4U4NElpCGCP56DnGLdhB5LW1WDw+lmtrh/zyTZO64UqxwZqAVmXLV0VnnX7BQarUPS4wZrmlYbEtWbztWLnjfa9ryM5jPxpaw8d7TvCXD3bTq7WElhCe2JuRw5h522lStyZLxsfSuE7N0g+wwou9DdYEtCpbvjJW1BXobBNHT5dbmjKoC0Cp8w6LbcmUQV086mws6+M9J3hi2S5iWjdk/liTQ8uX70a99U7SistTCZ84cOI8o+Zup36tYJbEx9GkXqjZJQmD2TK4oHxXYFWbOHpiyqAuxQFWWQ0lj7vjPyVCa4HZoQW+fTdq1DvJF3/0vBZhe2mnfmLU3ERqhQSxND7OshtWCs8E5EaSVrJ27wme+GA3Ma0aMl+mB4WotiNnLzJ8TiI1aigSxsfSsmEts0sSXuI3wWXGJo6eWrv3BI8v203PqGuZP7YXtaVNV4hqOZ59ieGzt5FfoFkyPpa24TJV7M/85pXS15s4euqTvSd5fNluekQ1kNASwgMnci4zfM42LuXmszQ+jnZN65pdkmsqupcMrq+AE6D3cv3m1dKd9QfN9snekzy2bBc9ohqwYGxvCa2KuLt8VUnSPhwQTp+/wog5ieRczGPx+Fg6tahndkmuq+j3sLLfe7mXC/hRcLmz/qCZPt3nCK3uLRsw3xuhVdm7OCsyol5n7zqlfdjvZV34mRFzEjl1/gqLHupNdMsGZpckfMRvggtcX3/QLJ/uO8mjSx2htWCcl5aeqc67ODM5q1fecYoq5FzKZeTc7RzLvsSCsb3p2cr4XRmEdflNc4bVfeaL0BIiAJy/ksfoedtJP32B2aNj6HNdI7NLEj4mweUD61IcodVNQksIj1z4+Spj5m3nwInzzBjZg5tvkJ2fA5G8gnrZupST/HnJLrpG1mfB2F4SWkJU0+XcfB5asIM9GT8yfVh3+nVsanZJ5RnRFGSFdRQtTl5FvWhdyg/FofX+uN7UDS2xMrW3ut6qs9li2XtK7izy6YynA8yIgSuD369cyctnwqIktn+fzZsPdOPOLs3NLsk5I5qCpOu1ShJcXuIIrZ3OQwu81/VW2c+XbWyoqAnCnUU+nZ3XU0YMXBn8fiP3agGPJOzkq7SzvHZfVwZ2s24DlvANucflBev3O0KrS0WhJUQFlFL3K6X2K6UKlFIxZtdjtqv5BTy2dBebDp5myqDO3B/T0uyShAV4FFwyyMpbv/8HHkmQ0BLVlgIMBr40uxCz5Rdonly+h3X7f+Dvd3diZFwrs0sSFuHpFZcMshL+WxhanSMcoVVPQku4SWudqrW29srQPlBQoHl61V4+3nOCpwd04KFftzG7JGEhHt3j0lqnAiiljKnGxjYOw9vIAAANO0lEQVQcOMUjSxyhtfAhCS3hfUqpCcAEgKioKJOrMY7Wmr+vSWFlcgaP92vHw7+5zuyShMX4rDnDXwcZOELrfxKS6dTCjdAyouvN3Q5CTxbulC49wyilNgLNnHzrOa31GlfPo7WeBcwCiImJ0VU83Ba01ry09gAJicf40y3X8cTt7cwuSVhQlcElg6xyG0uE1iJ3rrSM6HozooPQ2WOdkS49w2itbze7BivSWjNt3SHmb/2esX1b8/SA9jKbI5yqMrhkkFVs44FTPJyQTKfm9Vgo97SE8Mhbm9KY+UU6I2KjeP7uThJaokLSDl9Nm1JLhNZDsdQPk9ASnlNK/V4plQH0AT5RSq03uyZfmLElnTc3pnFfz0gmD+wsoSUq5dE9LqXU74F3gHAcg2y31rq/IZVZ2KbUU/xpcTIdJbSEwbTWHwEfmV2HL837+gjT1h3knugWTBvSlRo1LBZa7qxyI/eCfcLTrsKAG2SbD57i4cU76di8HotcDS3Z1FD+DoRTi7cd5aW1BxhwYzNeHxpNkNVCC9xb5UZ+l31Cpgrd8PnB0/xp0U7aN6vLonFuXGlZYVPDit7x+eqdoBX+DoSlrEg6zqTVKdzWoQlvD+tOcJC8HAnXyFqFLvr84Gn+uCiZ9s3qsvihWOrXstn0oLwTFBayZncmT6/ay03tGvPeiB6EXCOhJVwnvy0usH1oCWEh61JO8uTyPcS0bsisUTGEBgeZXZKwGQmuKnx+yBFaNzSrI6ElhIc2pZ7i0aW7iI6sz7wxvQgLkdAS7pOpwkpsKQytdk0ltITw1FdpZ4obmyy7E7irq8uAdAqayIK/Odaw5dBpJixKpl2TOiSMj6VBrZDSD7BCi6ydWm/tVKsw3LbDWcQvTKJteG37fljf6H3nRLVJcDnxxbdnKg8tsEaLrJ0aLuxUqzBU8tFsxi3YQeS1tVhc0XgSwg1yj6uML749Q/zCJK4PryS0hBAu2ZuRw5h5O2hStyZLxsfSuE5Ns0sSfkCCqwQJLSGMc+DEeUbN3U79WsEsiY+jSb1Qs0sSfkKCq9CXZULr2toSWkJUV9qpnxg1N5FaIUEsjY+jRYMws0sSfkTuceHodopfmMR1ElpCeOzI2YsMn5NIjRqKJfFxtGxYq/onk6XChBMBEVyrd2Xy2vpDnMi5TIsGYUzs355B3SMAR2iNfz+JNo1ruxdaVuiSs9OgtlOtotqOZ19i+Oxt5BdoPpgQR5vGtT07oa+XCpPOQVvw++BavSuTZz/cx+W8fAAycy7z7If7AGhcp2ZxaC2Jj6OhO1daVnixtdP6f3aqVVTLiZzLDJu9jUu5+SyNj6Nd07pmlyT8lN8H12vrDxWHVpHLeflMXnuACz9frV5oCSFKOX3+CiPmJPLjpTwS4mPp1KKe2SUJP+b3zRknci47PZ51MVdCSwgDnL3wM8PnJHLq/BUWjOtF18gGZpck/JzfB1dF3UzX1FAkjI+V0BLCAzmXchk5J5GMc5eYN6YXPVs1NLskEQD8fqpwYv/2pe5xASjgxXtupJF8GNI7KmrEEFVSSr0G3APkAunAWK11jrlVOXf+Sh6j5m7n8NmLzH0whri2jYx/Eis0QXmLNCxVm0fBZYdBVtQ9OGXtAc5ezOWaGooX77mRkX1amVyZAaw6qN0JLbNrtZ4NwLNa66tKqWnAs8DTJtdUzoWfrzJm3nYO/nCemSN7clO7cO88kT+/gEvDUrV5esVli0HWpF5NLuRe5YamdVgSH+c/y87YcVBLu3GltNb/LfHlNuA+s2qpyOXcfB5asIM9GT/y7vDu9OvY1OySRIDx6B6X1vq/WuurhV9uAyI9L8lY/5d+lnELdhDVsJZ/hZYIBOOAzyr6plJqglIqSSmVdObMGZ8UdCUvnwmLktj+fTavD41mQOfmPnleIUoysjnDcoPsm/QsCS1hOUqpjUqpFCf/DSzxmOeAq0BCRefRWs/SWsdorWPCw700VVdC7tUC/idhJ1+lneXVIV0Z2C3C688phDNVThUqpTYCzZx86zmt9ZrCx7g0yIBZADExMbpa1bqhKLRaXiuhJaxFa317Zd9XSj0I3A3001p7fay44mp+AY8t3cXmg6d5+feduT+mpdkliQBWZXDZcZBtO5xVuP9PmISWGazaNGIDSqkBOO4T36K1vmR2PQD5BZonl+9h3f4feP7uToyI9YPGJiuQcVJtnnYVWm6QbTucxdj5v4RWeF0JLZ+zY9OIdUwHagIblFIA27TWfzKrmIICzdOr9vLxnhM8PaAD437dxqxS/I+Mk2rztKvQUoMssTC0IiS0hE1pra83u4YiWmv+viaFlckZPHF7Ox7+zXVmlyQE4GFwWWmQJR7OYuyCHbRoEMqS+FgJLSE8oLXmpbUHSEg8xsO/uY7H+7UzuyQhivnFkk/bj2QzdsEOmtcPZemEOJrUlZ1WhagurTXT1h1i/tbvGde3DU/1b0/hjIoQlmD74Np+JJsx87dLaAlhkLc2pTHzi3RGxEbx97s7SmgJy7H1WoU7vneEVrP6oSyNt0Boydpjwube2/Idb25M476ekUwe2FlCS1iSba+4dnyfzYPzHKG1LD6OJvUscKUla48JG5v79RFeXXeIe6NbMG1IV2rUkNAS1mTL4Er6PpsxVgstIWxs8bajTF57gDs7N+P1odEESWgJC7NdcCUVXmk1rSehJYQRlicdZ9LqFPp1aMJbf+jONUG2e1kQAcZWv6HJR38JraUTJLSE8NSa3Zk8vWovN7VrzLsjehByja1eEkSAss1vafLRbEbP/SW0mkpoCeGRz/ad5Mnle+jduiGzRsUQGhxkdklCuMQWwZV89BwPzttBE6uHVkVrjMnaY8JiNqWe4rFlu+jWsgHzxvQiLERCS9iH5dvhHaG1nfC6NVkab+HQAml5F7bw5bdneHjxTjo2r8f8sb2oXdPyLwNClGLpK66i0GpcJ4Sl8XE0q2/h0BLCBr5Jz2LCoiTahtdm4bje1AsNNrskIdxm2eDaeeyX0Fo2oY+ElhAeSj6azUPvO/aoSxgfS4NaIWaXJES1WDK4dh07x4Nzt9OoTghLJ8iVlhCe2puRw5h5O2haL5SE8bE0kj3qhI1ZLrh2HTvH6LnbaVgnhGUT4mheP8zskoSwtQMnzjNq7nbq1womYXysfIxE2J6lgqtoKwUJLSGM84/PUqkVEsTS+DhaNJAxJezPUu1ESin+PaonV/O1hJYQBpk+rAc5l3Np2bCW2aUIYQiPgkspNRkYCBQAp4ExWusTnpzT9BXehTCRN8ZU/VrB1K8l3YPCf3g6Vfia1rqr1robsBZ43oCahAhkMqaEqIJHwaW1Pl/iy9qA9qwcIQKbjCkhqubxPS6l1MvAaOBH4NZKHjcBmAAQFRXl6dMK4bdkTAlROaV15W/olFIbgWZOvvWc1npNicc9C4RqrV+o6kljYmJ0UlKSu7UKYRlKqWStdUw1f1bGlBBOuDquqrzi0lrf7uJzLgE+AaocZEIEMhlTQnjG067CdlrrtMIv7wUOuvJzycnJZ5VSRyt5SGPgrCe1GUzqqVwg1tPKGyeVMWUaqadyvqrHpXFV5VRhpT+s1CqgPY7W3aPAn7TWHi+RrpRKqu40jDdIPZWTeowjY8ocUk/lrFaPR1dcWushRhUihJAxJYQrLLXkkxBCCFEVqwbXLLMLKEPqqZzUY31W+zuReion9VTCo3tcQgghhK9Z9YpLCCGEcEqCSwghhK1YNriUUq8ppQ4qpfYqpT5SSjUwuZ77lVL7lVIFSilT2kKVUgOUUoeUUt8ppZ4xo4Yy9cxTSp1WSqVYoJaWSqnPlVKphf9Oj5tdk9XImHJag4ypimux7JiybHABG4DOWuuuwLfAsybXkwIMBr4048mVUkHAu8CdQCdgmFKqkxm1lLAAGGByDUWuAn/VWncE4oBHLPD3YzUypkqQMVUly44pywaX1vq/WuurhV9uAyJNridVa33IxBJ6A99prQ9rrXOBZTj2bTKN1vpLINvMGoporU9qrXcW/v9PQCoQYW5V1iJjqhwZU5Ww8piybHCVMQ74zOwiTBYBHC/xdQYW+SWyGqVUa6A7kGhuJZYmY0rGlMusNqY83tbEE66skq2Ueg7HJWuCFeoxkXJyTD7LUIZSqg6wCniizN5WAUHGlFtkTLnAimPK1OCqapVspdSDwN1AP+2DD5y5sWq3GTKAliW+jgQ82tLd3yilgnEMsASt9Ydm12MGGVNukTFVBauOKctOFSqlBgBPA/dqrS+ZXY8F7ADaKaXaKKVCgD8AH5tck2UopRQwF0jVWr9udj1WJGOqHBlTlbDymLJscAHTgbrABqXUbqXUTDOLUUr9XimVAfQBPlFKrffl8xfeVP8zsB7HTdLlWuv9vqyhLKXUUuAboL1SKkMp9ZCJ5fQFRgG3Ff6+7FZK/c7EeqxIxlQJMqaqZNkxJUs+CSGEsBUrX3EJIYQQ5UhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSv/Hx6L8u+u5Y/dAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### 2D Decision Boundary\n", "##########################\n", "\n", "w, b = logr.weights, logr.bias\n", "\n", "x_min = -2\n", "y_min = ( (-(w[0] * x_min) - b[0]) \n", " / w[1] )\n", "\n", "x_max = 2\n", "y_max = ( (-(w[0] * x_max) - b[0]) \n", " / w[1] )\n", "\n", "\n", "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n", "\n", "ax[0].plot([x_min, x_max], [y_min, y_max])\n", "ax[1].plot([x_min, x_max], [y_min, y_max])\n", "\n", "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Low-level implementation using autograd" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def custom_where(cond, x_1, x_2):\n", " return (cond * x_1) + ((1-cond) * x_2)\n", "\n", "\n", "class LogisticRegression2():\n", " def __init__(self, num_features):\n", " self.num_features = num_features\n", " \n", " self.weights = torch.zeros(num_features, 1, \n", " dtype=torch.float32,\n", " device=device,\n", " requires_grad=True) # req. for autograd!\n", " self.bias = torch.zeros(1, \n", " dtype=torch.float32,\n", " device=device,\n", " requires_grad=True) # req. for autograd!\n", "\n", " def forward(self, x):\n", " linear = torch.add(torch.mm(x, self.weights), self.bias)\n", " probas = self._sigmoid(linear)\n", " return probas\n", " \n", " def predict_labels(self, x):\n", " probas = self.forward(x)\n", " labels = custom_where((probas >= .5).float(), 1, 0)\n", " return labels \n", " \n", " def evaluate(self, x, y):\n", " labels = self.predict_labels(x)\n", " accuracy = (torch.sum(labels.view(-1) == y.view(-1))).float() / y.size()[0]\n", " return accuracy\n", " \n", " def _sigmoid(self, z):\n", " return 1. / (1. + torch.exp(-z))\n", " \n", " def _logit_cost(self, y, proba):\n", " tmp1 = torch.mm(-y.view(1, -1), torch.log(proba))\n", " tmp2 = torch.mm((1 - y).view(1, -1), torch.log(1 - proba))\n", " return tmp1 - tmp2\n", " \n", " def train(self, x, y, num_epochs, learning_rate=0.01):\n", " \n", " for e in range(num_epochs):\n", " \n", " #### Compute outputs ####\n", " proba = self.forward(x)\n", " cost = self._logit_cost(y, proba)\n", " \n", " #### Compute gradients ####\n", " cost.backward()\n", " \n", " #### Update weights ####\n", " \n", " tmp = self.weights.detach()\n", " tmp -= learning_rate * self.weights.grad\n", " \n", " tmp = self.bias.detach()\n", " tmp -= learning_rate * self.bias.grad\n", " \n", " #### Reset gradients to zero for next iteration ####\n", " self.weights.grad.zero_()\n", " self.bias.grad.zero_()\n", " \n", " #### Logging ####\n", " print('Epoch: %03d' % (e+1), end=\"\")\n", " print(' | Train ACC: %.3f' % self.evaluate(x, y), end=\"\")\n", " print(' | Cost: %.3f' % self._logit_cost(y, self.forward(x)))\n", " \n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | Train ACC: 0.987 | Cost: 5.581\n", "Epoch: 002 | Train ACC: 0.987 | Cost: 4.882\n", "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n", "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n", "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n", "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n", "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n", "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n", "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n", "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n", "\n", "Model parameters:\n", " Weights: tensor([[ 4.2267],\n", " [-2.9613]], device='cuda:0', requires_grad=True)\n", " Bias: tensor([0.0994], device='cuda:0', requires_grad=True)\n" ] } ], "source": [ "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device)\n", "\n", "logr = LogisticRegression2(num_features=2)\n", "logr.train(X_train_tensor, y_train_tensor, num_epochs=10, learning_rate=0.1)\n", "\n", "print('\\nModel parameters:')\n", "print(' Weights: %s' % logr.weights)\n", "print(' Bias: %s' % logr.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluating the Model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test set accuracy: 100.00%\n" ] } ], "source": [ "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", "\n", "test_acc = logr.evaluate(X_test_tensor, y_test_tensor)\n", "print('Test set accuracy: %.2f%%' % (test_acc*100))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAADFCAYAAAAMsRa3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VNX9//HXISYk7AJhSwiLIotAWAIJpWoVLVgXKCiWVUBC69e61P5wqVT9ClrQ1hWFsgsEkEXBokJZxIWvBBLWQMAYEEhAloSIbCYk5/fHJDHLJJnJ3Jl778zn+Xj4eJibyZ0PkDPvued+5hyltUYIIYSwixpmFyCEEEK4Q4JLCCGErUhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSsSXEIIIWzlGjOetHHjxrp169ZmPLUQhkhOTj6rtQ43u44iMqaEP3B1XJkSXK1btyYpKcmMpxbCEEqpo148dyjwJVATxxhdqbV+obKfkTEl/IGr48qU4BJCVOpn4Dat9QWlVDDwtVLqM631NrMLE8IKJLiEsBjtWPn6QuGXwYX/yWrYQhSS5gwhLEgpFaSU2g2cBjZorROdPGaCUipJKZV05swZ3xcphEnkiksIC9Ja5wPdlFINgI+UUp211illHjMLmAUQExNT7oosLy+PjIwMrly54pOa7SI0NJTIyEiCg4PNLkVUk8fBVZ0byUII12itc5RSW4ABQEoVDy8lIyODunXr0rp1a5RSXqnPbrTWZGVlkZGRQZs2bcwux/JyLuXy9Kq9TLqrEy0b1jK7nGJGTBUW3UiOBroBA5RScQacVxhs9a5M+k7dTJtnPqHv1M2s3pVpdknCCaVUeOGVFkqpMOB24KC757ly5QqNGjWS0CpBKUWjRo3kKtQF56/k8eC87Xx+8AxHsy6ZXU4pHl9xyY1ke1i9K5NnP9zH5bx8ADJzLvPsh/sAGNQ9wszSRHnNgfeVUkE43lwu11qvrc6JJLTKk7+Tql38+Spj5+9g/4nzzBjZk1+3a2x2SaUYco+rcIAlA9cD71Z0IxmYABAVFWXE0wo3vLb+UHFoFbmcl89r6w9JcFmM1nov0N3sOkRgupybz0Pv72DXsXNMH96DOzo1NbukcgzpKtRa52utuwGRQG+lVGcnj5mltY7RWseEh1tmwYGAcSLnslvH7UhrzYYDp3BMAgirevHFF/nnP//plXMnJyfTpUsXrr/+eh577DH5XXDTlbx8JixKIvFINq8P7cbvujQ3uySnDG2H11rnAFtw3EgWFtKiQZhbx+1Ga83Ln6QSvzCJ9ftPmV2OLfnDPdCHH36YWbNmkZaWRlpaGuvWrTO7JNvIvVrAn5fs5Ku0s0wb3NXSMzEeB5dRN5KFd03s356w4KBSx8KCg5jYv71JFRlHa80rn6Yy5+sjjPlVa/rfaL2pDasrugeamXMZzS/3QD0Nr4ULF9K1a1eio6MZNWpUue/Pnj2bXr16ER0dzZAhQ7h0ydEEsGLFCjp37kx0dDQ333wzAPv376d3795069aNrl27kpaWVupcJ0+e5Pz58/Tp0welFKNHj2b16tUe1R8oruYX8PiyXWxMPc3kgTcytFdLs0uqlBH3uAy7kSy8p+jd02vrD3Ei5zItGoQxsX97S7+rckVRaM3+6ggP9mnFC/d0kpvv1eCNe6D79+/n5ZdfZuvWrTRu3Jjs7Oxyjxk8eDDx8fEATJo0iblz5/Loo4/y0ksvsX79eiIiIsjJyQFg5syZPP7444wYMYLc3Fzy80vXm5mZSWRkZPHXkZGRZGba76rR1/ILNH9dsYfPUn5g0l0dGdWntdklVcmIrkK5kWwTg7pH2D6oStJa84/PDhaH1ov33iihVU3euAe6efNm7rvvPho3dnSkNWzYsNxjUlJSmDRpEjk5OVy4cIH+/fsD0LdvX8aMGcPQoUMZPHgwAH369OHll18mIyODwYMH065du1LncnY/qzq/D6t3ZfrdG7yKFBRonv1wL2t2n2Bi//aMv6mt2SW5RJZ8EraktWbqZweZ9eVhRktoecwb90C11lX+m4wZM4bp06ezb98+XnjhheLPV82cOZMpU6Zw/PhxunXrRlZWFsOHD+fjjz8mLCyM/v37s3nz5lLnioyMJCMjo/jrjIwMWrRo4VbN3poytSKtNc9/nMLypAweu+16Hrn1erNLcpkEl7CdotD695eHGRXXiv+V0PKYN+6B9uvXj+XLl5OVlQXgdKrwp59+onnz5uTl5ZGQkFB8PD09ndjYWF566SUaN27M8ePHOXz4MG3btuWxxx7j3nvvZe/evaXO1bx5c+rWrcu2bdvQWrNw4UIGDhzoVs2VTZn6E601Uz5JZfG2Y/zx5rb85Y4bzC7JLbJWobAVrTVT1/0SWi8NlNAygjfugd54440899xz3HLLLQQFBdG9e3cWLFhQ6jGTJ08mNjaWVq1a0aVLF3766ScAJk6cSFpaGlpr+vXrR3R0NFOnTmXx4sUEBwfTrFkznn/++XLPOWPGDMaMGcPly5e58847ufPOO92qOVA+NvLa+kPMLWxmeubODrYbQ8qMzznExMRo2fROuEtrzbR1h5j5RToj46KYPLCzaQNOKZWstY4x5cmdcDamUlNT6dixo0kVWVtFfzd9p24m00lIRTQIY+szt/miNK97e1Mar2/4lmG9o3jl9+aNIWdcHVcyVShsQWvNq+sdoTUiNoqX7rXWgBP+wZ8/NgIw84t0Xt/wLYN7RPDyIPuOIZkqFJZXFFoztjhCa/LAztSoYc8BJ7zr3KVcTv14hdz8AkKCatC0fijX1gpx+ef99WMjAPO3HmHqZwe5u2tzXh3S1dZjSIJLWFrRfPyMLekMl9ASlTh3KZfMc5cpKLz9kZtfQOY5x7Sfu+HlD0FV0pLEY/zvfw7w205NeeOBblwTZO/JNntXL/ya1pp//vcQ721JZ1jvKKZIaIlKnPrxSnFoFSnQmlM/BvYWJiuTM3hu9T5ubR/OO8O7E2zz0AIJLmFRRaH17ueO0Hp5kISWqFxufoFbxwPBx3tO8NTKPfzqukbMGNmTmtcEVf1DNiDBJSxHa82//vttYWi1lNASLgmp4EqiouP+bl3KD/zlg93EtGrI7NExhAb7R2iBBJewGK01r2/4lumff8cferXk5UFdJLT8jLe2NWlaP5Tpr07ht71vJK69Y83CGkrRtH6o4c9ldZ8fPM2jS3fSNbI+88b2olaIf7Uz+NefRtia1po3NnzLO5sdofXK7yW0fOaVCMi9UP54SB34mz2WO7q2Vgj3DR7IyHETGNC3R7W6Cv3B12ln+ePiZNo3q8uCsb2pU9P/XubliktYQlFovb35Ox6ICezQUkq1VEp9rpRKVUrtV0o97vUndRZalR13kS+3NQH47W9u4uZuN1BDQYfm9QIutBIPZzF+4Q7aNq7NonGx1A8LNrskr/C/KBa29MbGtOLQ+sfgwA2tQleBv2qtdyql6gLJSqkNWusDZhfmDl9vaxLoko+eY9yCHUQ0CGPRQ7FcW9t/Q1uuuITp3tjwLW9vSmNoTKSEFqC1Pqm13ln4/z8BqYDtPljk6rYmN910E126dCEhIYH9+/cDv2xrMnv27OKA6tOnD6+88grTpk3j6NGjhIX5x+7dRtiX8SNj5m+ncd2aLImPI7xuTbNL8ioJrgBn9nbtb2z4lrc2pXF/z0imDrb3p/m9QSnVGsd+d4lOvjdBKZWklEo6c+aMr0urkq+3NQlUqSfPM2peIvVCg1kSH0fTev7fjOJxcJkyHy8MYfbeQ29u/CW0ptl8CRpvUErVAVYBT2itz5f9vtZ6ltY6RmsdEx4e7vsCq+DrbU0C0Xenf2LknERCrwliaXwcER7sn2YnRlxxFc3HdwTigEeUUp0MOK/wMjP3HnprYxpvbkzjPgktp5RSwThCK0Fr/aHXnzCkjnvHXVByW5Po6GiefPLJco8p2tbkjjvuoEOHDsXHJ06cSJcuXejcuTM333wz0dHRfPDBB3Tu3Jlu3bpx8OBBRo8eXe58Tz31FJGRkVy6dInIyEhefPHFatdvdUfOXmT47ESUUiyJjyWqUS2zS/IZw7c1UUqtAaZrrTdU9BjZ1sQa2jzzCc7+9RVwZOpdXnvetzam8cbGb4tDK8iGoeXNbU2UY37tfSBba/2EKz8j25q4x+5/N8ezL/HAv7/hytUClk2I44amdc0uyRCmbGti5/n4QOSN7dqr8vYmR2gN6WHf0PKBvsAo4Dal1O7C/35ndlHCGk7+eJnhc7Zx4eerLHqot9+EljsMa4d3ZT4emAWOd4dGPW+gmLR6H0sTj5OvNUFKMSy2JVMGdfHonBP7t+fZD/eVmi705t5D7xRuYDe4RwSv3iehVRGt9dc4LnyFKOX0+SsMn53IuYt5JIyP5cYW9c0uyRSGBJfP5+MDzKTV+1i87Vjx1/laF3/tSXj5cu+hdzal8a8N3zK4ewSv3RctoeUjrnT2BRozdn03QtaFnxkxJ5FT56+wcFxvols2MLsk03gcXIXz8XOBVK31656XJMpamni8wuOeXnX5Yu+h6ZtLhNb9Elq+EhoaSlZWFo0aNZLwKqS1Jisri9BQe7WM51zKZeTc7RzLvsT8sb2IaV3+M3GBxIgrrqL5+H1Kqd2Fx/6mtf7UgHMLHFdY7hy3knc//45//ldCywyRkZFkZGTgr/eUcy7lcvHnfDSOedXaNYNo4MIST6GhoURGRnq9PqOcv5LHg/O2k376ArMfjOFX1zU2uyTTeRxcMh/vfUFKOQ2pIIu/i3738+94bf0hfi+hZYrg4GDatGljdhle4Zg+zyh3fGRclMezEFZy8eerjJ2/g/0nzjNzZE9uucF6n9czg6ycYQPDYlu6ddwKikJrULcW/FNCSxissulzf3E5N5+H3t/BrmPneHtYd27v1NTskixDFtn1kdW7MqvdBFH0DtLTrkJPanDHe1t+Ca1/De0moSUMZ+fpc1dcyctnwqIkEo9k88bQbvyuS3OzS7IUCS4fKFpaqajtvGhpJcCt8PJkCsSIGlwxY0s6r647xEAJLeFFdp0+d0Xu1QL+vGQnX6Wd5dUhXb3ePGVHMlXoA2YureTLGmZsSWfauoPcG92Cf8n0oPAiI6bPzV5g2pmr+QU8vmwXG1NPM3lQZ4b2su7tADPJFZcPnMi57NZxO9Yw84tfQuv1odFcEyTviYT3eDp97qsZCHfkF2j+umIPn6X8wKS7OjIqrpUpddiBBJcPtGgQRqaTgPDm0kq+rOHfX6Qz9bOD3COhJXzIk+nzymYgzAiuggLNM6v2smb3CSb2b8/4m9r6vAY7kVcYH5jYvz1hwUGljnlzaSVf1vDvL9L5R2FovSGhJWzCCrMgRbTWPP9xCiuSM3isXzseufV6n9dgN3LF5QODukeQdDS71LTGkJ4Vr1jhje4/byzvNOtLR2jd3bW5hJawFSvMgoAjtCavTWXxtmP88Za2/OX2dj59fruS4PKB1bsyWZWcWdwFla81q5IziWnVsFxweHPu3cjlnWZ/eZhXPnWE1psPdJPQErbi6wWmndFa8+r6Q8zbeoQxv2rNMwM6yNJcLpJXGx9wp6PPCh2IVZnz1WFe/jSVuyS0hE0N6h7BPwZ3IaJBGAqIaBDGPwZ38en9rbc3fceMLekM6x3FC/d0ktByg1xx+YA78+lWmnt3Zs5Xh5nySSp3dWnOWxJawsZ8scB0RWZ+kV68L93LgzpLaLlJXnV8wJ0NG83Y3NFVJUPrzT9IaAlRHfO3Hinuwn31vq7UkM87ui1gX3m89eFDZ+d1p6PPCh2IzhSF1u+6NOPNP3QjWEJLCLclJB7lf/9zgP43NuX1ofIh/eoKyFefogaIzJzLaH5pgPA0vCo6L+DyfLoV5t7Lmvv1EaZ8ksqdnZvx1h+6S2j5gFJqnlLqtFIqxexahDFWJmfw3Ecp3No+nHeG9ZBx5IGAvMflrQ8fVnberc/c5vK5zZx7L2vu10eYvPYAd3ZuxtvDJLR8aAEwHVhoch3CAB/vOcFTK/fw6+sbM2NkT0KukXHkiYAMLm81QFi9scJd8yS0TKO1/lIp1drsOqzMW7sdGH3edSk/8JcPdhPTuiGzRvcktMytAOE+Q16J7Dat4a0GCCs3Vrhr/tYjvLT2AANulNCyKqXUBKVUklIqyV93Oa6Ir6f7q3vezQdP8ejSnXSNrM+8Mb2oFRKQ1wqGM+rVaAEwwKBzeZ23GiCs2ljhrgVbjxTfQH5nuISWVWmtZ2mtY7TWMeHhgbUzrrc+72jkeb9OO8ufFu+kfbO6LBjbmzo1JbSMYsjfpN2mNbyx/FHReVckHWNrenbxsR5R9RnUPcLp9IM3avDUgq1HeLEotOQGsrAoq0/3Jx7OYvzCHbRtXJtF42KpHxbsUV2iNJ+9BVBKTQAmAERFRfnqaSvkjQaISav3lQotgK3p2YyY/Q07j/1YahmniSv3gIa8Al18zOxtFd7/v+958T8H+G0nR2jJDWRhVd5aa9CI8yYfPce4BTuIaBDG4vGxXFs7xKOaRHk+e2UKhGmNpYnHnR7fmp5dbvohL18Xh1YRM5d2WvjN97zw8X7u6NSU6cMltMymlFoKfAO0V0plKKUeMrsmK7HqdP/ejBzGzNtOeN2aLImPo3Gdmh7VI5yTSVcDOdtK3F1mdCAu/OZ7nl/jCK13JbQsQWs9zOwarMzdHRecqax7sDpT+KknzzNq7nbqhQWTEB9H03qh1f7zicpJcBkoSCmPw8vXHYiLvpHQEvbjzo4LFf18ZbswuDtd/93pnxg5J5Gw4CCWxscRYcNOYjsxqh1epjWAYbEtnR7ve13DctMPwUGK4DLLvfi6A3HRtqP8fc1+bu8ooSXsxdPuPyO7B4+cvcjw2YkopVgSH0tUo1pun0O4x6iuQttNa7jzIcNJq/eVmpIYFtvS6ZbhUwZ1IfFwFmmnLxYfa9ekNgnxfZw+n5mbSy7adpS/r07h9o5NeG+EC6H1SgTkXih/PKQO/M2YdR6FcJWn3X9GdQ8ez77E8NnbuFqgWTYhjrbhddz6eVE9ATlV6M5mjZNW72PxtmPFX+drXfx12fCatHpfqdACSDt9kUmr9zFlUOn1Bs3cXHJxidB615XQAuehVdlxIbzI0+4/I7oHT/54meFztnEpN5+l8XHc0LSuyz8rPBOQc0PuTBNU1Cno7Lg7jzVrc8mExKNMWp1Cvw6O0Kp5jSw/I+zH0+4/T3/+9PkrDJ+dSM7FPBaO602nFvVcK1wYIiCvuNyZJqio2cLZcXcea8bmkgmJR3nuoxRu69CE90ZKaAn78nQRAU9+PuvCz4yYk8ip81dYOK430S0bVP8P4okAnr4PyOByZ5qgok7BICc7lrrzWHdqMGJaY0niseLQmiGhJfyAp4sIVOfncy7lMnLudo5lX2LB2N7EtG5Y7ef3WABP3wfkVKE70wQVdQo6O17ZY8tuMHlrh3CfbS65JPEYf/ton4SWEB44fyWP0fO2k376ArNHx9DnukZmlxSwAjK43NmsccqgLoyMiyq+agpSipFxUU67CmNaNaTshqZFX5ddcXpVciZDekZ4fXPJpdsdoXVr+3DPQiukgm6pio4L4Ucu/nyVsfN3cODEed4b0YObb/DP1X/sQmkDVntwV0xMjE5KSvL583pb36mbnU7pVTSFGNEgjK3P3Oa1epZtP8YzHxaFluwDZCSlVLLWOsbsOor465iygsu5+YxdsJ0d359j+rDu3NmludklObxYv5Lv/ei7Ogzk6rgKyHtc3lJRs0RFTRveXN6pKLR+I6ElTOatDR994UpePhMWJZF4JJs3H+hmndCqirNQ86OmjYCcKvSWipolnDVnVPZ4T32wwxFat9wQzkwJLWEib2346Au5Vwt4JGEnX6WdZdqQrgzsZrGwdXea3o+aNuSKy0AT+7cv9UFhcDRRDOkZwarkzHLHvbG80/Idx4tD69+jJLSEuSr7DKKVr7qu5hfw+LJdbDp4msmDOjM0xnnjlakqunqqbArRT/hVcLkzJVHRYz2Z1qhoxeopg7oQ06qh16dLlu84ztMf7uWmdh6GVqB9PiTQ/rw+5K0NH70pv0Dz1xV7+CzlBybd1ZFRca3MLkmU4TfB5c6ySBU9NulodqkrI3eXVqpqGSdvvsNcnvRLaM3y9Eor0D4fEmh/Xh/y1oaP3lJQoHlm1V7W7D7BUwPaM/6mtmaXJJzwm3tcRiyhtDTxuGVWnHbHiqTjPL1qL7++vrHnoSWEgby14aM3aK35+5oUViRn8Hi/dvzPb643uyRRAb+54jJiCSVPu//MmBZZkXScpwpDa/boGAktYSmeLs3kK1prJq9NJSHxGH+8pS1P3N7O7JKqL6ROxVPfrnBn6tykaXa/CS4jllCq6PNWvlxx2h0rkzMktPyUUmoA8BYQBMzRWk81uaRq8/Y0uae01ry6/hDzth5hzK9a88yADqgKOoFtwdPAcGfq3KRpdr8Jroo6+ipaQmniyj3k5f8SUsFBigd6tXTa/Xdrh3D6Tt1c7h3jiNnfsDU9u/ix7ZrUJiw4yCfdgyuTM5i4co+Elh82ViilgoB3gTuADGCHUupjrfUBcyvzT29v+o4ZW9IZHhvFC/d0sndoBQijdkAeoJQ6pJT6Tin1jBHndJfbyyKVvbDSjiWbyp6jqJW97OdQ7nh9S6nQAsfeW5HXhlZraSZ3rCoMrb7XeSm07LS8kxHv+Kz35+0NfKe1Pqy1zgWWAQPNKsafzdiSzhsbv+W+npFMGdhZQssmPL7istK7Q1enJF5bf4i8gtLJlVegeW39IbY+c1upc/Sdutlpw0XZDSOLpJ2+yPdT76pG9a5ZlZzB//NmaIFtr1SqzXp/3gig5CZuGUBs2QcppSYAEwCioqJ8U5kfmff1EaatO8g90S2YNqQrNcouNCosy4grLtu9OzSikcMMH+50hNavrmvE7NExhIUE6PSg/3P2Clru5qvWepbWOkZrHRMeLou+uiMh8SgvrT1A/xub8vrQaIIktGzFiHtctnt3aEQjh699tCuDv67YQ5+2jZgzupeEln/LAEou1RAJnDCpFr+zMjnDsTddjZ28890bBE8uMaNi43ujhnGnK9HTDsZqMiK4XH53CMwCx0rWBjxvtbnbyOHssZHXhjqdLux7nfEby320K4MnlztCa+6DEloBYAfQTinVBsgE/gAMN7ck/7BmdyZPrdzDr2vs473gtwhRpW8DyIfOcS+4TQp5I4LLlHeHzpZmAtc+L+LOZ0sGdY9gRdKxUo0YPaLqkxDfhzte31IqvNo1qc39MVFOOxA9+XP+1Z3QcnedMk8+m+Hrjr6Kns/PaK2vKqX+DKzH0Q4/T2u93+SybG9dykmeXL6HmNYNmX3iX4SqPLNLEtVkRHD5/N2hsyWbJq7cA5ripouqlmtytZFj0up95boHt6ZnM2L2N2Scu1Lq+PdZl5i4Yo/LNVRlze5Mnly+m9g2XrzS8uSzGb7+DEcAhFYRrfWnwKdm1+EvNh88xaNLdxEdWZ95Y3oR9o9cs0sSHvC4OUNrfRUoeneYCiz39rtDZ0sr5eXrcp2CRiy3tDTxuNPjW9OzvVrDmt2Z/OWD3fRu05C5Y6QRQ4jq+irtDH9avJMOzeoxf2xv6tT0m4+vBixD/gV9/e7QnU4/T7sCK1oGyh3u1lAytOaN6UWtEBloQlTHtsNZxC9Mom3j2iwc15v6YcFmlyQMYMtXRHc6/TxdbqmiZaDc4U4NElpCGCP56DnGLdhB5LW1WDw+lmtrh/zyTZO64UqxwZqAVmXLV0VnnX7BQarUPS4wZrmlYbEtWbztWLnjfa9ryM5jPxpaw8d7TvCXD3bTq7WElhCe2JuRw5h522lStyZLxsfSuE7N0g+wwou9DdYEtCpbvjJW1BXobBNHT5dbmjKoC0Cp8w6LbcmUQV086mws6+M9J3hi2S5iWjdk/liTQ8uX70a99U7SistTCZ84cOI8o+Zup36tYJbEx9GkXqjZJQmD2TK4oHxXYFWbOHpiyqAuxQFWWQ0lj7vjPyVCa4HZoQW+fTdq1DvJF3/0vBZhe2mnfmLU3ERqhQSxND7OshtWCs8E5EaSVrJ27wme+GA3Ma0aMl+mB4WotiNnLzJ8TiI1aigSxsfSsmEts0sSXuI3wWXGJo6eWrv3BI8v203PqGuZP7YXtaVNV4hqOZ59ieGzt5FfoFkyPpa24TJV7M/85pXS15s4euqTvSd5fNluekQ1kNASwgMnci4zfM42LuXmszQ+jnZN65pdkmsqupcMrq+AE6D3cv3m1dKd9QfN9snekzy2bBc9ohqwYGxvCa2KuLt8VUnSPhwQTp+/wog5ieRczGPx+Fg6tahndkmuq+j3sLLfe7mXC/hRcLmz/qCZPt3nCK3uLRsw3xuhVdm7OCsyol5n7zqlfdjvZV34mRFzEjl1/gqLHupNdMsGZpckfMRvggtcX3/QLJ/uO8mjSx2htWCcl5aeqc67ODM5q1fecYoq5FzKZeTc7RzLvsSCsb3p2cr4XRmEdflNc4bVfeaL0BIiAJy/ksfoedtJP32B2aNj6HNdI7NLEj4mweUD61IcodVNQksIj1z4+Spj5m3nwInzzBjZg5tvkJ2fA5G8gnrZupST/HnJLrpG1mfB2F4SWkJU0+XcfB5asIM9GT8yfVh3+nVsanZJ5RnRFGSFdRQtTl5FvWhdyg/FofX+uN7UDS2xMrW3ut6qs9li2XtK7izy6YynA8yIgSuD369cyctnwqIktn+fzZsPdOPOLs3NLsk5I5qCpOu1ShJcXuIIrZ3OQwu81/VW2c+XbWyoqAnCnUU+nZ3XU0YMXBn8fiP3agGPJOzkq7SzvHZfVwZ2s24DlvANucflBev3O0KrS0WhJUQFlFL3K6X2K6UKlFIxZtdjtqv5BTy2dBebDp5myqDO3B/T0uyShAV4FFwyyMpbv/8HHkmQ0BLVlgIMBr40uxCz5Rdonly+h3X7f+Dvd3diZFwrs0sSFuHpFZcMshL+WxhanSMcoVVPQku4SWudqrW29srQPlBQoHl61V4+3nOCpwd04KFftzG7JGEhHt3j0lqnAiiljKnGxjYOw9vIAAANO0lEQVQcOMUjSxyhtfAhCS3hfUqpCcAEgKioKJOrMY7Wmr+vSWFlcgaP92vHw7+5zuyShMX4rDnDXwcZOELrfxKS6dTCjdAyouvN3Q5CTxbulC49wyilNgLNnHzrOa31GlfPo7WeBcwCiImJ0VU83Ba01ry09gAJicf40y3X8cTt7cwuSVhQlcElg6xyG0uE1iJ3rrSM6HozooPQ2WOdkS49w2itbze7BivSWjNt3SHmb/2esX1b8/SA9jKbI5yqMrhkkFVs44FTPJyQTKfm9Vgo97SE8Mhbm9KY+UU6I2KjeP7uThJaokLSDl9Nm1JLhNZDsdQPk9ASnlNK/V4plQH0AT5RSq03uyZfmLElnTc3pnFfz0gmD+wsoSUq5dE9LqXU74F3gHAcg2y31rq/IZVZ2KbUU/xpcTIdJbSEwbTWHwEfmV2HL837+gjT1h3knugWTBvSlRo1LBZa7qxyI/eCfcLTrsKAG2SbD57i4cU76di8HotcDS3Z1FD+DoRTi7cd5aW1BxhwYzNeHxpNkNVCC9xb5UZ+l31Cpgrd8PnB0/xp0U7aN6vLonFuXGlZYVPDit7x+eqdoBX+DoSlrEg6zqTVKdzWoQlvD+tOcJC8HAnXyFqFLvr84Gn+uCiZ9s3qsvihWOrXstn0oLwTFBayZncmT6/ay03tGvPeiB6EXCOhJVwnvy0usH1oCWEh61JO8uTyPcS0bsisUTGEBgeZXZKwGQmuKnx+yBFaNzSrI6ElhIc2pZ7i0aW7iI6sz7wxvQgLkdAS7pOpwkpsKQytdk0ltITw1FdpZ4obmyy7E7irq8uAdAqayIK/Odaw5dBpJixKpl2TOiSMj6VBrZDSD7BCi6ydWm/tVKsw3LbDWcQvTKJteG37fljf6H3nRLVJcDnxxbdnKg8tsEaLrJ0aLuxUqzBU8tFsxi3YQeS1tVhc0XgSwg1yj6uML749Q/zCJK4PryS0hBAu2ZuRw5h5O2hStyZLxsfSuE5Ns0sSfkCCqwQJLSGMc+DEeUbN3U79WsEsiY+jSb1Qs0sSfkKCq9CXZULr2toSWkJUV9qpnxg1N5FaIUEsjY+jRYMws0sSfkTuceHodopfmMR1ElpCeOzI2YsMn5NIjRqKJfFxtGxYq/onk6XChBMBEVyrd2Xy2vpDnMi5TIsGYUzs355B3SMAR2iNfz+JNo1ruxdaVuiSs9OgtlOtotqOZ19i+Oxt5BdoPpgQR5vGtT07oa+XCpPOQVvw++BavSuTZz/cx+W8fAAycy7z7If7AGhcp2ZxaC2Jj6OhO1daVnixtdP6f3aqVVTLiZzLDJu9jUu5+SyNj6Nd07pmlyT8lN8H12vrDxWHVpHLeflMXnuACz9frV5oCSFKOX3+CiPmJPLjpTwS4mPp1KKe2SUJP+b3zRknci47PZ51MVdCSwgDnL3wM8PnJHLq/BUWjOtF18gGZpck/JzfB1dF3UzX1FAkjI+V0BLCAzmXchk5J5GMc5eYN6YXPVs1NLskEQD8fqpwYv/2pe5xASjgxXtupJF8GNI7KmrEEFVSSr0G3APkAunAWK11jrlVOXf+Sh6j5m7n8NmLzH0whri2jYx/Eis0QXmLNCxVm0fBZYdBVtQ9OGXtAc5ezOWaGooX77mRkX1amVyZAaw6qN0JLbNrtZ4NwLNa66tKqWnAs8DTJtdUzoWfrzJm3nYO/nCemSN7clO7cO88kT+/gEvDUrV5esVli0HWpF5NLuRe5YamdVgSH+c/y87YcVBLu3GltNb/LfHlNuA+s2qpyOXcfB5asIM9GT/y7vDu9OvY1OySRIDx6B6X1vq/WuurhV9uAyI9L8lY/5d+lnELdhDVsJZ/hZYIBOOAzyr6plJqglIqSSmVdObMGZ8UdCUvnwmLktj+fTavD41mQOfmPnleIUoysjnDcoPsm/QsCS1hOUqpjUqpFCf/DSzxmOeAq0BCRefRWs/SWsdorWPCw700VVdC7tUC/idhJ1+lneXVIV0Z2C3C688phDNVThUqpTYCzZx86zmt9ZrCx7g0yIBZADExMbpa1bqhKLRaXiuhJaxFa317Zd9XSj0I3A3001p7fay44mp+AY8t3cXmg6d5+feduT+mpdkliQBWZXDZcZBtO5xVuP9PmISWGazaNGIDSqkBOO4T36K1vmR2PQD5BZonl+9h3f4feP7uToyI9YPGJiuQcVJtnnYVWm6QbTucxdj5v4RWeF0JLZ+zY9OIdUwHagIblFIA27TWfzKrmIICzdOr9vLxnhM8PaAD437dxqxS/I+Mk2rztKvQUoMssTC0IiS0hE1pra83u4YiWmv+viaFlckZPHF7Ox7+zXVmlyQE4GFwWWmQJR7OYuyCHbRoEMqS+FgJLSE8oLXmpbUHSEg8xsO/uY7H+7UzuyQhivnFkk/bj2QzdsEOmtcPZemEOJrUlZ1WhagurTXT1h1i/tbvGde3DU/1b0/hjIoQlmD74Np+JJsx87dLaAlhkLc2pTHzi3RGxEbx97s7SmgJy7H1WoU7vneEVrP6oSyNt0Boydpjwube2/Idb25M476ekUwe2FlCS1iSba+4dnyfzYPzHKG1LD6OJvUscKUla48JG5v79RFeXXeIe6NbMG1IV2rUkNAS1mTL4Er6PpsxVgstIWxs8bajTF57gDs7N+P1odEESWgJC7NdcCUVXmk1rSehJYQRlicdZ9LqFPp1aMJbf+jONUG2e1kQAcZWv6HJR38JraUTJLSE8NSa3Zk8vWovN7VrzLsjehByja1eEkSAss1vafLRbEbP/SW0mkpoCeGRz/ad5Mnle+jduiGzRsUQGhxkdklCuMQWwZV89BwPzttBE6uHVkVrjMnaY8JiNqWe4rFlu+jWsgHzxvQiLERCS9iH5dvhHaG1nfC6NVkab+HQAml5F7bw5bdneHjxTjo2r8f8sb2oXdPyLwNClGLpK66i0GpcJ4Sl8XE0q2/h0BLCBr5Jz2LCoiTahtdm4bje1AsNNrskIdxm2eDaeeyX0Fo2oY+ElhAeSj6azUPvO/aoSxgfS4NaIWaXJES1WDK4dh07x4Nzt9OoTghLJ8iVlhCe2puRw5h5O2haL5SE8bE0kj3qhI1ZLrh2HTvH6LnbaVgnhGUT4mheP8zskoSwtQMnzjNq7nbq1womYXysfIxE2J6lgqtoKwUJLSGM84/PUqkVEsTS+DhaNJAxJezPUu1ESin+PaonV/O1hJYQBpk+rAc5l3Np2bCW2aUIYQiPgkspNRkYCBQAp4ExWusTnpzT9BXehTCRN8ZU/VrB1K8l3YPCf3g6Vfia1rqr1robsBZ43oCahAhkMqaEqIJHwaW1Pl/iy9qA9qwcIQKbjCkhqubxPS6l1MvAaOBH4NZKHjcBmAAQFRXl6dMK4bdkTAlROaV15W/olFIbgWZOvvWc1npNicc9C4RqrV+o6kljYmJ0UlKSu7UKYRlKqWStdUw1f1bGlBBOuDquqrzi0lrf7uJzLgE+AaocZEIEMhlTQnjG067CdlrrtMIv7wUOuvJzycnJZ5VSRyt5SGPgrCe1GUzqqVwg1tPKGyeVMWUaqadyvqrHpXFV5VRhpT+s1CqgPY7W3aPAn7TWHi+RrpRKqu40jDdIPZWTeowjY8ocUk/lrFaPR1dcWushRhUihJAxJYQrLLXkkxBCCFEVqwbXLLMLKEPqqZzUY31W+zuReion9VTCo3tcQgghhK9Z9YpLCCGEcEqCSwghhK1YNriUUq8ppQ4qpfYqpT5SSjUwuZ77lVL7lVIFSilT2kKVUgOUUoeUUt8ppZ4xo4Yy9cxTSp1WSqVYoJaWSqnPlVKphf9Oj5tdk9XImHJag4ypimux7JiybHABG4DOWuuuwLfAsybXkwIMBr4048mVUkHAu8CdQCdgmFKqkxm1lLAAGGByDUWuAn/VWncE4oBHLPD3YzUypkqQMVUly44pywaX1vq/WuurhV9uAyJNridVa33IxBJ6A99prQ9rrXOBZTj2bTKN1vpLINvMGoporU9qrXcW/v9PQCoQYW5V1iJjqhwZU5Ww8piybHCVMQ74zOwiTBYBHC/xdQYW+SWyGqVUa6A7kGhuJZYmY0rGlMusNqY83tbEE66skq2Ueg7HJWuCFeoxkXJyTD7LUIZSqg6wCniizN5WAUHGlFtkTLnAimPK1OCqapVspdSDwN1AP+2DD5y5sWq3GTKAliW+jgQ82tLd3yilgnEMsASt9Ydm12MGGVNukTFVBauOKctOFSqlBgBPA/dqrS+ZXY8F7ADaKaXaKKVCgD8AH5tck2UopRQwF0jVWr9udj1WJGOqHBlTlbDymLJscAHTgbrABqXUbqXUTDOLUUr9XimVAfQBPlFKrffl8xfeVP8zsB7HTdLlWuv9vqyhLKXUUuAboL1SKkMp9ZCJ5fQFRgG3Ff6+7FZK/c7EeqxIxlQJMqaqZNkxJUs+CSGEsBUrX3EJIYQQ5UhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSv/Hx6L8u+u5Y/dAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### 2D Decision Boundary\n", "##########################\n", "\n", "w, b = logr.weights, logr.bias\n", "\n", "x_min = -2\n", "y_min = ( (-(w[0] * x_min) - b[0]) \n", " / w[1] )\n", "\n", "x_max = 2\n", "y_max = ( (-(w[0] * x_max) - b[0]) \n", " / w[1] )\n", "\n", "\n", "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n", "\n", "ax[0].plot([x_min, x_max], [y_min, y_max])\n", "ax[1].plot([x_min, x_max], [y_min, y_max])\n", "\n", "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## High-level implementation using the nn.Module API" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class LogisticRegression3(torch.nn.Module):\n", "\n", " def __init__(self, num_features):\n", " super(LogisticRegression3, self).__init__()\n", " self.linear = torch.nn.Linear(num_features, 1)\n", " # initialize weights to zeros here,\n", " # since we used zero weights in the\n", " # manual approach\n", " \n", " self.linear.weight.detach().zero_()\n", " self.linear.bias.detach().zero_()\n", " # Note: the trailing underscore\n", " # means \"in-place operation\" in the context\n", " # of PyTorch\n", " \n", " def forward(self, x):\n", " logits = self.linear(x)\n", " probas = torch.sigmoid(logits)\n", " return probas\n", "\n", "model = LogisticRegression3(num_features=2).to(device)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "##### Define cost function and set up optimizer #####\n", "cost_fn = torch.nn.BCELoss(reduction='sum')\n", "# average_size=False to match results in\n", "# manual approach, where we did not normalize\n", "# the cost by the batch size\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | Train ACC: 0.987 | Cost: 5.581\n", "Epoch: 002 | Train ACC: 0.987 | Cost: 4.882\n", "Epoch: 003 | Train ACC: 1.000 | Cost: 4.381\n", "Epoch: 004 | Train ACC: 1.000 | Cost: 3.998\n", "Epoch: 005 | Train ACC: 1.000 | Cost: 3.693\n", "Epoch: 006 | Train ACC: 1.000 | Cost: 3.443\n", "Epoch: 007 | Train ACC: 1.000 | Cost: 3.232\n", "Epoch: 008 | Train ACC: 1.000 | Cost: 3.052\n", "Epoch: 009 | Train ACC: 1.000 | Cost: 2.896\n", "Epoch: 010 | Train ACC: 1.000 | Cost: 2.758\n", "\n", "Model parameters:\n", " Weights: Parameter containing:\n", "tensor([[ 4.2267, -2.9613]], device='cuda:0', requires_grad=True)\n", " Bias: Parameter containing:\n", "tensor([0.0994], device='cuda:0', requires_grad=True)\n" ] } ], "source": [ "def comp_accuracy(label_var, pred_probas):\n", " pred_labels = custom_where((pred_probas > 0.5).float(), 1, 0).view(-1)\n", " acc = torch.sum(pred_labels == label_var.view(-1)).float() / label_var.size(0)\n", " return acc\n", "\n", "\n", "num_epochs = 10\n", "\n", "X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)\n", "y_train_tensor = torch.tensor(y_train, dtype=torch.float32, device=device).view(-1, 1)\n", "\n", "\n", "for epoch in range(num_epochs):\n", " \n", " #### Compute outputs ####\n", " out = model(X_train_tensor)\n", " \n", " #### Compute gradients ####\n", " cost = cost_fn(out, y_train_tensor)\n", " optimizer.zero_grad()\n", " cost.backward()\n", " \n", " #### Update weights #### \n", " optimizer.step()\n", " \n", " #### Logging #### \n", " pred_probas = model(X_train_tensor)\n", " acc = comp_accuracy(y_train_tensor, pred_probas)\n", " print('Epoch: %03d' % (epoch + 1), end=\"\")\n", " print(' | Train ACC: %.3f' % acc, end=\"\")\n", " print(' | Cost: %.3f' % cost_fn(pred_probas, y_train_tensor))\n", "\n", "\n", " \n", "print('\\nModel parameters:')\n", "print(' Weights: %s' % model.linear.weight)\n", "print(' Bias: %s' % model.linear.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluating the Model" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test set accuracy: 100.00%\n" ] } ], "source": [ "X_test_tensor = torch.tensor(X_test, dtype=torch.float32, device=device)\n", "y_test_tensor = torch.tensor(y_test, dtype=torch.float32, device=device)\n", "\n", "pred_probas = model(X_test_tensor)\n", "test_acc = comp_accuracy(y_test_tensor, pred_probas)\n", "\n", "print('Test set accuracy: %.2f%%' % (test_acc*100))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAADFCAYAAAAMsRa3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VNX9//HXISYk7AJhSwiLIotAWAIJpWoVLVgXKCiWVUBC69e61P5wqVT9ClrQ1hWFsgsEkEXBokJZxIWvBBLWQMAYEEhAloSIbCYk5/fHJDHLJJnJ3Jl778zn+Xj4eJibyZ0PkDPvued+5hyltUYIIYSwixpmFyCEEEK4Q4JLCCGErUhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSsSXEIIIWzlGjOetHHjxrp169ZmPLUQhkhOTj6rtQ43u44iMqaEP3B1XJkSXK1btyYpKcmMpxbCEEqpo148dyjwJVATxxhdqbV+obKfkTEl/IGr48qU4BJCVOpn4Dat9QWlVDDwtVLqM631NrMLE8IKJLiEsBjtWPn6QuGXwYX/yWrYQhSS5gwhLEgpFaSU2g2cBjZorROdPGaCUipJKZV05swZ3xcphEnkiksIC9Ja5wPdlFINgI+UUp211illHjMLmAUQExNT7oosLy+PjIwMrly54pOa7SI0NJTIyEiCg4PNLkVUk8fBVZ0byUII12itc5RSW4ABQEoVDy8lIyODunXr0rp1a5RSXqnPbrTWZGVlkZGRQZs2bcwux/JyLuXy9Kq9TLqrEy0b1jK7nGJGTBUW3UiOBroBA5RScQacVxhs9a5M+k7dTJtnPqHv1M2s3pVpdknCCaVUeOGVFkqpMOB24KC757ly5QqNGjWS0CpBKUWjRo3kKtQF56/k8eC87Xx+8AxHsy6ZXU4pHl9xyY1ke1i9K5NnP9zH5bx8ADJzLvPsh/sAGNQ9wszSRHnNgfeVUkE43lwu11qvrc6JJLTKk7+Tql38+Spj5+9g/4nzzBjZk1+3a2x2SaUYco+rcIAlA9cD71Z0IxmYABAVFWXE0wo3vLb+UHFoFbmcl89r6w9JcFmM1nov0N3sOkRgupybz0Pv72DXsXNMH96DOzo1NbukcgzpKtRa52utuwGRQG+lVGcnj5mltY7RWseEh1tmwYGAcSLnslvH7UhrzYYDp3BMAgirevHFF/nnP//plXMnJyfTpUsXrr/+eh577DH5XXDTlbx8JixKIvFINq8P7cbvujQ3uySnDG2H11rnAFtw3EgWFtKiQZhbx+1Ga83Ln6QSvzCJ9ftPmV2OLfnDPdCHH36YWbNmkZaWRlpaGuvWrTO7JNvIvVrAn5fs5Ku0s0wb3NXSMzEeB5dRN5KFd03s356w4KBSx8KCg5jYv71JFRlHa80rn6Yy5+sjjPlVa/rfaL2pDasrugeamXMZzS/3QD0Nr4ULF9K1a1eio6MZNWpUue/Pnj2bXr16ER0dzZAhQ7h0ydEEsGLFCjp37kx0dDQ333wzAPv376d3795069aNrl27kpaWVupcJ0+e5Pz58/Tp0welFKNHj2b16tUe1R8oruYX8PiyXWxMPc3kgTcytFdLs0uqlBH3uAy7kSy8p+jd02vrD3Ei5zItGoQxsX97S7+rckVRaM3+6ggP9mnFC/d0kpvv1eCNe6D79+/n5ZdfZuvWrTRu3Jjs7Oxyjxk8eDDx8fEATJo0iblz5/Loo4/y0ksvsX79eiIiIsjJyQFg5syZPP7444wYMYLc3Fzy80vXm5mZSWRkZPHXkZGRZGba76rR1/ILNH9dsYfPUn5g0l0dGdWntdklVcmIrkK5kWwTg7pH2D6oStJa84/PDhaH1ov33iihVU3euAe6efNm7rvvPho3dnSkNWzYsNxjUlJSmDRpEjk5OVy4cIH+/fsD0LdvX8aMGcPQoUMZPHgwAH369OHll18mIyODwYMH065du1LncnY/qzq/D6t3ZfrdG7yKFBRonv1wL2t2n2Bi//aMv6mt2SW5RJZ8EraktWbqZweZ9eVhRktoecwb90C11lX+m4wZM4bp06ezb98+XnjhheLPV82cOZMpU6Zw/PhxunXrRlZWFsOHD+fjjz8mLCyM/v37s3nz5lLnioyMJCMjo/jrjIwMWrRo4VbN3poytSKtNc9/nMLypAweu+16Hrn1erNLcpkEl7CdotD695eHGRXXiv+V0PKYN+6B9uvXj+XLl5OVlQXgdKrwp59+onnz5uTl5ZGQkFB8PD09ndjYWF566SUaN27M8ePHOXz4MG3btuWxxx7j3nvvZe/evaXO1bx5c+rWrcu2bdvQWrNw4UIGDhzoVs2VTZn6E601Uz5JZfG2Y/zx5rb85Y4bzC7JLbJWobAVrTVT1/0SWi8NlNAygjfugd54440899xz3HLLLQQFBdG9e3cWLFhQ6jGTJ08mNjaWVq1a0aVLF3766ScAJk6cSFpaGlpr+vXrR3R0NFOnTmXx4sUEBwfTrFkznn/++XLPOWPGDMaMGcPly5e58847ufPOO92qOVA+NvLa+kPMLWxmeubODrYbQ8qMzznExMRo2fROuEtrzbR1h5j5RToj46KYPLCzaQNOKZWstY4x5cmdcDamUlNT6dixo0kVWVtFfzd9p24m00lIRTQIY+szt/miNK97e1Mar2/4lmG9o3jl9+aNIWdcHVcyVShsQWvNq+sdoTUiNoqX7rXWgBP+wZ8/NgIw84t0Xt/wLYN7RPDyIPuOIZkqFJZXFFoztjhCa/LAztSoYc8BJ7zr3KVcTv14hdz8AkKCatC0fijX1gpx+ef99WMjAPO3HmHqZwe5u2tzXh3S1dZjSIJLWFrRfPyMLekMl9ASlTh3KZfMc5cpKLz9kZtfQOY5x7Sfu+HlD0FV0pLEY/zvfw7w205NeeOBblwTZO/JNntXL/ya1pp//vcQ721JZ1jvKKZIaIlKnPrxSnFoFSnQmlM/BvYWJiuTM3hu9T5ubR/OO8O7E2zz0AIJLmFRRaH17ueO0Hp5kISWqFxufoFbxwPBx3tO8NTKPfzqukbMGNmTmtcEVf1DNiDBJSxHa82//vttYWi1lNASLgmp4EqiouP+bl3KD/zlg93EtGrI7NExhAb7R2iBBJewGK01r2/4lumff8cferXk5UFdJLT8jLe2NWlaP5Tpr07ht71vJK69Y83CGkrRtH6o4c9ldZ8fPM2jS3fSNbI+88b2olaIf7Uz+NefRtia1po3NnzLO5sdofXK7yW0fOaVCMi9UP54SB34mz2WO7q2Vgj3DR7IyHETGNC3R7W6Cv3B12ln+ePiZNo3q8uCsb2pU9P/XubliktYQlFovb35Ox6ICezQUkq1VEp9rpRKVUrtV0o97vUndRZalR13kS+3NQH47W9u4uZuN1BDQYfm9QIutBIPZzF+4Q7aNq7NonGx1A8LNrskr/C/KBa29MbGtOLQ+sfgwA2tQleBv2qtdyql6gLJSqkNWusDZhfmDl9vaxLoko+eY9yCHUQ0CGPRQ7FcW9t/Q1uuuITp3tjwLW9vSmNoTKSEFqC1Pqm13ln4/z8BqYDtPljk6rYmN910E126dCEhIYH9+/cDv2xrMnv27OKA6tOnD6+88grTpk3j6NGjhIX5x+7dRtiX8SNj5m+ncd2aLImPI7xuTbNL8ioJrgBn9nbtb2z4lrc2pXF/z0imDrb3p/m9QSnVGsd+d4lOvjdBKZWklEo6c+aMr0urkq+3NQlUqSfPM2peIvVCg1kSH0fTev7fjOJxcJkyHy8MYfbeQ29u/CW0ptl8CRpvUErVAVYBT2itz5f9vtZ6ltY6RmsdEx4e7vsCq+DrbU0C0Xenf2LknERCrwliaXwcER7sn2YnRlxxFc3HdwTigEeUUp0MOK/wMjP3HnprYxpvbkzjPgktp5RSwThCK0Fr/aHXnzCkjnvHXVByW5Po6GiefPLJco8p2tbkjjvuoEOHDsXHJ06cSJcuXejcuTM333wz0dHRfPDBB3Tu3Jlu3bpx8OBBRo8eXe58Tz31FJGRkVy6dInIyEhefPHFatdvdUfOXmT47ESUUiyJjyWqUS2zS/IZw7c1UUqtAaZrrTdU9BjZ1sQa2jzzCc7+9RVwZOpdXnvetzam8cbGb4tDK8iGoeXNbU2UY37tfSBba/2EKz8j25q4x+5/N8ezL/HAv7/hytUClk2I44amdc0uyRCmbGti5/n4QOSN7dqr8vYmR2gN6WHf0PKBvsAo4Dal1O7C/35ndlHCGk7+eJnhc7Zx4eerLHqot9+EljsMa4d3ZT4emAWOd4dGPW+gmLR6H0sTj5OvNUFKMSy2JVMGdfHonBP7t+fZD/eVmi705t5D7xRuYDe4RwSv3iehVRGt9dc4LnyFKOX0+SsMn53IuYt5JIyP5cYW9c0uyRSGBJfP5+MDzKTV+1i87Vjx1/laF3/tSXj5cu+hdzal8a8N3zK4ewSv3RctoeUjrnT2BRozdn03QtaFnxkxJ5FT56+wcFxvols2MLsk03gcXIXz8XOBVK31656XJMpamni8wuOeXnX5Yu+h6ZtLhNb9Elq+EhoaSlZWFo0aNZLwKqS1Jisri9BQe7WM51zKZeTc7RzLvsT8sb2IaV3+M3GBxIgrrqL5+H1Kqd2Fx/6mtf7UgHMLHFdY7hy3knc//45//ldCywyRkZFkZGTgr/eUcy7lcvHnfDSOedXaNYNo4MIST6GhoURGRnq9PqOcv5LHg/O2k376ArMfjOFX1zU2uyTTeRxcMh/vfUFKOQ2pIIu/i3738+94bf0hfi+hZYrg4GDatGljdhle4Zg+zyh3fGRclMezEFZy8eerjJ2/g/0nzjNzZE9uucF6n9czg6ycYQPDYlu6ddwKikJrULcW/FNCSxissulzf3E5N5+H3t/BrmPneHtYd27v1NTskixDFtn1kdW7MqvdBFH0DtLTrkJPanDHe1t+Ca1/De0moSUMZ+fpc1dcyctnwqIkEo9k88bQbvyuS3OzS7IUCS4fKFpaqajtvGhpJcCt8PJkCsSIGlwxY0s6r647xEAJLeFFdp0+d0Xu1QL+vGQnX6Wd5dUhXb3ePGVHMlXoA2YureTLGmZsSWfauoPcG92Cf8n0oPAiI6bPzV5g2pmr+QU8vmwXG1NPM3lQZ4b2su7tADPJFZcPnMi57NZxO9Yw84tfQuv1odFcEyTviYT3eDp97qsZCHfkF2j+umIPn6X8wKS7OjIqrpUpddiBBJcPtGgQRqaTgPDm0kq+rOHfX6Qz9bOD3COhJXzIk+nzymYgzAiuggLNM6v2smb3CSb2b8/4m9r6vAY7kVcYH5jYvz1hwUGljnlzaSVf1vDvL9L5R2FovSGhJWzCCrMgRbTWPP9xCiuSM3isXzseufV6n9dgN3LF5QODukeQdDS71LTGkJ4Vr1jhje4/byzvNOtLR2jd3bW5hJawFSvMgoAjtCavTWXxtmP88Za2/OX2dj59fruS4PKB1bsyWZWcWdwFla81q5IziWnVsFxweHPu3cjlnWZ/eZhXPnWE1psPdJPQErbi6wWmndFa8+r6Q8zbeoQxv2rNMwM6yNJcLpJXGx9wp6PPCh2IVZnz1WFe/jSVuyS0hE0N6h7BPwZ3IaJBGAqIaBDGPwZ38en9rbc3fceMLekM6x3FC/d0ktByg1xx+YA78+lWmnt3Zs5Xh5nySSp3dWnOWxJawsZ8scB0RWZ+kV68L93LgzpLaLlJXnV8wJ0NG83Y3NFVJUPrzT9IaAlRHfO3Hinuwn31vq7UkM87ui1gX3m89eFDZ+d1p6PPCh2IzhSF1u+6NOPNP3QjWEJLCLclJB7lf/9zgP43NuX1ofIh/eoKyFefogaIzJzLaH5pgPA0vCo6L+DyfLoV5t7Lmvv1EaZ8ksqdnZvx1h+6S2j5gFJqnlLqtFIqxexahDFWJmfw3Ecp3No+nHeG9ZBx5IGAvMflrQ8fVnberc/c5vK5zZx7L2vu10eYvPYAd3ZuxtvDJLR8aAEwHVhoch3CAB/vOcFTK/fw6+sbM2NkT0KukXHkiYAMLm81QFi9scJd8yS0TKO1/lIp1drsOqzMW7sdGH3edSk/8JcPdhPTuiGzRvcktMytAOE+Q16J7Dat4a0GCCs3Vrhr/tYjvLT2AANulNCyKqXUBKVUklIqyV93Oa6Ir6f7q3vezQdP8ejSnXSNrM+8Mb2oFRKQ1wqGM+rVaAEwwKBzeZ23GiCs2ljhrgVbjxTfQH5nuISWVWmtZ2mtY7TWMeHhgbUzrrc+72jkeb9OO8ufFu+kfbO6LBjbmzo1JbSMYsjfpN2mNbyx/FHReVckHWNrenbxsR5R9RnUPcLp9IM3avDUgq1HeLEotOQGsrAoq0/3Jx7OYvzCHbRtXJtF42KpHxbsUV2iNJ+9BVBKTQAmAERFRfnqaSvkjQaISav3lQotgK3p2YyY/Q07j/1YahmniSv3gIa8Al18zOxtFd7/v+958T8H+G0nR2jJDWRhVd5aa9CI8yYfPce4BTuIaBDG4vGxXFs7xKOaRHk+e2UKhGmNpYnHnR7fmp5dbvohL18Xh1YRM5d2WvjN97zw8X7u6NSU6cMltMymlFoKfAO0V0plKKUeMrsmK7HqdP/ejBzGzNtOeN2aLImPo3Gdmh7VI5yTSVcDOdtK3F1mdCAu/OZ7nl/jCK13JbQsQWs9zOwarMzdHRecqax7sDpT+KknzzNq7nbqhQWTEB9H03qh1f7zicpJcBkoSCmPw8vXHYiLvpHQEvbjzo4LFf18ZbswuDtd/93pnxg5J5Gw4CCWxscRYcNOYjsxqh1epjWAYbEtnR7ve13DctMPwUGK4DLLvfi6A3HRtqP8fc1+bu8ooSXsxdPuPyO7B4+cvcjw2YkopVgSH0tUo1pun0O4x6iuQttNa7jzIcNJq/eVmpIYFtvS6ZbhUwZ1IfFwFmmnLxYfa9ekNgnxfZw+n5mbSy7adpS/r07h9o5NeG+EC6H1SgTkXih/PKQO/M2YdR6FcJWn3X9GdQ8ez77E8NnbuFqgWTYhjrbhddz6eVE9ATlV6M5mjZNW72PxtmPFX+drXfx12fCatHpfqdACSDt9kUmr9zFlUOn1Bs3cXHJxidB615XQAuehVdlxIbzI0+4/I7oHT/54meFztnEpN5+l8XHc0LSuyz8rPBOQc0PuTBNU1Cno7Lg7jzVrc8mExKNMWp1Cvw6O0Kp5jSw/I+zH0+4/T3/+9PkrDJ+dSM7FPBaO602nFvVcK1wYIiCvuNyZJqio2cLZcXcea8bmkgmJR3nuoxRu69CE90ZKaAn78nQRAU9+PuvCz4yYk8ip81dYOK430S0bVP8P4okAnr4PyOByZ5qgok7BICc7lrrzWHdqMGJaY0niseLQmiGhJfyAp4sIVOfncy7lMnLudo5lX2LB2N7EtG5Y7ef3WABP3wfkVKE70wQVdQo6O17ZY8tuMHlrh3CfbS65JPEYf/ton4SWEB44fyWP0fO2k376ArNHx9DnukZmlxSwAjK43NmsccqgLoyMiyq+agpSipFxUU67CmNaNaTshqZFX5ddcXpVciZDekZ4fXPJpdsdoXVr+3DPQiukgm6pio4L4Ucu/nyVsfN3cODEed4b0YObb/DP1X/sQmkDVntwV0xMjE5KSvL583pb36mbnU7pVTSFGNEgjK3P3Oa1epZtP8YzHxaFluwDZCSlVLLWOsbsOor465iygsu5+YxdsJ0d359j+rDu3NmludklObxYv5Lv/ei7Ogzk6rgKyHtc3lJRs0RFTRveXN6pKLR+I6ElTOatDR994UpePhMWJZF4JJs3H+hmndCqirNQ86OmjYCcKvSWipolnDVnVPZ4T32wwxFat9wQzkwJLWEib2346Au5Vwt4JGEnX6WdZdqQrgzsZrGwdXea3o+aNuSKy0AT+7cv9UFhcDRRDOkZwarkzHLHvbG80/Idx4tD69+jJLSEuSr7DKKVr7qu5hfw+LJdbDp4msmDOjM0xnnjlakqunqqbArRT/hVcLkzJVHRYz2Z1qhoxeopg7oQ06qh16dLlu84ztMf7uWmdh6GVqB9PiTQ/rw+5K0NH70pv0Dz1xV7+CzlBybd1ZFRca3MLkmU4TfB5c6ySBU9NulodqkrI3eXVqpqGSdvvsNcnvRLaM3y9Eor0D4fEmh/Xh/y1oaP3lJQoHlm1V7W7D7BUwPaM/6mtmaXJJzwm3tcRiyhtDTxuGVWnHbHiqTjPL1qL7++vrHnoSWEgby14aM3aK35+5oUViRn8Hi/dvzPb643uyRRAb+54jJiCSVPu//MmBZZkXScpwpDa/boGAktYSmeLs3kK1prJq9NJSHxGH+8pS1P3N7O7JKqL6ROxVPfrnBn6tykaXa/CS4jllCq6PNWvlxx2h0rkzMktPyUUmoA8BYQBMzRWk81uaRq8/Y0uae01ry6/hDzth5hzK9a88yADqgKOoFtwdPAcGfq3KRpdr8Jroo6+ipaQmniyj3k5f8SUsFBigd6tXTa/Xdrh3D6Tt1c7h3jiNnfsDU9u/ix7ZrUJiw4yCfdgyuTM5i4co+Elh82ViilgoB3gTuADGCHUupjrfUBcyvzT29v+o4ZW9IZHhvFC/d0sndoBQijdkAeoJQ6pJT6Tin1jBHndJfbyyKVvbDSjiWbyp6jqJW97OdQ7nh9S6nQAsfeW5HXhlZraSZ3rCoMrb7XeSm07LS8kxHv+Kz35+0NfKe1Pqy1zgWWAQPNKsafzdiSzhsbv+W+npFMGdhZQssmPL7istK7Q1enJF5bf4i8gtLJlVegeW39IbY+c1upc/Sdutlpw0XZDSOLpJ2+yPdT76pG9a5ZlZzB//NmaIFtr1SqzXp/3gig5CZuGUBs2QcppSYAEwCioqJ8U5kfmff1EaatO8g90S2YNqQrNcouNCosy4grLtu9OzSikcMMH+50hNavrmvE7NExhIUE6PSg/3P2Clru5qvWepbWOkZrHRMeLou+uiMh8SgvrT1A/xub8vrQaIIktGzFiHtctnt3aEQjh699tCuDv67YQ5+2jZgzupeEln/LAEou1RAJnDCpFr+zMjnDsTddjZ28890bBE8uMaNi43ujhnGnK9HTDsZqMiK4XH53CMwCx0rWBjxvtbnbyOHssZHXhjqdLux7nfEby320K4MnlztCa+6DEloBYAfQTinVBsgE/gAMN7ck/7BmdyZPrdzDr2vs473gtwhRpW8DyIfOcS+4TQp5I4LLlHeHzpZmAtc+L+LOZ0sGdY9gRdKxUo0YPaLqkxDfhzte31IqvNo1qc39MVFOOxA9+XP+1Z3QcnedMk8+m+Hrjr6Kns/PaK2vKqX+DKzH0Q4/T2u93+SybG9dykmeXL6HmNYNmX3iX4SqPLNLEtVkRHD5/N2hsyWbJq7cA5ripouqlmtytZFj0up95boHt6ZnM2L2N2Scu1Lq+PdZl5i4Yo/LNVRlze5Mnly+m9g2XrzS8uSzGb7+DEcAhFYRrfWnwKdm1+EvNh88xaNLdxEdWZ95Y3oR9o9cs0sSHvC4OUNrfRUoeneYCiz39rtDZ0sr5eXrcp2CRiy3tDTxuNPjW9OzvVrDmt2Z/OWD3fRu05C5Y6QRQ4jq+irtDH9avJMOzeoxf2xv6tT0m4+vBixD/gV9/e7QnU4/T7sCK1oGyh3u1lAytOaN6UWtEBloQlTHtsNZxC9Mom3j2iwc15v6YcFmlyQMYMtXRHc6/TxdbqmiZaDc4U4NElpCGCP56DnGLdhB5LW1WDw+lmtrh/zyTZO64UqxwZqAVmXLV0VnnX7BQarUPS4wZrmlYbEtWbztWLnjfa9ryM5jPxpaw8d7TvCXD3bTq7WElhCe2JuRw5h522lStyZLxsfSuE7N0g+wwou9DdYEtCpbvjJW1BXobBNHT5dbmjKoC0Cp8w6LbcmUQV086mws6+M9J3hi2S5iWjdk/liTQ8uX70a99U7SistTCZ84cOI8o+Zup36tYJbEx9GkXqjZJQmD2TK4oHxXYFWbOHpiyqAuxQFWWQ0lj7vjPyVCa4HZoQW+fTdq1DvJF3/0vBZhe2mnfmLU3ERqhQSxND7OshtWCs8E5EaSVrJ27wme+GA3Ma0aMl+mB4WotiNnLzJ8TiI1aigSxsfSsmEts0sSXuI3wWXGJo6eWrv3BI8v203PqGuZP7YXtaVNV4hqOZ59ieGzt5FfoFkyPpa24TJV7M/85pXS15s4euqTvSd5fNluekQ1kNASwgMnci4zfM42LuXmszQ+jnZN65pdkmsqupcMrq+AE6D3cv3m1dKd9QfN9snekzy2bBc9ohqwYGxvCa2KuLt8VUnSPhwQTp+/wog5ieRczGPx+Fg6tahndkmuq+j3sLLfe7mXC/hRcLmz/qCZPt3nCK3uLRsw3xuhVdm7OCsyol5n7zqlfdjvZV34mRFzEjl1/gqLHupNdMsGZpckfMRvggtcX3/QLJ/uO8mjSx2htWCcl5aeqc67ODM5q1fecYoq5FzKZeTc7RzLvsSCsb3p2cr4XRmEdflNc4bVfeaL0BIiAJy/ksfoedtJP32B2aNj6HNdI7NLEj4mweUD61IcodVNQksIj1z4+Spj5m3nwInzzBjZg5tvkJ2fA5G8gnrZupST/HnJLrpG1mfB2F4SWkJU0+XcfB5asIM9GT8yfVh3+nVsanZJ5RnRFGSFdRQtTl5FvWhdyg/FofX+uN7UDS2xMrW3ut6qs9li2XtK7izy6YynA8yIgSuD369cyctnwqIktn+fzZsPdOPOLs3NLsk5I5qCpOu1ShJcXuIIrZ3OQwu81/VW2c+XbWyoqAnCnUU+nZ3XU0YMXBn8fiP3agGPJOzkq7SzvHZfVwZ2s24DlvANucflBev3O0KrS0WhJUQFlFL3K6X2K6UKlFIxZtdjtqv5BTy2dBebDp5myqDO3B/T0uyShAV4FFwyyMpbv/8HHkmQ0BLVlgIMBr40uxCz5Rdonly+h3X7f+Dvd3diZFwrs0sSFuHpFZcMshL+WxhanSMcoVVPQku4SWudqrW29srQPlBQoHl61V4+3nOCpwd04KFftzG7JGEhHt3j0lqnAiiljKnGxjYOw9vIAAANO0lEQVQcOMUjSxyhtfAhCS3hfUqpCcAEgKioKJOrMY7Wmr+vSWFlcgaP92vHw7+5zuyShMX4rDnDXwcZOELrfxKS6dTCjdAyouvN3Q5CTxbulC49wyilNgLNnHzrOa31GlfPo7WeBcwCiImJ0VU83Ba01ry09gAJicf40y3X8cTt7cwuSVhQlcElg6xyG0uE1iJ3rrSM6HozooPQ2WOdkS49w2itbze7BivSWjNt3SHmb/2esX1b8/SA9jKbI5yqMrhkkFVs44FTPJyQTKfm9Vgo97SE8Mhbm9KY+UU6I2KjeP7uThJaokLSDl9Nm1JLhNZDsdQPk9ASnlNK/V4plQH0AT5RSq03uyZfmLElnTc3pnFfz0gmD+wsoSUq5dE9LqXU74F3gHAcg2y31rq/IZVZ2KbUU/xpcTIdJbSEwbTWHwEfmV2HL837+gjT1h3knugWTBvSlRo1LBZa7qxyI/eCfcLTrsKAG2SbD57i4cU76di8HotcDS3Z1FD+DoRTi7cd5aW1BxhwYzNeHxpNkNVCC9xb5UZ+l31Cpgrd8PnB0/xp0U7aN6vLonFuXGlZYVPDit7x+eqdoBX+DoSlrEg6zqTVKdzWoQlvD+tOcJC8HAnXyFqFLvr84Gn+uCiZ9s3qsvihWOrXstn0oLwTFBayZncmT6/ay03tGvPeiB6EXCOhJVwnvy0usH1oCWEh61JO8uTyPcS0bsisUTGEBgeZXZKwGQmuKnx+yBFaNzSrI6ElhIc2pZ7i0aW7iI6sz7wxvQgLkdAS7pOpwkpsKQytdk0ltITw1FdpZ4obmyy7E7irq8uAdAqayIK/Odaw5dBpJixKpl2TOiSMj6VBrZDSD7BCi6ydWm/tVKsw3LbDWcQvTKJteG37fljf6H3nRLVJcDnxxbdnKg8tsEaLrJ0aLuxUqzBU8tFsxi3YQeS1tVhc0XgSwg1yj6uML749Q/zCJK4PryS0hBAu2ZuRw5h5O2hStyZLxsfSuE5Ns0sSfkCCqwQJLSGMc+DEeUbN3U79WsEsiY+jSb1Qs0sSfkKCq9CXZULr2toSWkJUV9qpnxg1N5FaIUEsjY+jRYMws0sSfkTuceHodopfmMR1ElpCeOzI2YsMn5NIjRqKJfFxtGxYq/onk6XChBMBEVyrd2Xy2vpDnMi5TIsGYUzs355B3SMAR2iNfz+JNo1ruxdaVuiSs9OgtlOtotqOZ19i+Oxt5BdoPpgQR5vGtT07oa+XCpPOQVvw++BavSuTZz/cx+W8fAAycy7z7If7AGhcp2ZxaC2Jj6OhO1daVnixtdP6f3aqVVTLiZzLDJu9jUu5+SyNj6Nd07pmlyT8lN8H12vrDxWHVpHLeflMXnuACz9frV5oCSFKOX3+CiPmJPLjpTwS4mPp1KKe2SUJP+b3zRknci47PZ51MVdCSwgDnL3wM8PnJHLq/BUWjOtF18gGZpck/JzfB1dF3UzX1FAkjI+V0BLCAzmXchk5J5GMc5eYN6YXPVs1NLskEQD8fqpwYv/2pe5xASjgxXtupJF8GNI7KmrEEFVSSr0G3APkAunAWK11jrlVOXf+Sh6j5m7n8NmLzH0whri2jYx/Eis0QXmLNCxVm0fBZYdBVtQ9OGXtAc5ezOWaGooX77mRkX1amVyZAaw6qN0JLbNrtZ4NwLNa66tKqWnAs8DTJtdUzoWfrzJm3nYO/nCemSN7clO7cO88kT+/gEvDUrV5esVli0HWpF5NLuRe5YamdVgSH+c/y87YcVBLu3GltNb/LfHlNuA+s2qpyOXcfB5asIM9GT/y7vDu9OvY1OySRIDx6B6X1vq/WuurhV9uAyI9L8lY/5d+lnELdhDVsJZ/hZYIBOOAzyr6plJqglIqSSmVdObMGZ8UdCUvnwmLktj+fTavD41mQOfmPnleIUoysjnDcoPsm/QsCS1hOUqpjUqpFCf/DSzxmOeAq0BCRefRWs/SWsdorWPCw700VVdC7tUC/idhJ1+lneXVIV0Z2C3C688phDNVThUqpTYCzZx86zmt9ZrCx7g0yIBZADExMbpa1bqhKLRaXiuhJaxFa317Zd9XSj0I3A3001p7fay44mp+AY8t3cXmg6d5+feduT+mpdkliQBWZXDZcZBtO5xVuP9PmISWGazaNGIDSqkBOO4T36K1vmR2PQD5BZonl+9h3f4feP7uToyI9YPGJiuQcVJtnnYVWm6QbTucxdj5v4RWeF0JLZ+zY9OIdUwHagIblFIA27TWfzKrmIICzdOr9vLxnhM8PaAD437dxqxS/I+Mk2rztKvQUoMssTC0IiS0hE1pra83u4YiWmv+viaFlckZPHF7Ox7+zXVmlyQE4GFwWWmQJR7OYuyCHbRoEMqS+FgJLSE8oLXmpbUHSEg8xsO/uY7H+7UzuyQhivnFkk/bj2QzdsEOmtcPZemEOJrUlZ1WhagurTXT1h1i/tbvGde3DU/1b0/hjIoQlmD74Np+JJsx87dLaAlhkLc2pTHzi3RGxEbx97s7SmgJy7H1WoU7vneEVrP6oSyNt0Boydpjwube2/Idb25M476ekUwe2FlCS1iSba+4dnyfzYPzHKG1LD6OJvUscKUla48JG5v79RFeXXeIe6NbMG1IV2rUkNAS1mTL4Er6PpsxVgstIWxs8bajTF57gDs7N+P1odEESWgJC7NdcCUVXmk1rSehJYQRlicdZ9LqFPp1aMJbf+jONUG2e1kQAcZWv6HJR38JraUTJLSE8NSa3Zk8vWovN7VrzLsjehByja1eEkSAss1vafLRbEbP/SW0mkpoCeGRz/ad5Mnle+jduiGzRsUQGhxkdklCuMQWwZV89BwPzttBE6uHVkVrjMnaY8JiNqWe4rFlu+jWsgHzxvQiLERCS9iH5dvhHaG1nfC6NVkab+HQAml5F7bw5bdneHjxTjo2r8f8sb2oXdPyLwNClGLpK66i0GpcJ4Sl8XE0q2/h0BLCBr5Jz2LCoiTahtdm4bje1AsNNrskIdxm2eDaeeyX0Fo2oY+ElhAeSj6azUPvO/aoSxgfS4NaIWaXJES1WDK4dh07x4Nzt9OoTghLJ8iVlhCe2puRw5h5O2haL5SE8bE0kj3qhI1ZLrh2HTvH6LnbaVgnhGUT4mheP8zskoSwtQMnzjNq7nbq1womYXysfIxE2J6lgqtoKwUJLSGM84/PUqkVEsTS+DhaNJAxJezPUu1ESin+PaonV/O1hJYQBpk+rAc5l3Np2bCW2aUIYQiPgkspNRkYCBQAp4ExWusTnpzT9BXehTCRN8ZU/VrB1K8l3YPCf3g6Vfia1rqr1robsBZ43oCahAhkMqaEqIJHwaW1Pl/iy9qA9qwcIQKbjCkhqubxPS6l1MvAaOBH4NZKHjcBmAAQFRXl6dMK4bdkTAlROaV15W/olFIbgWZOvvWc1npNicc9C4RqrV+o6kljYmJ0UlKSu7UKYRlKqWStdUw1f1bGlBBOuDquqrzi0lrf7uJzLgE+AaocZEIEMhlTQnjG067CdlrrtMIv7wUOuvJzycnJZ5VSRyt5SGPgrCe1GUzqqVwg1tPKGyeVMWUaqadyvqrHpXFV5VRhpT+s1CqgPY7W3aPAn7TWHi+RrpRKqu40jDdIPZWTeowjY8ocUk/lrFaPR1dcWushRhUihJAxJYQrLLXkkxBCCFEVqwbXLLMLKEPqqZzUY31W+zuReion9VTCo3tcQgghhK9Z9YpLCCGEcEqCSwghhK1YNriUUq8ppQ4qpfYqpT5SSjUwuZ77lVL7lVIFSilT2kKVUgOUUoeUUt8ppZ4xo4Yy9cxTSp1WSqVYoJaWSqnPlVKphf9Oj5tdk9XImHJag4ypimux7JiybHABG4DOWuuuwLfAsybXkwIMBr4048mVUkHAu8CdQCdgmFKqkxm1lLAAGGByDUWuAn/VWncE4oBHLPD3YzUypkqQMVUly44pywaX1vq/WuurhV9uAyJNridVa33IxBJ6A99prQ9rrXOBZTj2bTKN1vpLINvMGoporU9qrXcW/v9PQCoQYW5V1iJjqhwZU5Ww8piybHCVMQ74zOwiTBYBHC/xdQYW+SWyGqVUa6A7kGhuJZYmY0rGlMusNqY83tbEE66skq2Ueg7HJWuCFeoxkXJyTD7LUIZSqg6wCniizN5WAUHGlFtkTLnAimPK1OCqapVspdSDwN1AP+2DD5y5sWq3GTKAliW+jgQ82tLd3yilgnEMsASt9Ydm12MGGVNukTFVBauOKctOFSqlBgBPA/dqrS+ZXY8F7ADaKaXaKKVCgD8AH5tck2UopRQwF0jVWr9udj1WJGOqHBlTlbDymLJscAHTgbrABqXUbqXUTDOLUUr9XimVAfQBPlFKrffl8xfeVP8zsB7HTdLlWuv9vqyhLKXUUuAboL1SKkMp9ZCJ5fQFRgG3Ff6+7FZK/c7EeqxIxlQJMqaqZNkxJUs+CSGEsBUrX3EJIYQQ5UhwCSGEsBUJLiGEELYiwSWEEMJWJLiEEELYigSXEEIIW5HgEkIIYSv/Hx6L8u+u5Y/dAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### 2D Decision Boundary\n", "##########################\n", "\n", "w, b = logr.weights, logr.bias\n", "\n", "x_min = -2\n", "y_min = ( (-(w[0] * x_min) - b[0]) \n", " / w[1] )\n", "\n", "x_max = 2\n", "y_max = ( (-(w[0] * x_max) - b[0]) \n", " / w[1] )\n", "\n", "\n", "fig, ax = plt.subplots(1, 2, sharex=True, figsize=(7, 3))\n", "ax[0].plot([x_min, x_max], [y_min, y_max])\n", "ax[1].plot([x_min, x_max], [y_min, y_max])\n", "\n", "ax[0].scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], label='class 0', marker='o')\n", "ax[0].scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].scatter(X_test[y_test==0, 0], X_test[y_test==0, 1], label='class 0', marker='o')\n", "ax[1].scatter(X_test[y_test==1, 0], X_test[y_test==1, 1], label='class 1', marker='s')\n", "\n", "ax[1].legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "matplotlib 3.0.2\n", "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }