{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom Types\n", "\n", "Often, the behavior for a field needs to be customized to support a particular shape or validation method that ParamTools does not support out of the box. In this case, you may use the `register_custom_type` function to add your new `type` to the ParamTools type registry. Each `type` has a corresponding `field` that is used for serialization and deserialization. ParamTools will then use this `field` any time it is handling a `value`, `label`, or `member` that is of this `type`.\n", "\n", "ParamTools is built on top of [`marshmallow`](https://github.com/marshmallow-code/marshmallow), a general purpose validation library. This means that you must implement a custom `marshmallow` field to go along with your new type. Please refer to the `marshmallow` [docs](https://marshmallow.readthedocs.io/en/stable/) if you have questions about the use of `marshmallow` in the examples below.\n", "\n", "\n", "## 32 Bit Integer Example\n", "\n", "ParamTools's default integer field uses NumPy's `int64` type. This example shows you how to define an `int32` type and reference it in your `defaults`.\n", "\n", "First, let's define the Marshmallow class:\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import marshmallow as ma\n", "import numpy as np\n", "\n", "class Int32(ma.fields.Field):\n", " \"\"\"\n", " A custom type for np.int32.\n", " https://numpy.org/devdocs/reference/arrays.dtypes.html\n", " \"\"\"\n", " # minor detail that makes this play nice with array_first\n", " np_type = np.int32\n", "\n", " def _serialize(self, value, *args, **kwargs):\n", " \"\"\"Convert np.int32 to basic, serializable Python int.\"\"\"\n", " return value.tolist()\n", "\n", " def _deserialize(self, value, *args, **kwargs):\n", " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n", " converted = np.int32(value)\n", " return converted\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, reference it in our defaults JSON/dict object:\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "value: 2, type: \n" ] } ], "source": [ "import paramtools as pt\n", "\n", "\n", "# add int32 type to the paramtools type registry\n", "pt.register_custom_type(\"int32\", Int32())\n", "\n", "\n", "class Params(pt.Parameters):\n", " defaults = {\n", " \"small_int\": {\n", " \"title\": \"Small integer\",\n", " \"description\": \"Demonstrate how to define a custom type\",\n", " \"type\": \"int32\",\n", " \"value\": 2\n", " }\n", " }\n", "\n", "\n", "params = Params(array_first=True)\n", "\n", "\n", "print(f\"value: {params.small_int}, type: {type(params.small_int)}\")\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One problem with this is that we could run into some deserialization issues. Due to integer overflow, our deserialized result is not the number that we passed in--it's negative!\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('small_int', [OrderedDict([('value', -2147483648)])])])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params.adjust(dict(\n", " # this number wasn't chosen randomly.\n", " small_int=2147483647 + 1\n", "))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Marshmallow Validator\n", "\n", "Fortunately, you can specify a custom validator with `marshmallow` or ParamTools. Making this works requires modifying the `_deserialize` method to check for overflow like this:\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Int32(ma.fields.Field):\n", " \"\"\"\n", " A custom type for np.int32.\n", " https://numpy.org/devdocs/reference/arrays.dtypes.html\n", " \"\"\"\n", " # minor detail that makes this play nice with array_first\n", " np_type = np.int32\n", "\n", " def _serialize(self, value, *args, **kwargs):\n", " \"\"\"Convert np.int32 to basic Python int.\"\"\"\n", " return value.tolist()\n", "\n", " def _deserialize(self, value, *args, **kwargs):\n", " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n", " converted = np.int32(value)\n", "\n", " # check for overflow and let range validator\n", " # display the error message.\n", " if converted != int(value):\n", " return int(value)\n", "\n", " return converted\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Now, let's see how to use `marshmallow` to fix this problem:\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "ename": "ValidationError", "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint64\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_int32\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m ))\n", "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}" ] } ], "source": [ "import marshmallow as ma\n", "import paramtools as pt\n", "\n", "\n", "# get the minimum and maxium values for 32 bit integers.\n", "min_int32 = -2147483648 # = np.iinfo(np.int32).min\n", "max_int32 = 2147483647 # = np.iinfo(np.int32).max\n", "\n", "# add int32 type to the paramtools type registry\n", "pt.register_custom_type(\n", " \"int32\",\n", " Int32(validate=[\n", " ma.validate.Range(min=min_int32, max=max_int32)\n", " ])\n", ")\n", "\n", "\n", "class Params(pt.Parameters):\n", " defaults = {\n", " \"small_int\": {\n", " \"title\": \"Small integer\",\n", " \"description\": \"Demonstrate how to define a custom type\",\n", " \"type\": \"int32\",\n", " \"value\": 2\n", " }\n", " }\n", "\n", "\n", "params = Params(array_first=True)\n", "\n", "params.adjust(dict(\n", " small_int=np.int64(max_int32) + 1\n", "))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ParamTools Validator\n", "\n", "Finally, we will use ParamTools to solve this problem. We need to modify how we create our custom `marshmallow` field so that it's wrapped by ParamTools's `PartialField`. This makes it clear that your field still needs to be initialized, and that your custom field is able to receive validation information from the `defaults` configuration:\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "ename": "ValidationError", "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2147483647\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m ))\n", "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}" ] } ], "source": [ "import paramtools as pt\n", "\n", "\n", "# add int32 type to the paramtools type registry\n", "pt.register_custom_type(\n", " \"int32\",\n", " pt.PartialField(Int32)\n", ")\n", "\n", "\n", "class Params(pt.Parameters):\n", " defaults = {\n", " \"small_int\": {\n", " \"title\": \"Small integer\",\n", " \"description\": \"Demonstrate how to define a custom type\",\n", " \"type\": \"int32\",\n", " \"value\": 2,\n", " \"validators\": {\n", " \"range\": {\"min\": -2147483648, \"max\": 2147483647}\n", " }\n", " }\n", " }\n", "\n", "\n", "params = Params(array_first=True)\n", "\n", "params.adjust(dict(\n", " small_int=2147483647 + 1\n", "))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }