{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|default_exp test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from fastcore.imports import *\n", "from collections import Counter\n", "from contextlib import redirect_stdout" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "from nbdev.showdoc import *\n", "from fastcore.nb_imports import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test\n", "\n", "> Helper functions to quickly write tests in notebooks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple test functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check that code raises an exception when that's expected (`test_fail`).\n", "\n", "To test for equality or inequality (with different types of things) we define a simple function `test` that compares two objects with a given `cmp` operator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_fail(f, msg='', contains='', args=None, kwargs=None):\n", " \"Fails with `msg` unless `f()` raises an exception and (optionally) has `contains` in `e.args`\"\n", " args, kwargs = args or [], kwargs or {}\n", " try: f(*args, **kwargs)\n", " except Exception as e:\n", " assert not contains or contains in str(e)\n", " return\n", " assert False,f\"Expected exception but none raised. {msg}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _fail(): raise Exception(\"foobar\")\n", "test_fail(_fail, contains=\"foo\")\n", "\n", "def _fail(): raise Exception()\n", "test_fail(_fail)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also pass `args` and `kwargs` to function to check if it fails with special inputs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _fail_args(a):\n", " if a == 5:\n", " raise ValueError\n", "test_fail(_fail_args, args=(5,))\n", "test_fail(_fail_args, kwargs=dict(a=5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test(a, b, cmp, cname=None):\n", " \"`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails\"\n", " if cname is None: cname=cmp.__name__\n", " assert cmp(a,b),f\"{cname}:\\n{a}\\n{b}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test([1,2],[1,2], operator.eq)\n", "test_fail(lambda: test([1,2],[1], operator.eq))\n", "test([1,2],[1], operator.ne)\n", "test_fail(lambda: test([1,2],[1,2], operator.ne))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### all_equal\n", "\n", "> all_equal (a, b)\n", "\n", "Compares whether `a` and `b` are the same length and have the same contents" ], "text/plain": [ "---\n", "\n", "### all_equal\n", "\n", "> all_equal (a, b)\n", "\n", "Compares whether `a` and `b` are the same length and have the same contents" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(all_equal)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(['abc'], ['abc'], all_equal)\n", "test_fail(lambda: test(['abc'],['cab'], all_equal))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "### equals\n", "\n", "> equals (a, b)\n", "\n", "Compares `a` and `b` for equality; supports sublists, tensors and arrays too" ], "text/plain": [ "---\n", "\n", "### equals\n", "\n", "> equals (a, b)\n", "\n", "Compares `a` and `b` for equality; supports sublists, tensors and arrays too" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(equals)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test([['abc'],['a']], [['abc'],['a']], equals)\n", "test([['abc'],['a'],'b', [['x']]], [['abc'],['a'],'b', [['x']]], equals) # supports any depth and nested structure" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def nequals(a,b):\n", " \"Compares `a` and `b` for `not equals`\"\n", " return not equals(a,b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(['abc'], ['ab' ], nequals)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## test_eq test_ne, etc..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just use `test_eq`/`test_ne` to test for `==`/`!=`. `test_eq_type` checks things are equal and of the same type. We define them using `test`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_eq(a,b):\n", " \"`test` that `a==b`\"\n", " test(a,b,equals, cname='==')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [], "source": [ "test_eq([1,2],[1,2])\n", "test_eq([1,2],map(int,[1,2]))\n", "test_eq(array([1,2]),array([1,2]))\n", "test_eq(array([1,2]),array([1,2]))\n", "test_eq([array([1,2]),3],[array([1,2]),3])\n", "test_eq(dict(a=1,b=2), dict(b=2,a=1))\n", "test_fail(lambda: test_eq([1,2], 1), contains=\"==\")\n", "test_fail(lambda: test_eq(None, np.array([1,2])), contains=\"==\")\n", "test_eq({'a', 'b', 'c'}, {'c', 'a', 'b'})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "import pandas as pd\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df1 = pd.DataFrame(dict(a=[1,2],b=['a','b']))\n", "df2 = pd.DataFrame(dict(a=[1,2],b=['a','b']))\n", "df3 = pd.DataFrame(dict(a=[1,2],b=['a','c']))\n", "\n", "test_eq(df1,df2)\n", "test_eq(df1.a,df2.a)\n", "test_fail(lambda: test_eq(df1,df3), contains='==')\n", "class T(pd.Series): pass\n", "test_eq(df1.iloc[0], T(df2.iloc[0])) # works with subclasses" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(torch.zeros(10), torch.zeros(10, dtype=torch.float64))\n", "test_eq(torch.zeros(10), torch.ones(10)-1)\n", "test_fail(lambda:test_eq(torch.zeros(10), torch.ones(1, 10)), contains='==')\n", "test_eq(torch.zeros(3), [0,0,0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_eq_type(a,b):\n", " \"`test` that `a==b` and are same type\"\n", " test_eq(a,b)\n", " test_eq(type(a),type(b))\n", " if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [], "source": [ "test_eq_type(1,1)\n", "test_fail(lambda: test_eq_type(1,1.))\n", "test_eq_type([1,1],[1,1])\n", "test_fail(lambda: test_eq_type([1,1],(1,1)))\n", "test_fail(lambda: test_eq_type([1,1],[1,1.]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_ne(a,b):\n", " \"`test` that `a!=b`\"\n", " test(a,b,nequals,'!=')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [], "source": [ "test_ne([1,2],[1])\n", "test_ne([1,2],[1,3])\n", "test_ne(array([1,2]),array([1,1]))\n", "test_ne(array([1,2]),array([1,1]))\n", "test_ne([array([1,2]),3],[array([1,2])])\n", "test_ne([3,4],array([3]))\n", "test_ne([3,4],array([3,5]))\n", "test_ne(dict(a=1,b=2), ['a', 'b'])\n", "test_ne(['a', 'b'], dict(a=1,b=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def is_close(a,b,eps=1e-5):\n", " \"Is `a` within `eps` of `b`\"\n", " if hasattr(a, '__array__') or hasattr(b,'__array__'):\n", " return (abs(a-b) 0 else '')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_stdout(lambda: print('hi'), 'hi')\n", "test_fail(lambda: test_stdout(lambda: print('hi'), 'ho'))\n", "test_stdout(lambda: 1+1, '')\n", "test_stdout(lambda: print('hi there!'), r'^hi.*!$', regex=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_warns(f, show=False):\n", " with warnings.catch_warnings(record=True) as w:\n", " f()\n", " assert w, \"No warnings raised\"\n", " if show:\n", " for e in w: print(f\"{e.category}: {e.message}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_warns(lambda: warnings.warn(\"Oh no!\"))\n", "test_fail(lambda: test_warns(lambda: 2+2), contains='No warnings raised')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ": Oh no!\n" ] } ], "source": [ "test_warns(lambda: warnings.warn(\"Oh no!\"), show=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "TEST_IMAGE = 'images/puppy.jpg'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im = Image.open(TEST_IMAGE).resize((128,128)); im" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "TEST_IMAGE_BW = 'images/mnist3.png'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIAAAACACAAAAADmVT4XAAAK3ElEQVR4nO2a2XIjubGGM7HUXlxKXKTunvEaYb//uzjixJnweNyj1kJRXGqvwpK+ILV1j20UWzo3R/8VCyQLH4BEApkAwLve9a53vetd7/r/LnT+ISJDxhgyxMOfjDGGAJDIEp0MINx/KT0vCALP8yRHADB1XlQaEK3qu5PrdwdgXhin4/E4SZLAZwCg1l+u1i0xZupSmbcHkOF4cjafz7LpJIkYALSX/+vLkgRXOXVvD+Al02y2OD9fzs6yUcwAoE2tooIE731Sp3eBCwDjXMTZcrFYLBbzbDoNAAAgWmxLqEiwLhEErdLmJEt0AEDhB+Hk4tPH5Ww6GaVpcCyP5l3UAmPdbSjEvqybU+p3ApDhaDL/3R9+vJiEoR/Ix79OWaaBYTMNuFitqddvBhBPFh9+/OPvz1POgXN79AMsDWaECG3MCAVasz9lDBwAuJ9MF+cXF+cL71BgDxVxcRiMlLpWMy6CsO6NtQSDOP47AJPRZPnx43KaHOsHJAB8cqIsyVoIR/PtZr8vqk4bO4TAAcCLs4tPnxYj77EIX7pwObb+9GNRblZXV6u8ARgyJ516YDw7X0xCdnj+tnk84vFC9f3my09cG0t2QP0OAMhlGCdxIMAygG9eTgAYBikhmF1qyn2rzaDZ4GCE1hpjDBEyAGAEQPDc0BA4AAcAOa9nZ9Oq7/vXBbC6LffbyPfDY32HUaCHhyfxMJ1kVdsM8kgOAF25XaVChsljERHRcSa8sEcWjLKyKvJXBujr7W3Ag2QUPpQYq60hYMgYZ+wRwIKXTst8477HcAIg3Rb3gRePUs8/NN4Y1Xa9QS6E9D2PAQEhAhAPknGehpINmAcOtNSXW8GD0DM+AiDnYOrtJlfohVE0GvHHVyD3ozSNQ19odwIXAFUxrbnAOmAIIgiE3v36+brFeDTNlhgGj7/kMkqSJPS9AZ7IZbx0Y5qeMSpCDsxLRoFZ/fK3nyuczJYXEE8B8DgrmQziJI4Cr7OEjv7YBcDYvlJcQBlyZMGkic3t55/+p+DZRQvjc/NsFhwAfI+7d4ELABFAEYSiDTiyoKgSe3VzsyKgIC0b9Wy4uQySNIkCj6Pzkug6ZXS1wcpjiP5ok9DqZk8A1pinlY8AgHlAzSSNfKm1K4HznO1zk3MGKJM0wv1dDQC+5PzBCxACAAfG+3ESeV5vXQfBGUBVHUcEEEEUsnbXA4ShL4V4JAAAFFymceRLwY3rRHQGMMcmoRf4XDcKuJSCM3y+HCAgCMHZy9JXAngQdbrjVpnDOvDVd2Rt1yttrPumaDAAgCEGBg7tfRICkNW62xdV2yv3bdkJAED2WxNHACCjmnKzLapOufviUwB+c4YhAJi+3m/2ed0p9yjppB54+vSyHtu3ZVHUbW+sc+KBnQDwULs1xr5oqunrMi/rrh+wGH0nwEt7N31TlmXdKvvbw/TKAN/Kqqaq6qZXAyKT7wBgQnpS8GeDbXXftm2v9P8JgEEvGY9j7+kNiAdfNCg0+x4bCKeL+TTiT+/iQkjBmbsf/j4Ans4/XGQRf4oQuBDiaXl6c4B4ujhfPPQAAQBYa4noNxaJ/6ATHBEiImcsujhfzLPUPw4BEeimKouyVa8cnH5TPeNemKRx+sOfP81GoWQAAEjWqHazurq8vq/U2wIw4cfZxYfldHZxfpYEAgkBwOquzVe//vzz7S7vXzdB8TUA98LxxZ//8odFFMeBlHjoAaua/P7m889/X3ftK0fHXwOIIM2Wv//LXxeCIT3am+2r/d3t9Zcv24GJ6+EAfpotzs+Xi+jwfKzPtvndzc1qkw9NmQ6ehjzMPvzuh8X4MSA75u6ru19/+bzKB3X/SQAsmn388UMWPHb04YMu15f//HI/PFs6HCAYzRazVGiAQ7B+KLZtfr/e1sNz1sM9IROeJ/kh9qRnFkcAyKUc5AbhlG25aop8NCIOcEhZHsTD6XnLBKPWDDvAGQxgm81NIOLsAHDIVyGQmHzCZBJzC+2gPOUJAPXaQz5ePmslAYCcisl86ulW255eN1P6NUCz4xjMP9YpPNsVk0jjs4nX7XetJufA9DQA1eT+/epmrIU4+mEAAIZcYjk7yypt9JDVaLgR2r4t7r/EdpOOx9GDyRMCQBCPslndd6+cqPxa1vT15ldbnC8/ifCxlBAAZJzN67Z65UTlVyKyut3banOXszhJnn9lWTRZNPV++6YAQFa31JV5Bck4ohDpyRBENKnaarutevdTk+E9YIxhPZOVklFgz2KPe76QB6/gRVNtmqJSVa9cc/anBKfGKIBKITP72TiJkzThh42hiCxD1XS0zcu3BAAAALVBatazsyybzsE7nOWJkAfSdAp8VK5T4WQAsAV2u2w+X9aWB+EhLvC9SKi2t0w1tePOwB0AERnnHI87ACKm96ppmq6zXPCQcQbIAdKsarVtm7Z0i9FcARAZl2GSRBLJEiDavqmbsmvqcl91XTeKwkAAAHjxmQJmDW3bziVMdgZg3Iumy+VZxK22jDNTra+vd01Z7tb7qqnnWcYPhuCP0ZNEGOxzcvDJ7hcYuBef/fCnTyNplOGS6/1n2dZN1+SbWlndWS8GAACSqR9woxSibh1S5u5DwGU4Wv74p0ya3nBP6K1stkXXK9WLeJREo8MlBiLBA6mKzaYsPf7f3joAABCZCNJsfnboAaG9ervrZdBqlAysedoGIfj+y8TFqwAAMi79KE2E1YZJZu38hz6Y3RW1TbNRFMhnNVqj+67tnKLUAQCAXEjPY5ZZxpHCTMlsfbvZq3A6n6ah9wTQtVVZlnXvki0c4IgQkAnBgTOLDNAbYTzbzO+2rZdk03F0HHEi0m1dFXlRO11tGeIJkXEuGRwOasFDESZJEKUtD0ej5BinAxAZ3TdVWbklCoZ4QoZcPMsISSYERxl3zIui6DFbRdaorinzsnNKGLvPAkBgiM8CGeQQIHojhcLz/MDnBAhIYI3u6mJftuo1XTEAPEVih/cSMO6zQFlEzrl46ByyRvdNmZdu93qcAYgs2eOG9yke4vJwVvPs9ICs0apr6sZtV+QOYI1WfdcpzoAdIV4GlofeQUREssY1XesanJI1qmubuuqOloX422EociE96eoHB0THVquuKYu8/ncr3NEXI+NSuHti5yEg0n1XF9tNiMx/LDxUSc8ejsYC6NqyATag+3q3vuGqi0OOCPRgdi8Gm7Tpq7xsnKMz92lIti/Xl1hm4zT2fd8XQiCDx/ssDxah6v3m5vKueu1dMQHYvriF/HqSpqN0PBnHIePwOAAP9VO/vbq8/Hy5e4NtOanS1ndJEqeTs8X5ckqMPyM4StXrf/70y9XtmwDous2lH0bJZFH0VluAiD26RQQAoL7c3X7++z9Wef0WgQkZoxr0g6jqLZim7XsdiOMtGjLWKKXb/Ob6+vp22zif2gyPDVttLFlV3G3y5WwcHxOmSHVZ1W21v/3Hl3VeK3ptP/BMpgNS1Wq03lddTw87T52v7/dFvl1fXe8b436p8KTgtNPNTob7RlvDg6NbqjZXN+vtbrvd3BdD8vUnAJAx0ABIwyUnzmEMAFCsri+v7jbbXV6WzgZ4IsBRancXsq7NVyMAgPL2y+XqfpcXVdO97YnJk9qd6Her6SgCAGh29+tdUdVNq4acmw64Xf+teBRHYRT6EgBAN3Xddr1SA6+Ufg8AcsYZe/AD1hpLZIkG3qp917ve9a53vetd7/oXsYyZrAWZvNgAAAAASUVORK5CYII=", "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im = Image.open(TEST_IMAGE_BW).resize((128,128)); im" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def test_fig_exists(ax):\n", " \"Test there is a figure displayed in `ax`\"\n", " assert ax and len(ax.figure.canvas.tostring_argb())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig,ax = plt.subplots()\n", "ax.imshow(array(im));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ExceptionExpected:\n", " \"Context manager that tests if an exception is raised\"\n", " def __init__(self, ex=Exception, regex=''): self.ex,self.regex = ex,regex\n", " def __enter__(self): pass\n", " def __exit__(self, type, value, traceback):\n", " if not isinstance(value, self.ex) or (self.regex and not re.search(self.regex, f'{value.args}')):\n", " raise TypeError(f\"Expected {self.ex.__name__}({self.regex}) not raised.\")\n", " return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _tst_1(): assert False, \"This is a test\"\n", "def _tst_2(): raise SyntaxError\n", "\n", "with ExceptionExpected(): _tst_1()\n", "with ExceptionExpected(ex=AssertionError, regex=\"This is a test\"): _tst_1()\n", "with ExceptionExpected(ex=SyntaxError): _tst_2()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`exception` is an abbreviation for `ExceptionExpected()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "exception = ExceptionExpected()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with exception: _tst_1()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "def _f():\n", " with ExceptionExpected(): 1\n", "test_fail(partial(_f))\n", "\n", "def _f():\n", " with ExceptionExpected(SyntaxError): assert False\n", "test_fail(partial(_f))\n", "\n", "def _f():\n", " with ExceptionExpected(AssertionError, \"Yes\"): assert False, \"No\"\n", "test_fail(partial(_f))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_basics.ipynb.\n", "Converted 02_foundation.ipynb.\n", "Converted 03_xtras.ipynb.\n", "Converted 03a_parallel.ipynb.\n", "Converted 03b_net.ipynb.\n", "Converted 04_dispatch.ipynb.\n", "Converted 05_transform.ipynb.\n", "Converted 07_meta.ipynb.\n", "Converted 08_script.ipynb.\n", "Converted index.ipynb.\n" ] } ], "source": [ "#|hide\n", "#|eval: false\n", "from nbdev import nbdev_export\n", "nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3.7.15 ('fastcore-wZINr0y3')", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }