{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interactive type checking with datashape" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from datashape import *\n", "from inspect import *\n", "from functools import *\n", "from toolz.curried import *\n", "from toolz.curried.operator import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def to_datashape(annotation):\n", " if isinstance(annotation, str):\n", " return dshape(annotation)\n", " if isinstance(annotation, dict):\n", " return Record(\n", " (_1, to_datashape(_2)) for _1, _2 in annotations.items())\n", " return from_numpy(tuple(), annotation)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def annotation_to_shapes(callable):\n", " shapes = pipe(\n", " callable, getfullargspec, attrgetter('annotations'),\n", " valmap(to_datashape)\n", " )\n", " return shapes, shapes.pop('return')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def check_shapes(values, shapes):\n", " type_vars = {}\n", " for key, shape in shapes.items():\n", " value = values.pop(key)\n", " discovered = discover(value)\n", " new_parameters = []\n", " for dim, parameter in enumerate(shape.parameters):\n", " if isinstance(parameter, TypeVar): \n", " if parameter not in type_vars:\n", " if dim == len(shape.parameters)-1:\n", " new = discover(value).measure\n", " else:\n", " if hasattr(value, 'shape'):\n", " new = Fixed(value.shape[dim])\n", " else:\n", " new = Fixed(len(value))\n", " type_vars[parameter] = new\n", " new_parameters.append(type_vars[parameter])\n", " else:\n", " new_parameters.append(parameter)\n", " \n", " shape = DataShape(*new_parameters)\n", " \n", " if not validate(shape, value):\n", " raise InteractiveTypeError(f\"{key} expects {shape}, but recieved {discover(value)}\")\n", " \n", " if values:\n", " Err(f\"No types for {list(values)}.\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class InteractiveTypeError(BaseException):..." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def typecheck(callable):\n", " @wraps(callable)\n", " def caller(*args, **kwargs):\n", " argspec = getfullargspec(callable)\n", " shapes, returns = annotation_to_shapes(callable)\n", " values = merge(\n", " dict(zip(argspec.args, argspec.defaults or [])),\n", " dict(zip(argspec.args, args)),\n", " kwargs)\n", " check_shapes(values, shapes)\n", " output = callable(*args, **kwargs)\n", " return output\n", " return caller" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Demonstrate on a simple dot function." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "@typecheck\n", "def dot(x: \"N*float64\", y: \"N*float64\") -> \"float64\":\n", " return sum(_1*_2 for _1, _2 in zip(x, y))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAC0AAAAOBAMAAABJDIgxAAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAIpm7MhCriUTv3c12VGZoascqAAAACXBIWXMAAA7EAAAOxAGVKw4bAAAAx0lEQVQYGV2QsQ7BUBRATxMNKUNnU2MxGCVmf6ALqy4WC7UY+AeLhc0X8Ae1MDd+oKOYJAbEUve9W4vlnNyTvnfzitNod3Gb1wKgAwPKb6qpu1NItwMj2HCDhUK6HdjCpNuBLLSQrk5C6U9YpRbSf2bff0g/GRxlbWEqL+cFvdhgCnYQ40WOfNKLDUwvTEOP/t9TC+yqzOzNUt1rfMHx23APLaSrSwE1X57SMu9pSS5cn8/GeKm7VgwPapI8/+Auz/LfDLxI/QUhi1iL7EVjOgAAAABJRU5ErkJggg==\n", "text/latex": [ "$$200.0$$" ], "text/plain": [ "200.0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dot([10.], [20.])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Demonstrate on matrix multiplication." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "@typecheck\n", "def matmul(x: \"N*M*float64\", y: \"M*T*float64\") -> \"N*T*float64\":\n", " return x@y" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 3.72519493e+00, -1.09563082e-01],\n", " [ 4.69513259e-01, 3.37871837e+00],\n", " [ -4.44569978e+00, -2.15308497e+00],\n", " [ 2.63104150e+00, -2.90177959e+00],\n", " [ -7.61140141e-01, 2.16359480e+00],\n", " [ -2.40202344e-02, -1.48523737e+00],\n", " [ -3.08959526e+00, 6.93008449e-01],\n", " [ 1.29601051e+00, -2.66912294e-03],\n", " [ 1.87770045e-01, -9.63482844e-02],\n", " [ 2.59286769e+00, -8.12771658e-01]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matmul(np.random.randn(10,4), np.random.randn(4, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Execute a function with invalid generic types." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytest import raises" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def _validates_types():\n", " assert matmul(np.random.randn(10,4), np.random.randn(4, 2)) is not None\n", " assert dot([10.], [20.])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def _finds_type_errors():\n", " with raises(InteractiveTypeError):\n", " matmul(np.random.randn(10,4), np.random.randn(3, 2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "p6", "language": "python", "name": "other-env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }