{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hello world!\n" ] } ], "source": [ "print(\"Hello world!\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tinygrad.core import Tensor\n", "from tinygrad.nn import SimpleMLP\n", "from tinygrad.losses import MSELoss, MaxMarginLoss\n", "from tinygrad.optimizers import SimpleSGD\n", "\n", "import random\n", "import numpy as np\n", "from sklearn.datasets import make_blobs, make_moons\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1337)\n", "random.seed(1337)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x, y = make_moons(n_samples=100, noise=0.1)\n", "\n", "y = y*2 - 1 # make y be -1 or 1\n", "\n", "plt.figure(figsize=(5,5))\n", "plt.scatter(x[:,0], x[:,1], c=y, s=20, cmap='jet')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# now we initialize the model\n", "# for demo purposes we initialize a 2, 16, 16, 1 model\n", "model = SimpleMLP(2, 1, [16, 16])\n", "\n", "#temporary fix to make last neuron nonlin = False\n", "for i, layer in enumerate(model.layers):\n", " if i == len(model.layers) - 1:\n", " for n in layer.neurons:\n", " n.nonlin = False" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SimpleMLP(\n", "Linear(ins:2 outs:16 num_parameters:48)\n", "Linear(ins:16 outs:16 num_parameters:272)\n", "Linear(ins:16 outs:1 num_parameters:17)\n", ")\n", "No of model parameters: 337\n" ] } ], "source": [ "model.summary()\n", "print(\"No of model parameters: \", len(model.parameters()))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "For epoch: 0, loss: 0.8862514464368222\n", "For epoch: 5, loss: 0.3039345997419534\n", "For epoch: 10, loss: 0.23570270563284787\n", "For epoch: 15, loss: 0.21692423765783492\n", "For epoch: 20, loss: 0.1520994943457478\n", "For epoch: 25, loss: 0.11578019365064655\n", "For epoch: 30, loss: 0.08059487803507807\n", "For epoch: 35, loss: 0.10226911924959595\n", "For epoch: 40, loss: 0.06718934359576106\n", "For epoch: 45, loss: 0.05021609346959929\n", "For epoch: 50, loss: 0.07339858866572174\n", "For epoch: 55, loss: 0.014531426535936509\n", "For epoch: 60, loss: 0.013640569867793939\n", "For epoch: 65, loss: 0.00821577637588291\n", "For epoch: 70, loss: 0.013711454799957803\n", "For epoch: 75, loss: 0.010083769334297063\n", "For epoch: 80, loss: 0.0035957132908880254\n", "For epoch: 85, loss: 0.00016450233490777455\n", "For epoch: 90, loss: 0.00042543722337187175\n", "For epoch: 95, loss: 0.0\n" ] } ], "source": [ "epochs = 100\n", "lr = 0.005\n", "\n", "# now we define the loss function and optimizers\n", "loss_fn = MaxMarginLoss\n", "optim = SimpleSGD(model.parameters(), lr=lr)\n", "\n", "X = [list(map(Tensor, _x)) for _x in x] #minor preprocessing\n", "losslist = [] #to store losses\n", "\n", "for i in range(epochs):\n", " \n", " preds = list(map(model, X))\n", " \n", " loss = loss_fn(preds, y)\n", " \n", " model.zero_grad()\n", " loss.backward()\n", " \n", " lr = 1.0 - .9*i/epochs #manually change lr because schedulers havent been implemented\n", " optim.lr = lr\n", " optim.step()\n", " \n", " if i%5 == 0:\n", " print(\"For epoch: {}, loss: {}\".format(i, loss.data))\n", " losslist.append(loss.data)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEKCAYAAAAB0GKPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VfWd//HX596bhZBAQMK+L4Ir0UYFcUFtO8h06lLbgtXWage12taZzu/XznSmnUe3nzNdfq3jQnGtM1atdW2LW60biEtAUJRFFoWwBtlCgIQkn/njHiDc3IQDyb0ny/v5eNxHzj33e28+OSzvfM/3e87X3B0REZHDiUVdgIiIdAwKDBERCUWBISIioSgwREQkFAWGiIiEosAQEZFQFBgiIhKKAkNEREJRYIiISCiJqAtoS3369PHhw4dHXYaISIcxf/78Le5eEqZtpwqM4cOHU15eHnUZIiIdhpl9FLatTkmJiEgoCgwREQlFgSEiIqEoMEREJBQFhoiIhKLAEBGRUBQYIiISigIDuOWFD3h5eWXUZYiItGsKDOCOl1Yyd8WWqMsQEWnXFBhAPGbU1XvUZYiItGsKDJKBUd/QEHUZIiLtmgIDSMSMelcPQ0SkJQoMIBYz6hsUGCIiLVFgEPQwFBgiIi1SYAAxM+oUGCIiLVJgAIm40aDAEBFpkQKDYFqtAkNEpEUKDCBuGsMQETkcBQb7r8NQYIiItESBETCLugIRkfYtkakPNrN7gM8Am939xGDfw8DYoEkxsN3dS9O890OgCqgH6ty9LFN1AjS4E1NiiIi0KGOBAdwH3Arcv3+Hu39x/7aZ/QLY0cL7z3P3rNwRsL5BgSEicjgZCwx3f8XMhqd7zcwM+AJwfqa+/5FwT17tLSIizYtqDONsYJO7f9DM6w48Z2bzzWxGpotJnpLK9HcREenYMnlKqiXTgQdbeH2Su683s77A82a21N1fSdcwCJQZAEOHDj2qYhocnZISETmMrPcwzCwBXAo83Fwbd18ffN0MPA6c3kLbWe5e5u5lJSUlR1VTg7tmSYmIHEYUp6Q+CSx194p0L5pZdzMr2r8NfBpYnMmCXD0MEZHDylhgmNmDwDxgrJlVmNk1wUvTSDkdZWYDzWx28LQfMMfMFgFvAn9292cyVSdoDENEJIxMzpKa3sz+q9LsWw9MDbZXAeMzVVc6ug5DROTwdKU30C0nzq6auqjLEBFp1xQYQP+e+WzYsTfqMkRE2jUFBjCwZzc2bN8TdRkiIu2aAgMYUJzPpqoa3bFWRKQFCgxgQM9u1Dc4m6t0WkpEpDkKDGBgcT4A67crMEREmqPAINnDANiwQ+MYIiLNUWCQHPQG2KAehohIsxQYQI9uCQpy46xXD0NEpFkKDMDMGNAzXz0MEZEWKDACg3sV8NHW3VGXISLSbikwAicO6sHyTVXsqa2PuhQRkXZJgREoHdKL+gbnvfUtLTMuItJ1KTAC44f0BGDh2u0RVyIi0j4pMAJ9i/IZVNyNtxUYIiJpKTAaKR1SzMI1CgwRkXQUGI2UDilm3fY9VFbVRF2KiEi7k8klWu8xs81mtrjRvn83s3VmtjB4TG3mvVPMbJmZrTCz72aqxlSlQ4sBWKTTUiIiTWSyh3EfMCXN/v/v7qXBY3bqi2YWB24DLgSOB6ab2fEZrPOAEwf2JB4zDXyLiKSRscBw91eArUfx1tOBFe6+yt1rgYeAi9q0uGZ0y40zrn+RAkNEJI0oxjBuNLN3glNWvdK8PghY2+h5RbAvK0qHFLNo7XbctZiSiEhj2Q6MO4BRQCmwAfhFmjaWZl+z/3ub2QwzKzez8srKylYXOKZvIVU1dXxcXdvqzxIR6UyyGhjuvsnd6929AbiT5OmnVBXAkEbPBwPrW/jMWe5e5u5lJSUlra5xUK8CANZt051rRUQay2pgmNmARk8vARanafYWMMbMRphZLjANeCob9cHB1ffWbVdgiIg0lsjUB5vZg8BkoI+ZVQA/ACabWSnJU0wfAtcGbQcCd7n7VHevM7MbgWeBOHCPu7+XqTpTDS5O9jDWKzBERA6RscBw9+lpdt/dTNv1wNRGz2cDTabcZkOPbgkK8xJU6JSUiMghdKV3CjNjYHG+TkmJiKRQYKQxqLibTkmJiKRQYKQxsLibehgiIikUGGkM6tWN7bv3UV1TF3UpIiLthgIjjUHF3QDNlBIRaUyBkUZeIg7Abq3vLSJygAIjjeWbqjCD0X0Loy5FRKTdUGCksWTDTob1LqB7XsYuUxER6XAUGGks3VjFuP49oi5DRKRdUWCk2F1bx4cfV3PcAAWGiEhjCowUyzZW4Q7jBhRFXYqISLuiwEixZEMVAMerhyEicggFRoqlG3dSmJc4cC2GiIgkKTBSLN1Qxbj+RcRi6Rb+ExHpuhQYjbg7Szbu1PiFiEgaCoxGKqtqqNpbx+gSXbAnIpJKgdFIcUEuiZixuaom6lJERNqdjAWGmd1jZpvNbHGjfT8zs6Vm9o6ZPW5mxc2890Mze9fMFppZeaZqTJWbiDG6byFLN1Zl61uKiHQYmexh3AdMSdn3PHCiu58MLAf+uYX3n+fupe5elqH60hrbv4hlCgwRkSYyFhju/gqwNWXfc+6+f5GJ14HBmfr+R2tc/x6s276HHXv2RV2KiEi7EuUYxtXA08285sBzZjbfzGZksaYDM6TUyxAROVQkgWFm3wPqgAeaaTLJ3U8FLgRuMLNzWvisGWZWbmbllZWVra7tuOCmg0s37mz1Z4mIdCZZDwwz+wrwGeBL7u7p2rj7+uDrZuBx4PTmPs/dZ7l7mbuXlZSUtLq+fj3yKC7IOXCLEBERScpqYJjZFOA7wGfdfXczbbqbWdH+beDTwOJ0bTNUI+P6F6mHISKSIpPTah8E5gFjzazCzK4BbgWKgOeDKbMzg7YDzWx28NZ+wBwzWwS8CfzZ3Z/JVJ3pjOvfg2Ubq2hoSNsBEhHpkjK2pJy7T0+z++5m2q4Hpgbbq4DxmaorjOMGFLG7tp6123Yz7JjuUZYiItJu6ErvNAb3KgBg4469EVciItJ+KDDSyE0kD0ttfUPElYiItB8KjDTy9gdGnQJDRGQ/BUYaeYk4ADUKDBGRAxQYaeSqhyEi0oQCI439p6Rq6uojrkREpP1QYKShHoaISFMKjDQO9jAUGCIi+ykw0shVYIiINKHASCM3rsAQEUmVsVuDdGRmRv8e+dw7dzU5MeOrZ42gME+HSkS6NvUwmnH/NaczceQx/OL55Zz9H3/lNy+vZE+tZk2JSNdlzSxJ0SGVlZV5eXl5m37morXb+eXzy3l5eSV9CvO48bxRTD9j6IGL+0REOjIzm+/uZaHaKjDCKf9wKz9/bhmvr9rKgJ75fOP8MXy+bDA5cXXSRKTjUmBk0GsrtvDz55axYM12hvYu4JsXjOHi0oEkFBwi0gEdSWDof7kjdOboPjx6/Znce9Vp9OiW4J8eWcSnf/UKf1y0XgsuiUinpsA4CmbGeeP68scbz2LmFZ8gETO+8eDbTL3lVZ59byOdqdcmIrJfqMAws1FmlhdsTzazb5pZ8WHec4+ZbTazxY329Taz583sg+Brr2beO8XMlpnZCjP77pH8QNlkZkw5sT9Pf+scfj2tlNq6Bq797/lcdNtcXlq2WcEhIp1K2B7Go0C9mY0muczqCOB3h3nPfcCUlH3fBV5w9zHAC8HzQ5hZHLgNuBA4HphuZseHrDMS8ZhxUekgnvuHc/jZZSeztbqWq+59i8tmzuO1lVuiLk9EpE2EDYwGd68DLgF+5e7/AAxo6Q3u/gqwNWX3RcBvg+3fAheneevpwAp3X+XutcBDwfvavUQ8xufLhvDXb0/mxxefyLpte7j8zje4/M7Xmf9R6qEQEelYwgbGPjObDnwF+FOwL+covl8/d98AEHztm6bNIGBto+cVwb4OIzcR44oJw3jp/0zm+585nuWbqvjcHfP45fPLoy5NROSohQ2MrwITgZ+4+2ozGwH8T4ZqsjT7mh0MMLMZZlZuZuWVlZUZKuno5OfEufqsEVx7zigAeuTr9iIi0nGFCgx3f9/dv+nuDwYD1UXufvNRfL9NZjYAIPi6OU2bCmBIo+eDgfUt1DbL3cvcvaykpOQoSsqst9ds4z+fXcoF4/py9aQRUZcjInLUws6SesnMephZb2ARcK+Z/fIovt9TJE9rEXx9Mk2bt4AxZjbCzHKBacH7Opxt1bXc8MAC+vXI5xdfGE8slq7zJCLSMYQ9JdXT3XcClwL3uvsngE+29AYzexCYB4w1swozuwa4GfiUmX0AfCp4jpkNNLPZAMHg+o3As8AS4Pfu/t6R/2jRamhwbnp4IVt21XL7l06luCA36pJERFol7En1RHAK6QvA98K8wd2nN/PSBWnargemNno+G5gdsrZ26bYXV/Dy8kp+fPGJnDy4xUtWREQ6hLA9jB+S/I1/pbu/ZWYjgQ8yV1bHNueDLfzyL8u5uHQgXzpjaNTliIi0iVA9DHd/BHik0fNVwOcyVVRHtnHHXr710NuMLinkJ5echJnGLUSkcwg76D3YzB4PbvWxycweNbPBmS6uo9lX38CNv1vAnn313HHFqXTXKn0i0omEPSV1L8mZSgNJXkT3x2CfNPKfzyyl/KNt3Py5kxndtyjqckRE2lTYwChx93vdvS543Ae0v4seIvTsexu589XVXDlhGJ8dPzDqckRE2lzYwNhiZleYWTx4XAF8nMnCOpq/vL8JgBvOGx1xJSIimRE2MK4mOaV2I7ABuIzk7UIkcO25I4nHjNtfWhF1KSIiGRH21iBr3P2z7l7i7n3d/WKSF/FJYHTfIi4/fSgPvLGGFZuroi5HRKTNtWbFvX9ssyo6iZs+OYaCnDj/b/bSqEsREWlzrQkMXWCQ4pjCPG44fzQvLN3M3BVaOElEOpfWBIbWH03jqjOHM7hXN3785yXUN+gQiUjn0WJgmFmVme1M86gieU2GpMjPifOdKeNYsmEnjy6oiLocEZE202JguHuRu/dI8yhyd13G3IzPnDyAU4YW8/Nnl1FdUxd1OSIibaI1p6SkGWbGv/7t8WyuqmHWK6uiLkdEpE0oMDLkE8N68ZmTB/CbV1ayccfeqMsREWk1BUYGfWfKOBoa4JsPvc2OPfuiLkdEpFUUGBk0pHcBP/v8yby9Zhufu+M11m7dHXVJIiJHLeuBYWZjzWxho8dOM7sppc1kM9vRqM33s11nW7modBD3X30Gm3fu5ZLb57Jw7faoSxIROSpZDwx3X+bupe5eCnwC2A08nqbpq/vbufsPs1tl25o46hge+/qZdMuNM23WPJ5ZvDHqkkREjljUp6QuILns60cR15Fxo/sW8fjXJzGufw+uf2A+d726Cndd2CciHUfUgTENeLCZ1yaa2SIze9rMTshmUZnSpzCPh2ZMYMoJ/fnxn5fwg6feo66+IeqyRERCiSwwzCwX+CyN1gpvZAEwzN3HA/8FPNHC58wws3IzK6+srMxMsW0oPyfObZefyoxzRnL/vI+Y8d/zdXGfiHQIUfYwLgQWuPum1Bfcfae77wq2ZwM5ZtYn3Ye4+yx3L3P3spKSjrEIYCxm/MvU4/jRxSfy0rLNfOE389i0U9dqiEj7FmVgTKeZ01Fm1t/MLNg+nWSdnW6FvysnDOPuq07jwy3VXHzbXJZs2Bl1SSIizYokMMysAPgU8FijfdeZ2XXB08uAxWa2CLgFmOaddIT4vLF9+f11E3GHz8+cx8vL2/9pNRHpmqwz/T9cVlbm5eXlUZdxVDbs2MPV95WzfFMV/zX9FKaeNCDqkkSkCzCz+e5eFqZt1LOkJDCgZzeunzyK+gZnVeWuqMsREWlCgdFObNq5lx88uZgTBvZgxjmjoi5HRKQJBUY70NDgfPv3i9izr55fTzuF3IT+WESk/dH/TO3A3XNWM2fFFn7wdycwum9h1OWIiKSlwIjY4nU7+M9nl/I3J/Rj2mlDoi5HRKRZCowI7amt51sPvc0x3fO4+dKTCS49ERFpl7Qud4R+9Of3WbWlmgeuOYNe3XOjLkdEpEXqYUTkmcUb+d0ba5hxzkjOHJ32riciIu2KAiMCG3fs5buPvcNJg3ry7U+NjbocEZFQFBhZ1tDgfPuRhdTsa+DX00o1hVZEOgz9b5VlSzdWMXfFx5QOKWZI74KoyxERCU2BkWXHDSjiG+ePZt6qj7nq3jfZsWdf1CWJiISiwMgyM+Pbnx7Lzy47mTdWbeWyO15j7dbdUZclInJYCoyIfL5sCPdffTobd+7lkttfY9Ha7VGXJCLSIgVGhM4c3YfHv34m+TkxvjhrHs8s3hh1SSIizVJgRGx03yKeuGES4/r34PoH5nPnK6voTGuUiEjnocBoB/oU5vHQjAlMOaE/P5m9hH97cjF19Q1RlyUicoiolmj90MzeNbOFZtZkiTxLusXMVpjZO2Z2ahR1ZlN+TpzbLj+Va88dyf+8voav3V/Orpq6qMsSETkgyh7Gee5e2szSgBcCY4LHDOCOrFYWkVjM+OcLj+Onl5zEqx9s4bI7XmPDjj1RlyUiArTfU1IXAfd70utAsZl1mUWuLz9jKPdcdRoV2/Zw8W1zWbxuR9QliYhEFhgOPGdm881sRprXBwFrGz2vCPZ1GeceW8Ifrp9I3Iwv/GYef3l/U9QliUgXF1VgTHL3U0meerrBzM5JeT3dwhBppw6Z2QwzKzez8srKyrauM1Lj+vfgiRsmMbpvIX//3+XMfHmlZlCJSGQiCQx3Xx983Qw8Dpye0qQCaLz83GBgfTOfNcvdy9y9rKSkJBPlRqpvj3wenjGRqScN4Oanl/JPj7xDTV191GWJSBeU9cAws+5mVrR/G/g0sDil2VPAl4PZUhOAHe6+IculthvdcuPcOv0UbvrkGB5dUMGX7nyDj3fVRF2WiHQxUfQw+gFzzGwR8CbwZ3d/xsyuM7PrgjazgVXACuBO4OsR1NmumBk3ffJYbr38FN5dt4OLbpvLso1VUZclIl2IdaZz4mVlZV5e3uSyjk5n0drt/P395VTX1HHL9FO44Lh+UZckIh2Umc1v5vKGJtrrtFppwfghxTx141mMLCnka/eXM+sVDYaLSOYpMDqo/j3z+f21E7nwxP78dPZS/u8fNBguIpmlwOjAkoPhp/LNC8bwyPwKrrhLg+EikjkKjA4uFjP+8VPHcsv0U3inQoPhIpI5CoxO4rPjB/LwtROprWvg0tvn8tSi9TQ0aFxDRNqOAqMTKR1SzJM3TmJESXe++eDbfPpXr/CH+RXU1ulW6SLSegqMTmZAz2488fVJ/HpaKYmY8U+PLGLyz17k7jmrqdbt0kWkFXQdRifm7ry0vJKZL63kjdVbKS7I4csTh3PVmcPp3T036vJEpB04kuswFBhdxII125j50kqee38T+Tkxpp02lK+dPYLBvQqiLk1EIqTAkGat2FzFzJdX8cTb63DgovEDufbcUYztXxR1aSISAQWGHNb67Xu4e85qHnxzDbtr67lgXF+unzyKsuG9oy5NRLJIgSGhbd9dy/3zPuLeuavZtnsfZcN6cf3kUZw3ti+xWLplSUSkM1FgyBHbXVvH799ay52vrmbd9j2M7VfEteeO5O/GDyQnrsl0Ip2VAkOO2r76Bv70znpmvrSKZZuqGFTcja+dPYIvnjaEgtxE1OWJSBtTYEiruTsvLtvMHS+t5K0Pt9GrIIerzhzBlycOo5em5Ip0GgoMaVPlH25l5ssr+cuSzfQqyOHpb51D/575UZclIm1A62FImyob3pu7vnIaf7zxLHbX1vPDP70XdUkiEoEo1vQeYmYvmtkSM3vPzL6Vps1kM9thZguDx/ezXac0ddLgntx43mhmv7uRF5dujrocEcmyKHoYdcC33f04YAJwg5kdn6bdq+5eGjx+mN0SpTkzzh3JqJLu/NuTi9lTqwWbRLqSrAeGu29w9wXBdhWwBBiU7Trk6OQl4vzkkpOo2LaHW/76QdTliEgWRTqGYWbDgVOAN9K8PNHMFpnZ02Z2QlYLkxZNGHkMnzt1MHe+sorlm7RYk0hXEVlgmFkh8Chwk7vvTHl5ATDM3ccD/wU80cLnzDCzcjMrr6yszFzBcojv/e1xFOYn+N7j72qhJpEuIpLAMLMckmHxgLs/lvq6u+90913B9mwgx8z6pPssd5/l7mXuXlZSUpLRuuWg3t1z+ZcLj+OtD7fxh/kVUZcjIlkQxSwpA+4Glrj7L5tp0z9oh5mdTrLOj7NXpYRx2ScGc9rwXvz06SV8vKsm6nJEJMOi6GFMAq4Ezm80bXaqmV1nZtcFbS4DFpvZIuAWYJp3pisMO4lYzPjJJSexa28dP529NOpyRCTDsn5zIHefA7R4G1R3vxW4NTsVSWsc26+IGeeM5PaXVlI6pCdfOG0IeYl41GWJSAboSm9ptW+cP4ZThhbzb0++x6SbX+SWFz7QKSqRTkj3kpI24e7MXfExd89ZxYvLKslLxLj01EFcPWkEY/ppNT+R9upI7iWl+1VLmzAzzhrTh7PG9GHF5irunvMhjy2o4ME313LusSV87ewRnDW6D8FcBhHpgNTDkIzZWl3LA69/xG/nfcSWXTWM7VfENWeP4LPjB5Kfo3EOkfZAtzeXdqWmrp6nFq7n7jmrWbqxij6FuVw5YThXTBjKMYV5UZcn0qUpMKRdcndeW/kxd72aHOfITcS49JRBXH3WCI7VOIdIJDSGIe2SmTFpdB8mje7Dis27uGfuah6dX8FDb63lnGNL+NpZIzh7jMY5RNor9TAkUlura/ndG8lxjsqqGo7tV8g1Z43gotJBGucQyQKdkpIOp6aunj8u2sBdr65i6cYqjumeyxUThnHlxGH00TiHSMYoMKTDcnfmrfyYu+as5q9LN5ObiHFJ6SAuPmUQvbrnUJiXoDAvQfe8BDlxXXcq0loaw5AOy8w4c3QfzgzGOe6du5pHF1TwcPnaJm3zErED4XEwSOIU5udQmBene26CwvxEmjap7RO6nYlICOphSLu3rbqWRRXbqa6pp7qmjqqaOqqDxyHbe+uorq2juqY+uV1Tx5594ZaRzYlbk1DpnpegKAiVg9vB/vwE3XMbbQftivJyyM+JaeBeOgz1MKRT6dU9l8lj+x7Ve+sbnOraOnYFAbIreCS369m1dx/VtfXJ/SltduyuZd223UH7eqpr6wjz+1XMCBEw8RZ6PcF2foKCnDixmMJH2gcFhnRq8ZjRIz+HHvk5rf6shgZnz776Q0Nnb7Bdu3+7Pk0wJR+bdu5t1L6e+pArFXbPTZ42azlg4k3CpnuatnGFj7SCAkMkpFjMDvQY+rXys9ydvfsamoRK+l5QsN2op7S1eveBNtU19dTWN4T6vvk5MQrzgjGe1NDZP95zYOyn5V5QbkKTDroaBYZIBMyMbrlxuuXGKSlq/bThmrr6A2M8u1JDZ3+vpqaeXTX7mvSCNu7ce0hY7d0XLnxyD0w6SE4wODiWk3I6Lq/RGE+asZ/CvAR5CY37dAQKDJFOIC8RJy8Rp3f33FZ/Vl19QzJcahtNJkg3yaBRD6fqQM+nljUfN+r91IabdJCIpU46iLcwuWD/Kbic5Cy3lJ5PQW5c4ZMhkQSGmU0Bfg3Egbvc/eaU1y14fSqwG7jK3RdkvVCRLigRj9GzIEbPgrYZ99k/c63JabdgVtuhEw6SvaD9IbRhx94D7atr6ggz7BMzDoTMwbGceDNjP00nGuw/Fdc9OD2ncZ+Dsh4YZhYHbgM+BVQAb5nZU+7+fqNmFwJjgscZwB3BVxHpQGIxoyg/h6I2mHTgfnDSQXUwtnP4MaCDExG2VO0+ZIJCXchJBwW5qdOq44eOA+UnKMxNGQdK6S0VBb2hRAe/2DSKHsbpwAp3XwVgZg8BFwGNA+Mi4H5PXiTyupkVm9kAd9+Q/XJFpD0wMwpyExTkJqCVNzd2d2rqGg6eUgt6NWmv7QlC5uCpuXrWbd9zyKm52rpw4z55idjBU2u5qTPaWugF5ae2j0dysWkUgTEIaHzZbgVNew/p2gwCFBgi0mpmRn5OnPycOMcUtv7z9tU3pFw8mjLDrYUe0OaqvVRvOXgt0JFcbLo/VAb27Mbvr5vY+h/kMKIIjHQnBFP7hmHaJBuazQBmAAwdOrR1lYmIHIWceIziglyKC1o/6aD+wLjP4We47W+TrSnOUQRGBTCk0fPBwPqjaAOAu88CZkHy1iBtV6aISPYdcrFpz6irOVQUIzBvAWPMbISZ5QLTgKdS2jwFfNmSJgA7NH4hIhKtrPcw3L3OzG4EniU5rfYed3/PzK4LXp8JzCY5pXYFyWm1X812nSIicqhIrsNw99kkQ6HxvpmNth24Idt1iYhI8zr2pGAREckaBYaIiISiwBARkVAUGCIiEooCQ0REQulUa3qbWSXw0VG+vQ+wpQ3L6eh0PJrSMWlKx6SpjnZMhrl7SZiGnSowWsPMysMuhN4V6Hg0pWPSlI5JU535mOiUlIiIhKLAEBGRUBQYB82KuoB2RsejKR2TpnRMmuq0x0RjGCIiEop6GCIiEkqXCgwzm2Jmy8xshZl9N83rZma3BK+/Y2anRlFnNoU4Jl8KjsU7ZvaamY2Pos5sOtwxadTuNDOrN7PLsllftoU5HmY22cwWmtl7ZvZytmvMthD/bnqa2R/NbFFwTDrHHbfdvUs8SN5KfSUwEsgFFgHHp7SZCjxNcsW/CcAbUdfdDo7JmUCvYPtCHZND2v2V5F2XL4u67oj/jhQD7wNDg+d9o667HRyTfwH+I9guAbYCuVHX3tpHV+phnA6scPdV7l4LPARclNLmIuB+T3odKDazAdkuNIsOe0zc/TV33xY8fZ3k6oedWZi/JwDfAB4FNmezuAiEOR6XA4+5+xoAd9cxSS4pXWRmBhSSDIy67JbZ9rpSYAwC1jZ6XhHsO9I2ncmR/rzXkOyBdWaHPSZmNgi4BJhJ5xfm78ixQC8ze8nM5pvZl7NWXTTCHJNbgeNILi39LvAtd2/ITnmZE8kCShGxNPtSp4iFadOZhP55zew8koFxVkYril6YY/Ir4DvuXp/8BbJTC3M8EsAngAuAbsA8M3vd3ZdnuriIhDkmfwMsBM4HRgHPm9mr7r4z08VlUldjFjnPAAADU0lEQVQKjApgSKPng0mm/5G26UxC/bxmdjJwF3Chu3+cpdqiEuaYlAEPBWHRB5hqZnXu/kR2SsyqsP9utrh7NVBtZq8A44HOGhhhjslXgZs9OYixwsxWA+OAN7NTYmZ0pVNSbwFjzGyEmeUC04CnUto8BXw5mC01Adjh7huyXWgWHfaYmNlQ4DHgyk78G2Njhz0m7j7C3Ye7+3DgD8DXO2lYQLh/N08CZ5tZwswKgDOAJVmuM5vCHJM1JHtcmFk/YCywKqtVZkCX6WG4e52Z3Qg8S3KWwz3u/p6ZXRe8PpPkjJepwApgN8nfEjqtkMfk+8AxwO3Bb9R13klvrAahj0mXEeZ4uPsSM3sGeAdoAO5y98XRVZ1ZIf+O/Ai4z8zeJXkK6zvu3pHuYJuWrvQWEZFQutIpKRERaQUFhoiIhKLAEBGRUBQYIiISigJDRERCUWCIRCi4y+ufoq5DJAwFhoiIhKLAEAnBzK4wszeDNR9+Y2ZxM9tlZr8wswVm9oKZlQRtS83s9WANkcfNrFewf7SZ/SVYI2GBmY0KPr7QzP5gZkvN7IHgDqeY2c1m9n7wOT+P6EcXOUCBIXIYZnYc8EVgkruXAvXAl4DuwAJ3PxV4GfhB8Jb7SV7ZezLJO5Xu3/8AcJu7jye5zsj+286cAtwEHE9yjYVJZtab5B1xTwg+58eZ/SlFDk+BIXJ4F5C8G+tbZrYweD6S5G0wHg7a/A9wlpn1BIrdff+qc78FzjGzImCQuz8O4O573X130OZNd68Ibn+9EBgO7AT2AneZ2aUkb1UjEikFhsjhGfBbdy8NHmPd/d/TtGvpPjst3Qe9ptF2PZBw9zqSC/U8ClwMPHOENYu0OQWGyOG9AFxmZn0BzKy3mQ0j+e9n/3relwNz3H0HsM3Mzg72Xwm8HKyDUGFmFwefkRfc2TUtMysEerr7bJKnq0oz8YOJHIkuc7dakaPl7u+b2b8Cz5lZDNgH3ABUAyeY2XxgB8lxDoCvADODQFjFwbseXwn8xsx+GHzG51v4tkXAk2aWT7J38g9t/GOJHDHdrVbkKJnZLncvjLoOkWzRKSkREQlFPQwREQlFPQwREQlFgSEiIqEoMEREJBQFhoiIhKLAEBGRUBQYIiISyv8C+5XOIigtTL0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.xlabel(\"epochs\")\n", "plt.ylabel(\"Loss\")\n", "plt.plot(losslist, range(len(losslist)))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#visualize decision boundary\n", "\n", "h = 0.25\n", "x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1\n", "y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1\n", "xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n", " np.arange(y_min, y_max, h))\n", "Xmesh = np.c_[xx.ravel(), yy.ravel()]\n", "inputs = [list(map(Tensor, xrow)) for xrow in Xmesh]\n", "scores = list(map(model, inputs))\n", "Z = np.array([s.data > 0 for s in scores])\n", "Z = Z.reshape(xx.shape)\n", "\n", "fig = plt.figure()\n", "plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.8)\n", "plt.scatter(x[:, 0], x[:, 1], c=y, s=40, cmap=plt.cm.Spectral)\n", "plt.xlim(xx.min(), xx.max())\n", "plt.ylim(yy.min(), yy.max())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }