{ "cells": [ { "cell_type": "markdown", "id": "13a3464b", "metadata": {}, "source": [ "# Typing\n", "\n", "> SAX types" ] }, { "cell_type": "code", "execution_count": null, "id": "0466717d-c49e-4ceb-ac5d-f55fd1f6c92f", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Callable\n", "\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import pytest\n", "import sax" ] }, { "cell_type": "code", "execution_count": null, "id": "28374448", "metadata": {}, "outputs": [], "source": [ "assert sax.is_float(3.0)\n", "assert not sax.is_float(3)\n", "assert not sax.is_float(3.0 + 2j)\n", "assert not sax.is_float(jnp.array(3.0, dtype=complex))\n", "assert not sax.is_float(jnp.array(3, dtype=int))" ] }, { "cell_type": "code", "execution_count": null, "id": "d81a703d", "metadata": {}, "outputs": [], "source": [ "assert not sax.is_complex(3.0)\n", "assert not sax.is_complex(3)\n", "assert sax.is_complex(3.0 + 2j)\n", "assert sax.is_complex(jnp.array(3.0, dtype=complex))\n", "assert not sax.is_complex(jnp.array(3, dtype=int))" ] }, { "cell_type": "code", "execution_count": null, "id": "c533ff0d", "metadata": {}, "outputs": [], "source": [ "assert sax.is_complex_float(3.0)\n", "assert not sax.is_complex_float(3)\n", "assert sax.is_complex_float(3.0 + 2j)\n", "assert sax.is_complex_float(jnp.array(3.0, dtype=complex))\n", "assert not sax.is_complex_float(jnp.array(3, dtype=int))" ] }, { "cell_type": "code", "execution_count": null, "id": "b7f34b59-a4e7-4c98-8e91-4c5f8fc93658", "metadata": {}, "outputs": [], "source": [ "_sdict: sax.SDict = {\n", " (\"in0\", \"out0\"): 3.0,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "46590789-f484-4c46-a750-a6ed1b5b9553", "metadata": { "tags": [] }, "outputs": [], "source": [ "Si = jnp.arange(3, dtype=int)\n", "Sj = jnp.array([0, 1, 0], dtype=int)\n", "Sx = jnp.array([3.0, 4.0, 1.0])\n", "port_map = {\"in0\": 0, \"in1\": 2, \"out0\": 1}\n", "_scoo: sax.SCoo = (Si, Sj, Sx, port_map)" ] }, { "cell_type": "code", "execution_count": null, "id": "d9bdf2ae-e8b5-43a8-8082-8999b10127a4", "metadata": {}, "outputs": [], "source": [ "Sd = jnp.arange(9, dtype=float).reshape(3, 3)\n", "port_map = {\"in0\": 0, \"in1\": 2, \"out0\": 1}\n", "_sdense = Sd, port_map" ] }, { "cell_type": "code", "execution_count": null, "id": "ca61d288-3f7b-4fee-9ff1-7e9a0d431833", "metadata": {}, "outputs": [], "source": [ "assert not sax.is_sdict(object())\n", "assert sax.is_sdict(_sdict)\n", "assert not sax.is_sdict(_scoo)\n", "assert not sax.is_sdict(_sdense)" ] }, { "cell_type": "code", "execution_count": null, "id": "9ea0239a", "metadata": {}, "outputs": [], "source": [ "assert not sax.is_scoo(object)\n", "assert not sax.is_scoo(_sdict)\n", "assert sax.is_scoo(_scoo)\n", "assert not sax.is_scoo(_sdense)" ] }, { "cell_type": "code", "execution_count": null, "id": "2eddd63a", "metadata": {}, "outputs": [], "source": [ "assert not sax.is_sdense(object)\n", "assert not sax.is_sdense(_sdict)\n", "assert not sax.is_sdense(_scoo)\n", "assert sax.is_sdense(_sdense)" ] }, { "cell_type": "code", "execution_count": null, "id": "69afc7b0-fcfd-43cd-986d-1c4b3de6e03f", "metadata": {}, "outputs": [], "source": [ "def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:\n", " return {(\"in0\", \"out0\"): jnp.array(3.0)}" ] }, { "cell_type": "code", "execution_count": null, "id": "c12de3d7-9087-4724-966c-60bb7eb9f1de", "metadata": {}, "outputs": [], "source": [ "assert sax.is_model(good_model)" ] }, { "cell_type": "code", "execution_count": null, "id": "4cd2c9e5-85ff-49f6-afb1-da8dfe4ca592", "metadata": {}, "outputs": [], "source": [ "def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:\n", " return {(\"in0\", \"out0\"): jnp.array(3.0)}" ] }, { "cell_type": "code", "execution_count": null, "id": "16bb13cb-03b8-4363-8604-83f8f00a5efd", "metadata": {}, "outputs": [], "source": [ "assert not sax.is_model(bad_model)" ] }, { "cell_type": "markdown", "id": "1f9715ed", "metadata": {}, "source": [ "> Note: For a `Callable` to be considered a `ModelFactory` in SAX, it **MUST** have a `Callable` or `Model` return annotation. Otherwise SAX will view it as a `Model` and things might break!" ] }, { "cell_type": "code", "execution_count": null, "id": "fe9cdff0", "metadata": {}, "outputs": [], "source": [ "def func() -> sax.Model:\n", " ...\n", " \n", "assert sax.is_model_factory(func) # yes, we only check the annotation for now...\n", "\n", "def func():\n", " ...\n", " \n", "assert not sax.is_model_factory(func) # yes, we only check the annotation for now..." ] }, { "cell_type": "code", "execution_count": null, "id": "754399d5", "metadata": {}, "outputs": [], "source": [ "def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:\n", " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", "\n", "\n", "assert sax.validate_model(good_model) is None" ] }, { "cell_type": "code", "execution_count": null, "id": "181c72fa", "metadata": {}, "outputs": [], "source": [ "def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> sax.SDict:\n", " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", "\n", "\n", "with pytest.raises(ValueError):\n", " sax.validate_model(bad_model)" ] }, { "cell_type": "markdown", "id": "3afe685c", "metadata": {}, "source": [ "## SAX return type helpers\n", "\n", "> a.k.a SDict, SDense, SCoo helpers" ] }, { "cell_type": "markdown", "id": "a875f149", "metadata": {}, "source": [ "Convert an `SDict`, `SCoo` or `SDense` into an `SDict` (or convert a model generating any of these types into a model generating an `SDict`):" ] }, { "cell_type": "code", "execution_count": null, "id": "9a372fbf", "metadata": {}, "outputs": [], "source": [ "assert sax.sdict(_sdict) is _sdict\n", "assert sax.sdict(_scoo) == {\n", " (\"in0\", \"in0\"): 3.0,\n", " (\"in1\", \"in0\"): 1.0,\n", " (\"out0\", \"out0\"): 4.0,\n", "}\n", "assert sax.sdict(_sdense) == {\n", " (\"in0\", \"in0\"): 0.0,\n", " (\"in0\", \"out0\"): 1.0,\n", " (\"in0\", \"in1\"): 2.0,\n", " (\"out0\", \"in0\"): 3.0,\n", " (\"out0\", \"out0\"): 4.0,\n", " (\"out0\", \"in1\"): 5.0,\n", " (\"in1\", \"in0\"): 6.0,\n", " (\"in1\", \"out0\"): 7.0,\n", " (\"in1\", \"in1\"): 8.0,\n", "}" ] }, { "cell_type": "markdown", "id": "492c5cdd", "metadata": {}, "source": [ "Convert an `SDict`, `SCoo` or `SDense` into an `SCoo` (or convert a model generating any of these types into a model generating an `SCoo`):" ] }, { "cell_type": "code", "execution_count": null, "id": "2e409185-a9f6-4ea9-af09-1c184557f02d", "metadata": {}, "outputs": [], "source": [ "sax.scoo(_sdense)" ] }, { "cell_type": "code", "execution_count": null, "id": "1e97b31c", "metadata": {}, "outputs": [], "source": [ "assert sax.scoo(_scoo) is _scoo\n", "assert sax.scoo(_sdict) == (0, 1, 3.0, {\"in0\": 0, \"out0\": 1})\n", "Si, Sj, Sx, port_map = sax.scoo(_sdense) # type: ignore\n", "np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))\n", "np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))\n", "np.testing.assert_array_almost_equal(Sx, jnp.array([0.0, 2.0, 1.0, 6.0, 8.0, 7.0, 3.0, 5.0, 4.0]))\n", "assert port_map == {\"in0\": 0, \"in1\": 1, \"out0\": 2}" ] }, { "cell_type": "markdown", "id": "5e58325b", "metadata": {}, "source": [ "Convert an `SDict`, `SCoo` or `SDense` into an `SDense` (or convert a model generating any of these types into a model generating an `SDense`):" ] }, { "cell_type": "code", "execution_count": null, "id": "084b7ddb", "metadata": {}, "outputs": [], "source": [ "assert sax.sdense(_sdense) is _sdense\n", "Sd, port_map = sax.sdense(_scoo) # type: ignore\n", "Sd_ = jnp.array([[3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],\n", " [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],\n", " [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]])\n", "\n", "np.testing.assert_array_almost_equal(Sd, Sd_)\n", "assert port_map == {\"in0\": 0, \"in1\": 2, \"out0\": 1}" ] } ], "metadata": { "kernelspec": { "display_name": "sax", "language": "python", "name": "sax" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }