{ "cells": [ { "cell_type": "markdown", "id": "6ecfd145", "metadata": {}, "source": [ "# Utils\n", "\n", "> General SAX utilities" ] }, { "cell_type": "code", "execution_count": null, "id": "697d19f7-0648-4b6f-b4a3-2f7795aed9f0", "metadata": { "tags": [] }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import pytest\n", "import sax" ] }, { "cell_type": "code", "execution_count": null, "id": "7e9b15b3", "metadata": { "tags": [] }, "outputs": [], "source": [ "arr1 = 1 * jnp.ones((1, 2, 2))\n", "arr2 = 2 * jnp.ones((1, 3, 3))\n", "\n", "assert (\n", " sax.block_diag(arr1, arr2)\n", " == jnp.array(\n", " [\n", " [\n", " [1.0, 1.0, 0.0, 0.0, 0.0],\n", " [1.0, 1.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 2.0, 2.0, 2.0],\n", " [0.0, 0.0, 2.0, 2.0, 2.0],\n", " [0.0, 0.0, 2.0, 2.0, 2.0],\n", " ]\n", " ]\n", " )\n", ").all()" ] }, { "cell_type": "code", "execution_count": null, "id": "cfe8bf00", "metadata": { "tags": [] }, "outputs": [], "source": [ "assert sax.clean_string(\"Hello, string 1.0\") == \"Hello__string_1p0\"" ] }, { "cell_type": "code", "execution_count": null, "id": "7c7a72e4", "metadata": { "tags": [] }, "outputs": [], "source": [ "orig_settings = {\"a\": 3, \"c\": jnp.array([9.0, 10.0, 11.0])}\n", "new_settings = sax.copy_settings(orig_settings)\n", "\n", "assert orig_settings[\"a\"] == new_settings[\"a\"]\n", "assert jnp.all(orig_settings[\"c\"] == new_settings[\"c\"])\n", "new_settings[\"a\"] = jnp.array(5.0)\n", "assert orig_settings[\"a\"] == 3\n", "assert new_settings[\"a\"] == 5\n", "assert orig_settings[\"c\"] is new_settings[\"c\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "dc454e7d", "metadata": { "tags": [] }, "outputs": [], "source": [ "nested_dict = {\n", " \"a\": 3.0,\n", " \"b\": {\"c\": 4.0},\n", "}\n", "\n", "flat_dict = sax.flatten_dict(nested_dict, sep=\",\")\n", "assert flat_dict == {\"a\": 3.0, \"b,c\": 4.0}" ] }, { "cell_type": "code", "execution_count": null, "id": "cd8b221b", "metadata": { "tags": [] }, "outputs": [], "source": [ "assert sax.unflatten_dict(flat_dict, sep=\",\") == nested_dict" ] }, { "cell_type": "code", "execution_count": null, "id": "c70af810", "metadata": { "tags": [] }, "outputs": [], "source": [ "sax.scoo({(\"in0\", \"out0\"): 1.0})" ] }, { "cell_type": "code", "execution_count": null, "id": "a61a5387-87fc-4f75-b415-1232d54f3fa2", "metadata": { "tags": [] }, "outputs": [], "source": [ "def coupler(coupling=0.5):\n", " return {\n", " (\"in0\", \"out0\"): coupling**0.5,\n", " (\"in0\", \"out1\"): 1j*coupling**0.5,\n", " (\"in1\", \"out0\"): 1j*coupling**0.5,\n", " (\"in1\", \"out1\"): coupling**0.5,\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "dc2eecbe", "metadata": { "tags": [] }, "outputs": [], "source": [ "model = coupler\n", "assert sax.get_ports(model) == (\"in0\", \"in1\", \"out0\", \"out1\")\n", "\n", "sdict_ = coupler()\n", "assert sax.get_ports(sdict_) == (\"in0\", \"in1\", \"out0\", \"out1\")\n", "\n", "scoo_ = sax.scoo(sdict_)\n", "assert sax.get_ports(scoo_) == (\"in0\", \"in1\", \"out0\", \"out1\")\n", "\n", "sdense_ = sax.sdense(sdict_)\n", "assert sax.get_ports(sdense_) == (\"in0\", \"in1\", \"out0\", \"out1\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cf00a0c0", "metadata": { "tags": [] }, "outputs": [], "source": [ "model = coupler\n", "assert sax.get_port_combinations(model) == (\n", " (\"in0\", \"out0\"),\n", " (\"in0\", \"out1\"),\n", " (\"in1\", \"out0\"),\n", " (\"in1\", \"out1\"),\n", ")\n", "\n", "sdict_ = coupler()\n", "assert sax.get_port_combinations(sdict_) == (\n", " (\"in0\", \"out0\"),\n", " (\"in0\", \"out1\"),\n", " (\"in1\", \"out0\"),\n", " (\"in1\", \"out1\"),\n", ")\n", "\n", "scoo_ = sax.scoo(sdict_)\n", "assert sax.get_port_combinations(scoo_) == (\n", " (\"in0\", \"out0\"),\n", " (\"in0\", \"out1\"),\n", " (\"in1\", \"out0\"),\n", " (\"in1\", \"out1\"),\n", ")\n", "\n", "sdense_ = sax.sdense(sdict_)\n", "assert sax.get_port_combinations(sdense_) == (\n", " (\"in0\", \"in0\"),\n", " (\"in0\", \"in1\"),\n", " (\"in0\", \"out0\"),\n", " (\"in0\", \"out1\"),\n", " (\"in1\", \"in0\"),\n", " (\"in1\", \"in1\"),\n", " (\"in1\", \"out0\"),\n", " (\"in1\", \"out1\"),\n", " (\"out0\", \"in0\"),\n", " (\"out0\", \"in1\"),\n", " (\"out0\", \"out0\"),\n", " (\"out0\", \"out1\"),\n", " (\"out1\", \"in0\"),\n", " (\"out1\", \"in1\"),\n", " (\"out1\", \"out0\"),\n", " (\"out1\", \"out1\"),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "2deefdc8", "metadata": { "tags": [] }, "outputs": [], "source": [ "assert sax.get_settings(coupler) == {'coupling': 0.5}" ] }, { "cell_type": "code", "execution_count": null, "id": "8e06bfe4", "metadata": { "tags": [] }, "outputs": [], "source": [ "# hide\n", "\n", "wls = jnp.array([2.19999, 2.20001, 2.22499, 2.22501, 2.24999, 2.25001, 2.27499, 2.27501, 2.29999, 2.30001, 2.32499, 2.32501, 2.34999, 2.35001, 2.37499, 2.37501, 2.39999, 2.40001, 2.42499, 2.42501, 2.44999, 2.45001])\n", "phis = jnp.array([5.17317336, 5.1219654, 4.71259842, 4.66252492, 5.65699608, 5.60817922, 2.03697377, 1.98936119, 6.010146, 5.96358061, 4.96336733, 4.91777933, 5.13912198, 5.09451137, 0.22347545, 0.17979684, 2.74501894, 2.70224092, 0.10403192, 0.06214664, 4.83328794, 4.79225525])\n", "wl = jnp.array([2.21, 2.27, 1.31, 2.424])\n", "phi = jnp.array(sax.grouped_interp(wl, wls, phis))\n", "phi_ref = jnp.array([-1.4901831, 1.3595749, -1.110012 , 2.1775336])\n", "\n", "assert ((phi-phi_ref)**2 < 1e-5).all()" ] }, { "cell_type": "code", "execution_count": null, "id": "27cf74d7", "metadata": { "tags": [] }, "outputs": [], "source": [ "d = sax.merge_dicts({\"a\": 3}, {\"b\": 4})\n", "assert d[\"a\"] == 3\n", "assert d[\"b\"] == 4\n", "assert tuple(sorted(d)) == (\"a\", \"b\")\n", "\n", "d = sax.merge_dicts({\"a\": 3}, {\"a\": 4})\n", "assert d[\"a\"] == 4\n", "assert tuple(d) == (\"a\",)\n", "\n", "d = sax.merge_dicts({\"a\": 3}, {\"a\": {\"b\": 5}})\n", "assert d[\"a\"][\"b\"] == 5\n", "assert tuple(d) == (\"a\",)\n", "\n", "d = sax.merge_dicts({\"a\": {\"b\": 5}}, {\"a\": 3})\n", "assert d[\"a\"] == 3\n", "assert tuple(d) == (\"a\",)" ] }, { "cell_type": "code", "execution_count": null, "id": "5010a982", "metadata": { "tags": [] }, "outputs": [], "source": [ "assert sax.mode_combinations(modes=[\"te\", \"tm\"]) == (('te', 'te'), ('tm', 'tm'))\n", "assert sax.mode_combinations(modes=[\"te\", \"tm\"], cross=True) == (('te', 'te'), ('te', 'tm'), ('tm', 'te'), ('tm', 'tm'))" ] }, { "cell_type": "code", "execution_count": null, "id": "30a304a7", "metadata": { "tags": [] }, "outputs": [], "source": [ "sdict_ = {(\"in0\", \"out0\"): 1.0}\n", "assert sax.reciprocal(sdict_) == {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}" ] }, { "cell_type": "code", "execution_count": null, "id": "7b041b23", "metadata": { "tags": [] }, "outputs": [], "source": [ "def model(x=jnp.array(3.0), y=jnp.array(4.0), z=jnp.array([3.0, 4.0])) -> sax.SDict:\n", " return {(\"in0\", \"out0\"): jnp.array(3.0)}\n", "\n", "renamings = {\"x\": \"a\", \"y\": \"z\", \"z\": \"y\"}\n", "new_model = sax.rename_params(model, renamings)\n", "settings = sax.get_settings(new_model)\n", "assert settings[\"a\"] == 3.0\n", "assert settings[\"z\"] == 4.0\n", "assert jnp.all(settings[\"y\"] == jnp.array([3.0, 4.0]))" ] }, { "cell_type": "code", "execution_count": null, "id": "dd9c42ac", "metadata": { "tags": [] }, "outputs": [], "source": [ "d = sax.reciprocal({(\"p0\", \"p1\"): 0.1, (\"p1\", \"p2\"): 0.2})\n", "origports = sax.get_ports(d)\n", "renamings = {\"p0\": \"in0\", \"p1\": \"out0\", \"p2\": \"in1\"}\n", "d_ = sax.rename_ports(d, renamings)\n", "assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n", "d_ = sax.rename_ports(sax.scoo(d), renamings)\n", "assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))\n", "d_ = sax.rename_ports(sax.sdense(d), renamings)\n", "assert tuple(sorted(sax.get_ports(d_))) == tuple(sorted(renamings[p] for p in origports))" ] }, { "cell_type": "markdown", "id": "a7642618", "metadata": {}, "source": [ "Assuming you have a settings dictionary for a `circuit` containing a directional coupler `\"dc\"` and a waveguide `\"wg\"`:" ] }, { "cell_type": "code", "execution_count": null, "id": "85293a23", "metadata": { "tags": [] }, "outputs": [], "source": [ "settings = {\"wl\": 1.55, \"dc\": {\"coupling\": 0.5}, \"wg\": {\"wl\": 1.56, \"neff\": 2.33}}" ] }, { "cell_type": "markdown", "id": "68c1cd3b", "metadata": {}, "source": [ "You can update this settings dictionary with some global settings as follows. When updating settings globally like this, each subdictionary of the settings dictionary will be updated with these values (if the key exists in the subdictionary):" ] }, { "cell_type": "code", "execution_count": null, "id": "4a1aaf3c", "metadata": { "tags": [] }, "outputs": [], "source": [ "settings = sax.update_settings(settings, wl=1.3, coupling=0.3, neff=3.0)\n", "assert settings == {\"wl\": 1.3, \"dc\": {\"coupling\": 0.3}, \"wg\": {\"wl\": 1.3, \"neff\": 3.0}}" ] }, { "cell_type": "markdown", "id": "73737b1f", "metadata": {}, "source": [ "Alternatively, you can set certain settings for a specific component (e.g. 'wg' in this case) as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "87c32103", "metadata": { "tags": [] }, "outputs": [], "source": [ "settings = sax.update_settings(settings, \"wg\", wl=2.0)\n", "assert settings == {\"wl\": 1.3, \"dc\": {\"coupling\": 0.3}, \"wg\": {\"wl\": 2.0, \"neff\": 3.0}}" ] }, { "cell_type": "markdown", "id": "47d3b775", "metadata": {}, "source": [ "note that only the `\"wl\"` belonging to `\"wg\"` has changed." ] }, { "cell_type": "code", "execution_count": null, "id": "b47e00d4", "metadata": { "tags": [] }, "outputs": [], "source": [ "sdict = {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}\n", "sax.validate_not_mixedmode(sdict)\n", "\n", "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0@tm\", \"in0@tm\"): 1.0}\n", "sax.validate_not_mixedmode(sdict)\n", "\n", "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0\", \"in0@tm\"): 1.0}\n", "with pytest.raises(ValueError):\n", " sax.validate_not_mixedmode(sdict)" ] }, { "cell_type": "code", "execution_count": null, "id": "41fb0142", "metadata": { "tags": [] }, "outputs": [], "source": [ "sdict = {(\"in0\", \"out0\"): 1.0, (\"out0\", \"in0\"): 1.0}\n", "with pytest.raises(ValueError):\n", " sax.validate_multimode(sdict)\n", "\n", "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0@tm\", \"in0@tm\"): 1.0}\n", "sax.validate_multimode(sdict)\n", "\n", "sdict = {(\"in0@te\", \"out0@te\"): 1.0, (\"out0\", \"in0@tm\"): 1.0}\n", "with pytest.raises(ValueError):\n", " sax.validate_multimode(sdict)" ] }, { "cell_type": "code", "execution_count": null, "id": "fe25c362", "metadata": { "tags": [] }, "outputs": [], "source": [ "good_sdict = sax.reciprocal({(\"p0\", \"p1\"): 0.1, \n", " (\"p1\", \"p2\"): 0.2})\n", "assert sax.validate_sdict(good_sdict) is None\n", "\n", "bad_sdict = {\n", " \"p0,p1\": 0.1,\n", " (\"p1\", \"p2\"): 0.2,\n", "}\n", "with pytest.raises(ValueError):\n", " sax.validate_sdict(bad_sdict)" ] }, { "cell_type": "code", "execution_count": null, "id": "b9ba032d", "metadata": { "tags": [] }, "outputs": [], "source": [ "assert sax.get_inputs_outputs([\"in0\", \"out0\"]) == (('in0',), ('out0',))\n", "assert sax.get_inputs_outputs([\"in0\", \"in1\"]) == (('in0', 'in1'), ())\n", "assert sax.get_inputs_outputs([\"out0\", \"out1\"]) == ((), ('out0', 'out1'))\n", "assert sax.get_inputs_outputs([\"out0\", \"dc0\"]) == (('dc0',), ('out0',))\n", "assert sax.get_inputs_outputs([\"dc0\", \"in0\"]) == (('in0',), ('dc0',))" ] } ], "metadata": { "kernelspec": { "display_name": "sax", "language": "python", "name": "sax" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }