{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import unittest\n", "from config import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision import *\n", "from fastai.gen_doc.doctest import this_tests\n", "from fastai.callbacks import *\n", "from math import isclose\n", "from fastai.train import ClassificationInterpretation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_TINY)\n", "data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), num_workers=2)\n", "data = data.normalize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3,16,16,16,2), bn=True), metrics=[accuracy, error_rate])\n", "learn.fit_one_cycle(3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class IntegrationTest(unittest.TestCase):\n", "\n", " def test_accuracy(self):\n", " self.assertGreater(accuracy(*learn.get_preds()), 0.9)\n", " \n", " def test_error_rate(self):\n", " self.assertLess(error_rate(*learn.get_preds()), 0.1)\n", " \n", " def test_preds(self):\n", " pass_tst = False\n", " for i in range(3):\n", " img, label = learn.data.valid_ds[i]\n", " pred_class,pred_idx,outputs = learn.predict(img)\n", " if outputs[int(label)] > outputs[1-int(label)]:\n", " pass_tst = True\n", " break\n", " self.assertTrue(pass_tst)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class OneCycleTest(unittest.TestCase):\n", " \n", " def test_lrs(self):\n", " lrs = learn.recorder.lrs\n", " self.assertLess(lrs[0], 0.001)\n", " self.assertLess(lrs[-1], 0.0001)\n", " self.assertEqual(np.max(lrs), 3e-3)\n", " \n", " def test_moms(self):\n", " moms = learn.recorder.moms\n", " self.assertEqual(moms[0],0.95)\n", " self.assertLess(abs(moms[-1]-0.95), 0.01)\n", " self.assertEqual(np.min(moms),0.85)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }