{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.learner import *\n", "from fastai.dataset import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = np.array([[0.,0.], [0,1], [1,0], [1,1]])\n", "y = np.array([0,1,1,0])\n", "data = (X,y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "md = ImageClassifierData.from_arrays('.', data, data, bs=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner.from_model_data(SimpleNet([2, 10, 2]), md)\n", "learn.crit = nn.CrossEntropyLoss()\n", "learn.opt_fn = optim.SGD" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4858fda4b41c4982b98d59cbfdab0ce1", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.703408 0.686116 0.5 \n", " 1 0.694675 0.680413 0.5 \n", " 2 0.689824 0.672467 0.5 \n", " 3 0.685353 0.6662 0.75 \n", " 4 0.681366 0.656064 0.75 \n", " 5 0.676933 0.645323 0.75 \n", " 6 0.672139 0.634129 0.75 \n", " 7 0.667045 0.614466 1.0 \n", " 8 0.66072 0.59939 0.75 \n", " 9 0.654014 0.579936 0.75 \n", " 10 0.646579 0.552088 1.0 \n", " 11 0.637801 0.547516 0.75 \n", " 12 0.629983 0.531916 1.0 \n", " 13 0.622022 0.483599 0.75 \n", " 14 0.611432 0.464332 1.0 \n", " 15 0.600781 0.439199 1.0 \n", " 16 0.589663 0.439463 1.0 \n", " 17 0.57981 0.3815 1.0 \n", " 18 0.567367 0.345051 1.0 \n", " 19 0.553991 0.324375 1.0 \n", " 20 0.540708 0.295096 1.0 \n", " 21 0.527019 0.266743 1.0 \n", " 22 0.513012 0.237544 1.0 \n", " 23 0.498673 0.230024 1.0 \n", " 24 0.485123 0.199286 1.0 \n", " 25 0.471132 0.196474 1.0 \n", " 26 0.458067 0.200001 1.0 \n", " 27 0.44612 0.152343 1.0 \n", " 28 0.432868 0.143447 1.0 \n", " 29 0.420133 0.126178 1.0 \n", "\n" ] }, { "data": { "text/plain": [ "[0.12617818, 1.0]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(1., 30, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }