{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# numpy Testing\n", "### Miki Tebeka .:. [353solutions](http://353solutions.com) .:. Highly effective Python, Scientific Python and Go workshops\n", "\n", "We'll explore certain caveats while testing [numpy](http://docs.scipy.org/doc/numpy/reference/) code.\n", "\n", "#### TL;DR\n", "Use [np.allclose](http://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html) when comparing numpy arrays. Beware of `nan`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Naive Approach" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "ename": "ValueError", "evalue": "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad multiplication'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mtest_mul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m\u001b[0m in \u001b[0;36mtest_mul\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexpected\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.21\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[1;32massert\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad multiplication'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mtest_mul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mValueError\u001b[0m: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" ] } ], "source": [ "def test_mul():\n", " arr = np.array([0.0, 1.0, 1.1])\n", " v, expected = 1.1, np.array([0.0, 1.1, 1.21])\n", " assert arr * v == expected, 'bad multiplication'\n", " \n", "test_mul()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is due to the fact that when we compare two numpy arrays with `==` we'll get an array of boolean values comparing each element." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([ True, False, True], dtype=bool)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array([1,2,3]) == np.array([1, 1, 3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And the truch value of an array (as the error says) is ambiguous." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "ename": "ValueError", "evalue": "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mbool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mValueError\u001b[0m: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" ] } ], "source": [ "bool(np.array([1, 2, 3]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to use [np.all](http://docs.scipy.org/doc/numpy/reference/generated/numpy.all.html) to check that all elements are equal." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.all([True, True, True])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using np.all" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "ename": "AssertionError", "evalue": "bad multiplication", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad multiplication'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mtest_mul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m\u001b[0m in \u001b[0;36mtest_mul\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexpected\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.21\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[1;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad multiplication'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mtest_mul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mAssertionError\u001b[0m: bad multiplication" ] } ], "source": [ "def test_mul():\n", " arr = np.array([0.0, 1.0, 1.1])\n", " v, expected = 1.1, np.array([0.0, 1.1, 1.21])\n", " assert np.all(arr * v == expected), 'bad multiplication'\n", " \n", "test_mul()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is due to the fact that floating points are not exact." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "1.2100000000000002" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1.1 * 1.1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is *not* a bug in Python but how floating points are implemented. You'll get the same result in C, Java, Go ...\n", "To overcome this we're going to use [np.allclose](http://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html).\n", "\n", "BTW: If you're really intersted in floating points, read [this article](http://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using np.allclose" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def test_mul():\n", " arr = np.array([0.0, 1.0, 1.1])\n", " v, expected = 1.1, np.array([0.0, 1.1, 1.21])\n", " assert np.allclose(arr * v, expected), 'bad multiplication'\n", " \n", "test_mul()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Oh nan, Let Me Count the Ways ..." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "ename": "AssertionError", "evalue": "bad nan", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr1\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0marr2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad nan'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mtest_div\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m\u001b[0m in \u001b[0;36mtest_div\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0marr1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0marr2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2.0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m2.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2.0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mexpected\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0.5\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnan\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1.0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[1;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr1\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0marr2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexpected\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'bad nan'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mtest_div\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mAssertionError\u001b[0m: bad nan" ] } ], "source": [ "def test_div():\n", " arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])\n", " expected = np.array([0.5, np.nan, 1.0])\n", " assert np.allclose(arr1 / arr2, expected), 'bad nan'\n", " \n", "test_div()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is due to the fact the `nan` does not equal itself." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.nan == np.nan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To check is a number is `nan` we need to use [np.isnan](http://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.isnan(np.inf/np.inf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have two options to solve this:\n", "\n", "1. Convert all `nan` to numbers\n", "2. Use `equal_nan` argument to `np.allclose`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Option 1: Convert `nan` to Numbers" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def test_div():\n", " arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])\n", " expected = np.array([0.5, np.nan, 1.0])\n", " result = arr1 / arr2\n", " \n", " result[np.isnan(result)] = 0.0\n", " expected[np.isnan(expected)] = 0.0\n", " assert np.allclose(result, expected), 'bad nan'\n", " \n", "test_div()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Option 2: Use `equal_nan` in `np.allclose`" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def test_div():\n", " arr1, arr2 = np.array([1.0, np.inf, 2.0]), np.array([2.0, np.inf, 2.0])\n", " expected = np.array([0.5, np.nan, 1.0])\n", " assert np.allclose(arr1 / arr2, expected, equal_nan=True), 'bad nan'\n", " \n", "test_div()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }