{ "cells": [ { "cell_type": "markdown", "id": "ab894499", "metadata": { "papermill": { "duration": 0.010359, "end_time": "2023-09-03T04:40:07.944293", "exception": false, "start_time": "2023-09-03T04:40:07.933934", "status": "completed" }, "tags": [] }, "source": [ "# SAX Quick Start\n", "> Let's go over the core functionality of SAX." ] }, { "cell_type": "markdown", "id": "3c91396d", "metadata": { "papermill": { "duration": 0.021667, "end_time": "2023-09-03T04:40:08.076325", "exception": false, "start_time": "2023-09-03T04:40:08.054658", "status": "completed" }, "tags": [] }, "source": [ "## Environment variables" ] }, { "cell_type": "markdown", "id": "79240f32", "metadata": { "papermill": { "duration": 0.00901, "end_time": "2023-09-03T04:40:08.094208", "exception": false, "start_time": "2023-09-03T04:40:08.085198", "status": "completed" }, "tags": [] }, "source": [ "SAX is based on JAX... here are some useful environment variables for working with JAX:" ] }, { "cell_type": "code", "execution_count": null, "id": "40ce30d2", "metadata": { "papermill": { "duration": 0.491056, "end_time": "2023-09-03T04:40:08.594137", "exception": false, "start_time": "2023-09-03T04:40:08.103081", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# select float32 or float64 as default dtype\n", "%env JAX_ENABLE_X64=0\n", "\n", "# select cpu or gpu\n", "%env JAX_PLATFORM_NAME=cpu\n", "\n", "# set custom CUDA location for gpu:\n", "%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda\n", "\n", "# Using GPU?\n", "from jax.lib import xla_bridge\n", "print(xla_bridge.get_backend().platform)" ] }, { "cell_type": "markdown", "id": "4938b729-8b27-406f-b5fd-e118a9fb39f1", "metadata": { "papermill": { "duration": 0.017299, "end_time": "2023-09-03T04:40:08.621013", "exception": false, "start_time": "2023-09-03T04:40:08.603714", "status": "completed" }, "tags": [] }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "e0450fde", "metadata": { "papermill": { "duration": 1.389344, "end_time": "2023-09-03T04:40:10.020833", "exception": false, "start_time": "2023-09-03T04:40:08.631489", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import jax\n", "import jax.example_libraries.optimizers as opt\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt # plotting\n", "import sax\n", "from tqdm.notebook import trange" ] }, { "cell_type": "markdown", "id": "8c4d4c4e", "metadata": { "papermill": { "duration": 0.009958, "end_time": "2023-09-03T04:40:10.041225", "exception": false, "start_time": "2023-09-03T04:40:10.031267", "status": "completed" }, "tags": [] }, "source": [ "## Scatter *dictionaries*\n", "The core datastructure for specifying scatter parameters in SAX is a dictionary... more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann `SDict` in SAX (`SDict ≈ Dict[Tuple[str,str], float]`).\n", "\n", "Dictionaries are in fact much better suited for characterizing S-parameters than, say, (jax-)numpy arrays due to the inherent sparse nature of scatter parameters. Moreover, dictonaries allow for string indexing, which makes them much more pleasant to use in this context. Let’s for example create an `SDict` for a 50/50 coupler:" ] }, { "cell_type": "markdown", "id": "a7f90205", "metadata": { "papermill": { "duration": 0.010042, "end_time": "2023-09-03T04:40:10.061689", "exception": false, "start_time": "2023-09-03T04:40:10.051647", "status": "completed" }, "tags": [] }, "source": [ "```\n", "in1 out1\n", " \\ /\n", " ========\n", " / \\\n", "in0 out0\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "89f215e8", "metadata": { "papermill": { "duration": 0.023886, "end_time": "2023-09-03T04:40:10.095541", "exception": false, "start_time": "2023-09-03T04:40:10.071655", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "coupling = 0.5\n", "kappa = coupling ** 0.5\n", "tau = (1 - coupling) ** 0.5\n", "coupler_dict = {\n", " (\"in0\", \"out0\"): tau,\n", " (\"out0\", \"in0\"): tau,\n", " (\"in0\", \"out1\"): 1j * kappa,\n", " (\"out1\", \"in0\"): 1j * kappa,\n", " (\"in1\", \"out0\"): 1j * kappa,\n", " (\"out0\", \"in1\"): 1j * kappa,\n", " (\"in1\", \"out1\"): tau,\n", " (\"out1\", \"in1\"): tau,\n", "}\n", "coupler_dict" ] }, { "cell_type": "markdown", "id": "bb9cb2b6", "metadata": { "papermill": { "duration": 0.009826, "end_time": "2023-09-03T04:40:10.114856", "exception": false, "start_time": "2023-09-03T04:40:10.105030", "status": "completed" }, "tags": [] }, "source": [ "Only the non-zero port combinations need to be specified. Any non-existent port-combination (for example `(\"in0\", \"in1\")`) is considered to be zero by SAX.\n", "\n", "Obviously, it can still be tedious to specify every port in the circuit manually. SAX therefore offers the `reciprocal` function, which auto-fills the reverse connection if the forward connection exist. For example:" ] }, { "cell_type": "code", "execution_count": null, "id": "494f445d", "metadata": { "papermill": { "duration": 0.018928, "end_time": "2023-09-03T04:40:10.145552", "exception": false, "start_time": "2023-09-03T04:40:10.126624", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "coupler_dict = sax.reciprocal(\n", " {\n", " (\"in0\", \"out0\"): tau,\n", " (\"in0\", \"out1\"): 1j * kappa,\n", " (\"in1\", \"out0\"): 1j * kappa,\n", " (\"in1\", \"out1\"): tau,\n", " }\n", ")\n", "\n", "coupler_dict" ] }, { "cell_type": "markdown", "id": "67e730f6", "metadata": { "papermill": { "duration": 0.010589, "end_time": "2023-09-03T04:40:10.166287", "exception": false, "start_time": "2023-09-03T04:40:10.155698", "status": "completed" }, "tags": [] }, "source": [ "## Parametrized Models" ] }, { "cell_type": "markdown", "id": "ec1e1214", "metadata": { "papermill": { "duration": 0.0097, "end_time": "2023-09-03T04:40:10.185857", "exception": false, "start_time": "2023-09-03T04:40:10.176157", "status": "completed" }, "tags": [] }, "source": [ "Constructing such an `SDict` is easy, however, usually we're more interested in having parametrized models for our components. To parametrize the coupler `SDict`, just wrap it in a function to obtain a SAX `Model`, which is a keyword-only function mapping to an `SDict`:" ] }, { "cell_type": "code", "execution_count": null, "id": "33efb99f", "metadata": { "papermill": { "duration": 0.026393, "end_time": "2023-09-03T04:40:10.222059", "exception": false, "start_time": "2023-09-03T04:40:10.195666", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def coupler(coupling=0.5) -> sax.SDict:\n", " kappa = coupling ** 0.5\n", " tau = (1 - coupling) ** 0.5\n", " coupler_dict = sax.reciprocal(\n", " {\n", " (\"in0\", \"out0\"): tau,\n", " (\"in0\", \"out1\"): 1j * kappa,\n", " (\"in1\", \"out0\"): 1j * kappa,\n", " (\"in1\", \"out1\"): tau,\n", " }\n", " )\n", " return coupler_dict\n", "\n", "\n", "coupler(coupling=0.3)" ] }, { "cell_type": "markdown", "id": "4ef93f65", "metadata": { "papermill": { "duration": 0.032631, "end_time": "2023-09-03T04:40:10.264877", "exception": false, "start_time": "2023-09-03T04:40:10.232246", "status": "completed" }, "tags": [] }, "source": [ "We can define a waveguide in the same way:" ] }, { "cell_type": "code", "execution_count": null, "id": "5eac9176", "metadata": { "papermill": { "duration": 0.018719, "end_time": "2023-09-03T04:40:10.310448", "exception": false, "start_time": "2023-09-03T04:40:10.291729", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:\n", " dwl = wl - wl0\n", " dneff_dwl = (ng - neff) / wl0\n", " neff = neff - dwl * dneff_dwl\n", " phase = 2 * jnp.pi * neff * length / wl\n", " transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)\n", " sdict = sax.reciprocal(\n", " {\n", " (\"in0\", \"out0\"): transmission,\n", " }\n", " )\n", " return sdict" ] }, { "cell_type": "markdown", "id": "67713c18", "metadata": { "papermill": { "duration": 0.038693, "end_time": "2023-09-03T04:40:10.358565", "exception": false, "start_time": "2023-09-03T04:40:10.319872", "status": "completed" }, "tags": [] }, "source": [ "That's pretty straightforward. Let's now move on to parametrized circuits:" ] }, { "cell_type": "markdown", "id": "579a9ce1", "metadata": { "papermill": { "duration": 0.027416, "end_time": "2023-09-03T04:40:10.395287", "exception": false, "start_time": "2023-09-03T04:40:10.367871", "status": "completed" }, "tags": [] }, "source": [ "## Circuit Models" ] }, { "cell_type": "markdown", "id": "cc3c78a0", "metadata": { "papermill": { "duration": 0.010096, "end_time": "2023-09-03T04:40:10.415180", "exception": false, "start_time": "2023-09-03T04:40:10.405084", "status": "completed" }, "tags": [] }, "source": [ "Existing models can now be combined into a circuit using `sax.circuit`, which basically creates a new `Model` function:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6885e1d", "metadata": { "papermill": { "duration": 0.831818, "end_time": "2023-09-03T04:40:11.257250", "exception": false, "start_time": "2023-09-03T04:40:10.425432", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi, info = sax.circuit(\n", " netlist={\n", " \"instances\": {\n", " \"lft\": \"coupler\",\n", " \"top\": \"waveguide\",\n", " \"btm\": \"waveguide\",\n", " \"rgt\": \"coupler\",\n", " },\n", " \"connections\": {\n", " \"lft,out0\": \"btm,in0\",\n", " \"btm,out0\": \"rgt,in0\",\n", " \"lft,out1\": \"top,in0\",\n", " \"top,out0\": \"rgt,in1\",\n", " },\n", " \"ports\": {\n", " \"in0\": \"lft,in0\",\n", " \"in1\": \"lft,in1\",\n", " \"out0\": \"rgt,out0\",\n", " \"out1\": \"rgt,out1\",\n", " },\n", " },\n", " models={\n", " \"coupler\": coupler,\n", " \"waveguide\": waveguide,\n", " }\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "e95325c6", "metadata": { "papermill": { "duration": 0.045163, "end_time": "2023-09-03T04:40:11.308446", "exception": false, "start_time": "2023-09-03T04:40:11.263283", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi?" ] }, { "cell_type": "markdown", "id": "d2c74cff", "metadata": { "papermill": { "duration": 0.009796, "end_time": "2023-09-03T04:40:11.327604", "exception": false, "start_time": "2023-09-03T04:40:11.317808", "status": "completed" }, "tags": [] }, "source": [ "The `circuit` function just creates a similar function as we created for the waveguide and the coupler, but in stead of taking parameters directly it takes parameter *dictionaries* for each of the instances in the circuit. The keys in these parameter dictionaries should correspond to the keyword arguments of each individual subcomponent. " ] }, { "cell_type": "markdown", "id": "ba1202d1", "metadata": { "papermill": { "duration": 0.009578, "end_time": "2023-09-03T04:40:11.346468", "exception": false, "start_time": "2023-09-03T04:40:11.336890", "status": "completed" }, "tags": [] }, "source": [ "Let's now do a simulation for the MZI we just constructed:" ] }, { "cell_type": "code", "execution_count": null, "id": "f2108f6a", "metadata": { "papermill": { "duration": 0.751356, "end_time": "2023-09-03T04:40:12.107394", "exception": false, "start_time": "2023-09-03T04:40:11.356038", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%time mzi()" ] }, { "cell_type": "code", "execution_count": null, "id": "cb43b835", "metadata": { "papermill": { "duration": 0.011038, "end_time": "2023-09-03T04:40:12.124654", "exception": false, "start_time": "2023-09-03T04:40:12.113616", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi2 = jax.jit(mzi)" ] }, { "cell_type": "code", "execution_count": null, "id": "61f2bb6e", "metadata": { "papermill": { "duration": 0.379308, "end_time": "2023-09-03T04:40:12.510922", "exception": false, "start_time": "2023-09-03T04:40:12.131614", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%time mzi2()" ] }, { "cell_type": "code", "execution_count": null, "id": "6000f608-6814-436d-af02-a54dc7d0da85", "metadata": { "papermill": { "duration": 0.028438, "end_time": "2023-09-03T04:40:12.550611", "exception": false, "start_time": "2023-09-03T04:40:12.522173", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%time mzi2()" ] }, { "cell_type": "markdown", "id": "1e0d41ea", "metadata": { "papermill": { "duration": 0.014407, "end_time": "2023-09-03T04:40:12.578298", "exception": false, "start_time": "2023-09-03T04:40:12.563891", "status": "completed" }, "tags": [] }, "source": [ "Or in the case we want an MZI with different arm lengths:" ] }, { "cell_type": "code", "execution_count": null, "id": "4f53bdc3", "metadata": { "papermill": { "duration": 0.123415, "end_time": "2023-09-03T04:40:12.713195", "exception": false, "start_time": "2023-09-03T04:40:12.589780", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi(top={\"length\": 25.0}, btm={\"length\": 15.0})" ] }, { "cell_type": "markdown", "id": "a331252a", "metadata": { "papermill": { "duration": 0.006598, "end_time": "2023-09-03T04:40:12.730584", "exception": false, "start_time": "2023-09-03T04:40:12.723986", "status": "completed" }, "tags": [] }, "source": [ "## Simulating the parametrized MZI" ] }, { "cell_type": "markdown", "id": "4ba8626b", "metadata": { "papermill": { "duration": 0.006519, "end_time": "2023-09-03T04:40:12.743613", "exception": false, "start_time": "2023-09-03T04:40:12.737094", "status": "completed" }, "tags": [] }, "source": [ "We can simulate the above mzi for multiple wavelengths as well by specifying the wavelength at the top level of the circuit call. Each setting specified at the top level of the circuit call will be propagated to all subcomponents of the circuit which have that setting:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6c8b5c9", "metadata": { "papermill": { "duration": 0.536233, "end_time": "2023-09-03T04:40:13.286380", "exception": false, "start_time": "2023-09-03T04:40:12.750147", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "wl = jnp.linspace(1.51, 1.59, 1000)\n", "%time S = mzi(wl=wl, top={\"length\": 25.0}, btm={\"length\": 15.0})" ] }, { "cell_type": "markdown", "id": "28d8cc75", "metadata": { "papermill": { "duration": 0.010658, "end_time": "2023-09-03T04:40:13.308604", "exception": false, "start_time": "2023-09-03T04:40:13.297946", "status": "completed" }, "tags": [] }, "source": [ "Let's see what this gives:" ] }, { "cell_type": "code", "execution_count": null, "id": "f7de73dc", "metadata": { "papermill": { "duration": 0.527485, "end_time": "2023-09-03T04:40:13.847003", "exception": false, "start_time": "2023-09-03T04:40:13.319518", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "plt.plot(wl * 1e3, abs(S[\"in0\", \"out0\"]) ** 2)\n", "plt.ylim(-0.05, 1.05)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "eb9f0074", "metadata": { "papermill": { "duration": 0.011082, "end_time": "2023-09-03T04:40:13.869632", "exception": false, "start_time": "2023-09-03T04:40:13.858550", "status": "completed" }, "tags": [] }, "source": [ "## Optimization" ] }, { "cell_type": "markdown", "id": "431aee51", "metadata": { "papermill": { "duration": 0.011031, "end_time": "2023-09-03T04:40:13.891520", "exception": false, "start_time": "2023-09-03T04:40:13.880489", "status": "completed" }, "tags": [] }, "source": [ "We'd like to optimize an MZI such that one of the minima is at 1550nm. To do this, we need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:" ] }, { "cell_type": "code", "execution_count": null, "id": "7dabe735", "metadata": { "papermill": { "duration": 0.014642, "end_time": "2023-09-03T04:40:13.917117", "exception": false, "start_time": "2023-09-03T04:40:13.902475", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "@jax.jit\n", "def loss_fn(delta_length):\n", " S = mzi(wl=1.55, top={\"length\": 15.0 + delta_length}, btm={\"length\": 15.0})\n", " return (abs(S[\"in0\", \"out0\"]) ** 2).mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "a60a8eab", "metadata": { "papermill": { "duration": 0.282056, "end_time": "2023-09-03T04:40:14.205914", "exception": false, "start_time": "2023-09-03T04:40:13.923858", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%time loss_fn(10.0)" ] }, { "cell_type": "markdown", "id": "af1eb06b", "metadata": { "papermill": { "duration": 0.010349, "end_time": "2023-09-03T04:40:14.227615", "exception": false, "start_time": "2023-09-03T04:40:14.217266", "status": "completed" }, "tags": [] }, "source": [ "We can use this loss function to define a grad function which works on the parameters of the loss function:" ] }, { "cell_type": "code", "execution_count": null, "id": "16369cba", "metadata": { "papermill": { "duration": 0.013245, "end_time": "2023-09-03T04:40:14.247411", "exception": false, "start_time": "2023-09-03T04:40:14.234166", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "grad_fn = jax.jit(\n", " jax.grad(\n", " loss_fn,\n", " argnums=0, # JAX gradient function for the first positional argument, jitted\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "2cd2f2cf", "metadata": { "papermill": { "duration": 0.006893, "end_time": "2023-09-03T04:40:14.261155", "exception": false, "start_time": "2023-09-03T04:40:14.254262", "status": "completed" }, "tags": [] }, "source": [ "Next, we need to define a JAX optimizer, which on its own is nothing more than three more functions: an initialization function with which to initialize the optimizer state, an update function which will update the optimizer state (and with it the model parameters). The third function that's being returned will give the model parameters given the optimizer state." ] }, { "cell_type": "code", "execution_count": null, "id": "e553d69e", "metadata": { "papermill": { "duration": 0.013543, "end_time": "2023-09-03T04:40:14.281648", "exception": false, "start_time": "2023-09-03T04:40:14.268105", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "initial_delta_length = 10.0\n", "init_fn, update_fn, params_fn = opt.adam(step_size=0.1)\n", "state = init_fn(initial_delta_length)" ] }, { "cell_type": "markdown", "id": "e0a03fc2", "metadata": { "papermill": { "duration": 0.006686, "end_time": "2023-09-03T04:40:14.295129", "exception": false, "start_time": "2023-09-03T04:40:14.288443", "status": "completed" }, "tags": [] }, "source": [ "Given all this, a single training step can be defined:" ] }, { "cell_type": "code", "execution_count": null, "id": "eea4f7b8", "metadata": { "papermill": { "duration": 0.013079, "end_time": "2023-09-03T04:40:14.315667", "exception": false, "start_time": "2023-09-03T04:40:14.302588", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def step_fn(step, state):\n", " settings = params_fn(state)\n", " loss = loss_fn(settings)\n", " grad = grad_fn(settings)\n", " state = update_fn(step, grad, state)\n", " return loss, state" ] }, { "cell_type": "markdown", "id": "ba1612df", "metadata": { "papermill": { "duration": 0.006976, "end_time": "2023-09-03T04:40:14.329838", "exception": false, "start_time": "2023-09-03T04:40:14.322862", "status": "completed" }, "tags": [] }, "source": [ "And we can use this step function to start the training of the MZI:" ] }, { "cell_type": "code", "execution_count": null, "id": "61ef4479", "metadata": { "papermill": { "duration": 5.76803, "end_time": "2023-09-03T04:40:20.104912", "exception": false, "start_time": "2023-09-03T04:40:14.336882", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "range_ = trange(300)\n", "for step in range_:\n", " loss, state = step_fn(step, state)\n", " range_.set_postfix(loss=f\"{loss:.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "eb9a1e20", "metadata": { "papermill": { "duration": 0.015812, "end_time": "2023-09-03T04:40:20.132271", "exception": false, "start_time": "2023-09-03T04:40:20.116459", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "delta_length = params_fn(state)\n", "delta_length" ] }, { "cell_type": "markdown", "id": "15d8754b", "metadata": { "papermill": { "duration": 0.006808, "end_time": "2023-09-03T04:40:20.146325", "exception": false, "start_time": "2023-09-03T04:40:20.139517", "status": "completed" }, "tags": [] }, "source": [ "Let's see what we've got over a range of wavelengths:" ] }, { "cell_type": "code", "execution_count": null, "id": "b36d10a3", "metadata": { "papermill": { "duration": 0.197916, "end_time": "2023-09-03T04:40:20.351213", "exception": false, "start_time": "2023-09-03T04:40:20.153297", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "S = mzi(wl=wl, top={\"length\": 15.0 + delta_length}, btm={\"length\": 15.0})\n", "plt.plot(wl * 1e3, abs(S[\"in1\", \"out1\"]) ** 2)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.plot([1550, 1550], [0, 1])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a990da96", "metadata": { "papermill": { "duration": 0.007276, "end_time": "2023-09-03T04:40:20.366561", "exception": false, "start_time": "2023-09-03T04:40:20.359285", "status": "completed" }, "tags": [] }, "source": [ "The minimum of the MZI is perfectly located at 1550nm." ] }, { "cell_type": "markdown", "id": "7d93aa85", "metadata": { "papermill": { "duration": 0.007199, "end_time": "2023-09-03T04:40:20.381013", "exception": false, "start_time": "2023-09-03T04:40:20.373814", "status": "completed" }, "tags": [] }, "source": [ "## MZI Chain" ] }, { "cell_type": "markdown", "id": "9c312e13", "metadata": { "papermill": { "duration": 0.007248, "end_time": "2023-09-03T04:40:20.396255", "exception": false, "start_time": "2023-09-03T04:40:20.389007", "status": "completed" }, "tags": [] }, "source": [ "Let's now create a chain of MZIs. For this, we first create a subcomponent: a directional coupler with arms:\n", "\n", "\n", "```\n", " top\n", " in0 ----- out0 -> out1\n", " in1 <- in1 out1 \n", " \\ dc / \n", " ====== \n", " / \\ \n", " in0 <- in0 out0 btm \n", " in0 ----- out0 -> out0\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "1ebe80c2", "metadata": { "papermill": { "duration": 0.29058, "end_time": "2023-09-03T04:40:20.694616", "exception": false, "start_time": "2023-09-03T04:40:20.404036", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "dc_with_arms, info = sax.circuit(\n", " netlist={\n", " \"instances\": {\n", " \"lft\": \"coupler\",\n", " \"top\": \"waveguide\",\n", " \"btm\": \"waveguide\",\n", " },\n", " \"connections\": {\n", " \"lft,out0\": \"btm,in0\",\n", " \"lft,out1\": \"top,in0\",\n", " },\n", " \"ports\": {\n", " \"in0\": \"lft,in0\",\n", " \"in1\": \"lft,in1\",\n", " \"out0\": \"btm,out0\",\n", " \"out1\": \"top,out0\",\n", " },\n", " },\n", " models = {\n", " \"coupler\": coupler,\n", " \"waveguide\": waveguide,\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "1eda468f", "metadata": { "papermill": { "duration": 0.006912, "end_time": "2023-09-03T04:40:20.708703", "exception": false, "start_time": "2023-09-03T04:40:20.701791", "status": "completed" }, "tags": [] }, "source": [ "An MZI chain can now be created by cascading these directional couplers with arms:\n", "\n", "```\n", " _ _ _ _ _ _ \n", " \\/ \\/ \\/ \\/ ... \\/ \\/ \n", " /\\_ /\\_ /\\_ /\\_ /\\_ /\\_ \n", "```\n", "\n", "Let's create a *model factory* (`ModelFactory`) for this. In SAX, a *model factory* is any keyword-only function that generates a `Model`:" ] }, { "cell_type": "code", "execution_count": null, "id": "11c4ab07", "metadata": { "papermill": { "duration": 0.013119, "end_time": "2023-09-03T04:40:20.728750", "exception": false, "start_time": "2023-09-03T04:40:20.715631", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def mzi_chain(num_mzis=1) -> sax.Model:\n", " chain, _ = sax.circuit(\n", " netlist={\n", " \"instances\": {f\"dc{i}\": \"dc_with_arms\" for i in range(num_mzis + 1)},\n", " \"connections\": {\n", " **{f\"dc{i},out0\": f\"dc{i+1},in0\" for i in range(num_mzis)},\n", " **{f\"dc{i},out1\": f\"dc{i+1},in1\" for i in range(num_mzis)},\n", " },\n", " \"ports\": {\n", " \"in0\": f\"dc0,in0\",\n", " \"in1\": f\"dc0,in1\",\n", " \"out0\": f\"dc{num_mzis},out0\",\n", " \"out1\": f\"dc{num_mzis},out1\",\n", " },\n", " },\n", " models={\"dc_with_arms\": dc_with_arms},\n", " )\n", " return chain" ] }, { "cell_type": "markdown", "id": "03f225e9", "metadata": { "papermill": { "duration": 0.007011, "end_time": "2023-09-03T04:40:20.742896", "exception": false, "start_time": "2023-09-03T04:40:20.735885", "status": "completed" }, "tags": [] }, "source": [ "Let's for example create a chain with 15 MZIs. We can also update the settings dictionary as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "576d4de8", "metadata": { "papermill": { "duration": 0.432693, "end_time": "2023-09-03T04:40:21.183532", "exception": false, "start_time": "2023-09-03T04:40:20.750839", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "chain = mzi_chain(num_mzis=15)\n", "settings = sax.get_settings(chain)\n", "for dc in settings:\n", " settings[dc][\"top\"][\"length\"] = 25.0\n", " settings[dc][\"btm\"][\"length\"] = 15.0\n", "settings = sax.update_settings(settings, wl=jnp.linspace(1.5, 1.6, 1000))" ] }, { "cell_type": "markdown", "id": "11e3a10f", "metadata": { "papermill": { "duration": 0.007573, "end_time": "2023-09-03T04:40:21.198796", "exception": false, "start_time": "2023-09-03T04:40:21.191223", "status": "completed" }, "tags": [] }, "source": [ "We can simulate this:" ] }, { "cell_type": "code", "execution_count": null, "id": "b6e1155a", "metadata": { "papermill": { "duration": 9.273413, "end_time": "2023-09-03T04:40:30.479689", "exception": false, "start_time": "2023-09-03T04:40:21.206276", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%time S = chain(**settings) # time to evaluate the MZI\n", "func = jax.jit(chain)\n", "%time S = func(**settings) # time to jit the MZI\n", "%time S = func(**settings) # time to evaluate the MZI after jitting" ] }, { "cell_type": "markdown", "id": "70db7179", "metadata": { "papermill": { "duration": 0.008382, "end_time": "2023-09-03T04:40:30.496669", "exception": false, "start_time": "2023-09-03T04:40:30.488287", "status": "completed" }, "tags": [] }, "source": [ "Where we see that the unjitted evaluation of the MZI chain takes about a second, while the jitting of the MZI chain takes about two minutes (on a CPU). However, after the MZI chain has been jitted, the evaluation is in the order of about a few milliseconds!\n", "\n", "Anyway, let's see what this gives:" ] }, { "cell_type": "code", "execution_count": null, "id": "1e4a2aab", "metadata": { "papermill": { "duration": 0.160355, "end_time": "2023-09-03T04:40:30.665472", "exception": false, "start_time": "2023-09-03T04:40:30.505117", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "S = sax.sdict(S)\n", "plt.plot(1e3 * settings[\"dc0\"][\"top\"][\"wl\"], jnp.abs(S[\"in0\", \"out0\"]) ** 2)\n", "plt.ylim(-0.05, 1.05)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "sax", "language": "python", "name": "sax" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "papermill": { "default_parameters": {}, "duration": 24.881158, "end_time": "2023-09-03T04:40:31.292856", "environment_variables": {}, "exception": null, "input_path": "./01_quick_start.ipynb", "output_path": "./01_quick_start.ipynb", "parameters": {}, "start_time": "2023-09-03T04:40:06.411698", "version": "2.4.0" } }, "nbformat": 4, "nbformat_minor": 5 }