{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": { "papermill": { "duration": 0.022628, "end_time": "2025-02-27T06:29:35.365228", "exception": false, "start_time": "2025-02-27T06:29:35.342600", "status": "completed" }, "tags": [] }, "source": [ "# Layout Aware Monte Carlo with GDSFactory\n", "> Towards layout-aware optimization and monte-carlo simulations" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": { "papermill": { "duration": 0.042821, "end_time": "2025-02-27T06:29:35.432773", "exception": false, "start_time": "2025-02-27T06:29:35.389952", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# hide\n", "import os\n", "\n", "os.environ[\"LOGURU_LEVEL\"] = \"CRITICAL\"\n", "# import warnings\n", "# warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "papermill": { "duration": 7.045413, "end_time": "2025-02-27T06:29:42.529062", "exception": false, "start_time": "2025-02-27T06:29:35.483649", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import itertools\n", "import json\n", "import os\n", "import sys\n", "from functools import partial\n", "from typing import List\n", "\n", "import gdsfactory as gf # conda install gdsfactory\n", "import jax\n", "import jax.example_libraries.optimizers as opt\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import meow as mw\n", "import numpy as np\n", "import sax\n", "from numpy.fft import fft2, fftfreq, fftshift, ifft2\n", "from tqdm.notebook import tqdm, trange" ] }, { "cell_type": "markdown", "id": "3", "metadata": { "papermill": { "duration": 0.017043, "end_time": "2025-02-27T06:29:42.562487", "exception": false, "start_time": "2025-02-27T06:29:42.545444", "status": "completed" }, "tags": [] }, "source": [ "## Simple MZI Layout" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": { "papermill": { "duration": 0.029744, "end_time": "2025-02-27T06:29:42.607687", "exception": false, "start_time": "2025-02-27T06:29:42.577943", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "@gf.cell\n", "def simple_mzi():\n", " c = gf.Component()\n", "\n", " # components\n", " mmi_in = gf.components.mmi1x2()\n", " mmi_out = gf.components.mmi2x2()\n", " bend = gf.components.bend_euler()\n", " half_delay_straight = gf.components.straight(length=10.0)\n", "\n", " # references\n", " mmi_in = c.add_ref(mmi_in, name=\"mmi_in\")\n", " mmi_out = c.add_ref(mmi_out, name=\"mmi_out\")\n", " straight_top1 = c.add_ref(half_delay_straight, name=\"straight_top1\")\n", " straight_top2 = c.add_ref(half_delay_straight, name=\"straight_top2\")\n", " bend_top1 = c.add_ref(bend, name=\"bend_top1\")\n", " bend_top2 = c.add_ref(bend, name=\"bend_top2\").dmirror()\n", " bend_top3 = c.add_ref(bend, name=\"bend_top3\").dmirror()\n", " bend_top4 = c.add_ref(bend, name=\"bend_top4\")\n", " bend_btm1 = c.add_ref(bend, name=\"bend_btm1\").dmirror()\n", " bend_btm2 = c.add_ref(bend, name=\"bend_btm2\")\n", " bend_btm3 = c.add_ref(bend, name=\"bend_btm3\")\n", " bend_btm4 = c.add_ref(bend, name=\"bend_btm4\").dmirror()\n", "\n", " # connections\n", " bend_top1.connect(\"o1\", mmi_in.ports[\"o2\"])\n", " straight_top1.connect(\"o1\", bend_top1.ports[\"o2\"])\n", " bend_top2.connect(\"o1\", straight_top1.ports[\"o2\"])\n", " bend_top3.connect(\"o1\", bend_top2.ports[\"o2\"])\n", " straight_top2.connect(\"o1\", bend_top3.ports[\"o2\"])\n", " bend_top4.connect(\"o1\", straight_top2.ports[\"o2\"])\n", "\n", " bend_btm1.connect(\"o1\", mmi_in.ports[\"o3\"])\n", " bend_btm2.connect(\"o1\", bend_btm1.ports[\"o2\"])\n", " bend_btm3.connect(\"o1\", bend_btm2.ports[\"o2\"])\n", " bend_btm4.connect(\"o1\", bend_btm3.ports[\"o2\"])\n", "\n", " mmi_out.connect(\"o1\", bend_btm4.ports[\"o2\"])\n", "\n", " # ports\n", " c.add_port(\n", " \"o1\",\n", " port=mmi_in.ports[\"o1\"],\n", " )\n", " c.add_port(\"o2\", port=mmi_out.ports[\"o3\"])\n", " c.add_port(\"o3\", port=mmi_out.ports[\"o4\"])\n", " return c" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": { "papermill": { "duration": 0.392171, "end_time": "2025-02-27T06:29:43.012997", "exception": false, "start_time": "2025-02-27T06:29:42.620826", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi = simple_mzi()\n", "mzi" ] }, { "cell_type": "markdown", "id": "6", "metadata": { "papermill": { "duration": 0.028334, "end_time": "2025-02-27T06:29:43.063039", "exception": false, "start_time": "2025-02-27T06:29:43.034705", "status": "completed" }, "tags": [] }, "source": [ "## Simulate MZI\n", "\n", "We used the following components to construct the MZI circuit:\n", "\n", "- mmi1x2\n", "- mmi2x2\n", "- straight\n", "- bend_euler" ] }, { "cell_type": "markdown", "id": "7", "metadata": { "papermill": { "duration": 0.037767, "end_time": "2025-02-27T06:29:43.131459", "exception": false, "start_time": "2025-02-27T06:29:43.093692", "status": "completed" }, "tags": [] }, "source": [ "We need a model for each of those components to be able to simulate the circuit with SAX. Let's create some dummy models for now." ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": { "papermill": { "duration": 0.023523, "end_time": "2025-02-27T06:29:43.172148", "exception": false, "start_time": "2025-02-27T06:29:43.148625", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def mmi1x2():\n", " S = {\n", " (\"o1\", \"o2\"): 0.5**0.5,\n", " (\"o1\", \"o3\"): 0.5**0.5,\n", " }\n", " return sax.reciprocal(S)" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": { "papermill": { "duration": 0.027245, "end_time": "2025-02-27T06:29:43.215921", "exception": false, "start_time": "2025-02-27T06:29:43.188676", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def mmi2x2():\n", " S = {\n", " (\"o1\", \"o3\"): 0.5**0.5,\n", " (\"o1\", \"o4\"): 1j * 0.5**0.5,\n", " (\"o2\", \"o3\"): 1j * 0.5**0.5,\n", " (\"o2\", \"o4\"): 0.5**0.5,\n", " }\n", " return sax.reciprocal(S)" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": { "papermill": { "duration": 0.025944, "end_time": "2025-02-27T06:29:43.260399", "exception": false, "start_time": "2025-02-27T06:29:43.234455", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def straight(length=10.0, width=0.5):\n", " S = {(\"o1\", \"o2\"): 1.0} # we'll improve this model later!\n", " return sax.reciprocal(S)" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": { "papermill": { "duration": 0.02596, "end_time": "2025-02-27T06:29:43.304942", "exception": false, "start_time": "2025-02-27T06:29:43.278982", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def bend_euler(length=10.0, width=0.5, dy=10.0, radius_min=7, radius=10):\n", " return straight(length=length, width=width) # stub with straight for now" ] }, { "cell_type": "markdown", "id": "12", "metadata": { "papermill": { "duration": 0.018984, "end_time": "2025-02-27T06:29:43.342291", "exception": false, "start_time": "2025-02-27T06:29:43.323307", "status": "completed" }, "tags": [] }, "source": [ "Let's create a SAX circuit with our very simple placeholder models:" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": { "papermill": { "duration": 1.443977, "end_time": "2025-02-27T06:29:44.805215", "exception": false, "start_time": "2025-02-27T06:29:43.361238", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "models = {\n", " \"mmi1x2\": mmi1x2,\n", " \"mmi2x2\": mmi2x2,\n", " \"straight\": straight,\n", " \"bend_euler\": bend_euler,\n", "}\n", "mzi1, _ = sax.circuit(mzi.get_netlist(recursive=True), models=models)\n", "?mzi1" ] }, { "cell_type": "markdown", "id": "14", "metadata": { "papermill": { "duration": 0.008625, "end_time": "2025-02-27T06:29:44.821666", "exception": false, "start_time": "2025-02-27T06:29:44.813041", "status": "completed" }, "tags": [] }, "source": [ "the resulting circuit is just a model function on its own! Hence, calling it will give the result:" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": { "papermill": { "duration": 0.572123, "end_time": "2025-02-27T06:29:45.407163", "exception": false, "start_time": "2025-02-27T06:29:44.835040", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi1()" ] }, { "cell_type": "markdown", "id": "16", "metadata": { "papermill": { "duration": 0.009592, "end_time": "2025-02-27T06:29:45.424162", "exception": false, "start_time": "2025-02-27T06:29:45.414570", "status": "completed" }, "tags": [] }, "source": [ "## Waveguide Model\n", "\n", "Our waveguide model is not very good (it just has 100% transmission and no phase). Let's do something about the phase calculation. To do this, we need to find the effective index of the waveguide in relation to its parameters. We can use [meow](https://github.com/flaport/meow) to obtain the waveguide effective index. Let's first create a `find_waveguide_modes`:" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": { "papermill": { "duration": 0.029697, "end_time": "2025-02-27T06:29:45.462334", "exception": false, "start_time": "2025-02-27T06:29:45.432637", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def find_waveguide_modes(\n", " wl: float = 1.55,\n", " n_box: float = 1.4,\n", " n_clad: float = 1.4,\n", " n_core: float = 3.4,\n", " t_slab: float = 0.1,\n", " t_soi: float = 0.22,\n", " w_core: float = 0.45,\n", " du=0.02,\n", " n_modes: int = 10,\n", " cache_path: str = \"modes\",\n", " replace_cached: bool = False,\n", "):\n", " length = 10.0\n", " delta = 10 * du\n", " env = mw.Environment(wl=wl)\n", " cache_path = os.path.abspath(cache_path)\n", " os.makedirs(cache_path, exist_ok=True)\n", " fn = f\"{wl=:.2f}-{n_box=:.2f}-{n_clad=:.2f}-{n_core=:.2f}-{t_slab=:.3f}-{t_soi=:.3f}-{w_core=:.3f}-{du=:.3f}-{n_modes=}.json\"\n", " path = os.path.join(cache_path, fn)\n", " if not replace_cached and os.path.exists(path):\n", " return [mw.Mode.model_validate(mode) for mode in json.load(open(path, \"r\"))]\n", "\n", " # fmt: off\n", " m_core = mw.SampledMaterial(name=\"slab\", n=np.asarray([n_core, n_core]), params={\"wl\": np.asarray([1.0, 2.0])}, meta={\"color\": (0.9, 0, 0, 0.9)})\n", " m_clad = mw.SampledMaterial(name=\"clad\", n=np.asarray([n_clad, n_clad]), params={\"wl\": np.asarray([1.0, 2.0])}) \n", " m_box = mw.SampledMaterial(name=\"box\", n=np.asarray([n_box, n_box]), params={\"wl\": np.asarray([1.0, 2.0])})\n", " box = mw.Structure(material=m_box, geometry=mw.Box(x_min=- 2 * w_core - delta, x_max= 2 * w_core + delta, y_min=- 2 * t_soi - delta, y_max=0.0, z_min=0.0, z_max=length))\n", " slab = mw.Structure(material=m_core, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0.0, y_max=t_slab, z_min=0.0, z_max=length))\n", " clad = mw.Structure(material=m_clad, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0, y_max=3 * t_soi + delta, z_min=0.0, z_max=length))\n", " core = mw.Structure(material=m_core, geometry=mw.Box(x_min=-w_core / 2, x_max=w_core / 2, y_min=0.0, y_max=t_soi, z_min=0.0, z_max=length))\n", " \n", " cell = mw.Cell(structures=[box, clad, slab, core], mesh=mw.Mesh2D( x=np.arange(-2*w_core, 2*w_core, du), y=np.arange(-2*t_soi, 3*t_soi, du), ), z_min=0.0, z_max=10.0)\n", " cross_section = mw.CrossSection.from_cell(cell=cell, env=env)\n", " modes = mw.compute_modes(cross_section, num_modes=n_modes)\n", " # fmt: on\n", "\n", " json.dump([json.loads(mode.json()) for mode in modes], open(path, \"w\"))\n", "\n", " return modes" ] }, { "cell_type": "markdown", "id": "18", "metadata": { "papermill": { "duration": 0.008601, "end_time": "2025-02-27T06:29:45.517951", "exception": false, "start_time": "2025-02-27T06:29:45.509350", "status": "completed" }, "tags": [] }, "source": [ "We can also create a rudimentary model for the silicon refractive index:" ] }, { "cell_type": "code", "execution_count": null, "id": "19", "metadata": { "papermill": { "duration": 0.018172, "end_time": "2025-02-27T06:29:45.547141", "exception": false, "start_time": "2025-02-27T06:29:45.528969", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def silicon_index(wl):\n", " \"\"\"a rudimentary silicon refractive index model\"\"\"\n", " a, b = 0.2411478522088102, 3.3229394315868976\n", " return a / wl + b" ] }, { "cell_type": "markdown", "id": "20", "metadata": { "papermill": { "duration": 0.009357, "end_time": "2025-02-27T06:29:45.571308", "exception": false, "start_time": "2025-02-27T06:29:45.561951", "status": "completed" }, "tags": [] }, "source": [ "We can now easily calculate the modes of a strip waveguide:" ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": { "papermill": { "duration": 0.469329, "end_time": "2025-02-27T06:29:46.050795", "exception": false, "start_time": "2025-02-27T06:29:45.581466", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "modes = find_waveguide_modes(wl=1.5, n_core=silicon_index(wl=1.5))" ] }, { "cell_type": "markdown", "id": "22", "metadata": { "papermill": { "duration": 0.01447, "end_time": "2025-02-27T06:29:46.079882", "exception": false, "start_time": "2025-02-27T06:29:46.065412", "status": "completed" }, "tags": [] }, "source": [ "The fundamental mode is the mode with index 0:" ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": { "papermill": { "duration": 0.538399, "end_time": "2025-02-27T06:29:46.632781", "exception": false, "start_time": "2025-02-27T06:29:46.094382", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mw.visualize(modes[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "24", "metadata": { "papermill": { "duration": 66.968676, "end_time": "2025-02-27T06:30:53.617537", "exception": false, "start_time": "2025-02-27T06:29:46.648861", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "wavelengths, widths = np.mgrid[1.5:1.6:10j, 0.4:0.6:5j]\n", "neffs = np.zeros_like(wavelengths)\n", "neffs_ = neffs.ravel()\n", "\n", "for i, (wl, w) in enumerate(zip(tqdm(wavelengths.ravel()), widths.ravel())):\n", " modes = find_waveguide_modes(\n", " wl=wl, n_core=silicon_index(wl), w_core=w, replace_cached=False\n", " )\n", " neffs_[i] = np.real(modes[0].neff)" ] }, { "cell_type": "markdown", "id": "25", "metadata": { "papermill": { "duration": 0.007532, "end_time": "2025-02-27T06:30:53.632557", "exception": false, "start_time": "2025-02-27T06:30:53.625025", "status": "completed" }, "tags": [] }, "source": [ "This results in the following effective indices:" ] }, { "cell_type": "code", "execution_count": null, "id": "26", "metadata": { "papermill": { "duration": 0.125793, "end_time": "2025-02-27T06:30:53.766195", "exception": false, "start_time": "2025-02-27T06:30:53.640402", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "_wls = np.unique(wavelengths.ravel())\n", "_widths = np.unique(widths.ravel())\n", "plt.figure(figsize=(8, 3))\n", "plt.plot(_wls * 1000, neffs)\n", "plt.ylabel(\"neff\")\n", "plt.xlabel(\"λ [nm]\")\n", "plt.title(\"Effective Index\")\n", "plt.grid(True)\n", "plt.figlegend(\n", " [f\"{w=:.2f}um\" for w in _widths], ncol=len(widths), bbox_to_anchor=(0.95, -0.05)\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "27", "metadata": { "papermill": { "duration": 0.007852, "end_time": "2025-02-27T06:30:53.782352", "exception": false, "start_time": "2025-02-27T06:30:53.774500", "status": "completed" }, "tags": [] }, "source": [ "We can do a grid interpolation on those effective indices:" ] }, { "cell_type": "code", "execution_count": null, "id": "28", "metadata": { "papermill": { "duration": 0.557346, "end_time": "2025-02-27T06:30:54.348116", "exception": false, "start_time": "2025-02-27T06:30:53.790770", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "_grid = [jnp.sort(jnp.unique(wavelengths)), jnp.sort(jnp.unique(widths))]\n", "_data = jnp.asarray(neffs)\n", "\n", "\n", "@jax.jit\n", "def _get_coordinate(arr1d: jnp.ndarray, value: jnp.ndarray):\n", " return jnp.interp(value, arr1d, jnp.arange(arr1d.shape[0]))\n", "\n", "\n", "@jax.jit\n", "def _get_coordinates(arrs1d: List[jnp.ndarray], values: jnp.ndarray):\n", " # don't use vmap as arrays in arrs1d could have different shapes...\n", " return jnp.array([_get_coordinate(a, v) for a, v in zip(arrs1d, values)])\n", "\n", "\n", "@jax.jit\n", "def neff(wl=1.55, width=0.5):\n", " params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(wl), jnp.asarray(width)), 0)\n", " coords = _get_coordinates(_grid, params)\n", " return jax.scipy.ndimage.map_coordinates(_data, coords, 1, mode=\"nearest\")\n", "\n", "\n", "neff(wl=[1.52, 1.58], width=[0.5, 0.55])" ] }, { "cell_type": "code", "execution_count": null, "id": "29", "metadata": { "papermill": { "duration": 0.349859, "end_time": "2025-02-27T06:30:54.706739", "exception": false, "start_time": "2025-02-27T06:30:54.356880", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "wavelengths_ = np.linspace(wavelengths.min(), wavelengths.max(), 100)\n", "widths_ = np.linspace(widths.min(), widths.max(), 100)\n", "wavelengths_, widths_ = np.meshgrid(wavelengths_, widths_)\n", "neffs_ = neff(wavelengths_, widths_)\n", "\n", "plt.figure(figsize=(8, 3))\n", "plt.pcolormesh(wavelengths_ * 1000, widths_, neffs_)\n", "plt.ylabel(\"neff\")\n", "plt.xlabel(\"λ [nm]\")\n", "plt.title(\"Effective Index\")\n", "plt.grid(True)\n", "plt.figlegend(\n", " [f\"{w=:.2f}um\" for w in _widths], ncol=len(_widths), bbox_to_anchor=(0.95, -0.05)\n", ")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "30", "metadata": { "papermill": { "duration": 0.014129, "end_time": "2025-02-27T06:30:54.729575", "exception": false, "start_time": "2025-02-27T06:30:54.715446", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def straight(wl=1.55, length=10.0, width=0.5):\n", " S = {\n", " (\"o1\", \"o2\"): jnp.exp(2j * np.pi * neff(wl=wl, width=width) / wl * length),\n", " }\n", " return sax.reciprocal(S)" ] }, { "cell_type": "markdown", "id": "31", "metadata": { "papermill": { "duration": 0.008795, "end_time": "2025-02-27T06:30:54.747833", "exception": false, "start_time": "2025-02-27T06:30:54.739038", "status": "completed" }, "tags": [] }, "source": [ "Even though this still is lossless transmission, we're at least modeling the phase correctly." ] }, { "cell_type": "code", "execution_count": null, "id": "32", "metadata": { "papermill": { "duration": 0.21718, "end_time": "2025-02-27T06:30:54.974103", "exception": false, "start_time": "2025-02-27T06:30:54.756923", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "straight()" ] }, { "cell_type": "markdown", "id": "33", "metadata": { "papermill": { "duration": 0.00931, "end_time": "2025-02-27T06:30:54.993131", "exception": false, "start_time": "2025-02-27T06:30:54.983821", "status": "completed" }, "tags": [] }, "source": [ "## Simulate MZI again" ] }, { "cell_type": "code", "execution_count": null, "id": "34", "metadata": { "papermill": { "duration": 0.213078, "end_time": "2025-02-27T06:30:55.215355", "exception": false, "start_time": "2025-02-27T06:30:55.002277", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "models[\"straight\"] = straight\n", "mzi2, _ = sax.circuit(mzi.get_netlist(recursive=True), models=models)\n", "mzi2()" ] }, { "cell_type": "code", "execution_count": null, "id": "35", "metadata": { "papermill": { "duration": 0.886234, "end_time": "2025-02-27T06:30:56.110859", "exception": false, "start_time": "2025-02-27T06:30:55.224625", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "wl = jnp.linspace(1.51, 1.59, 1000)\n", "S = mzi2(wl=wl)\n", "plt.plot(wl, abs(S[\"o1\", \"o2\"]) ** 2)\n", "plt.ylim(-0.05, 1.05)\n", "plt.xlabel(\"λ [μm]\")\n", "plt.ylabel(\"T\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "36", "metadata": { "papermill": { "duration": 0.009973, "end_time": "2025-02-27T06:30:56.130580", "exception": false, "start_time": "2025-02-27T06:30:56.120607", "status": "completed" }, "tags": [] }, "source": [ "## Optimize MZI" ] }, { "cell_type": "markdown", "id": "37", "metadata": { "papermill": { "duration": 0.009859, "end_time": "2025-02-27T06:30:56.150411", "exception": false, "start_time": "2025-02-27T06:30:56.140552", "status": "completed" }, "tags": [] }, "source": [ "We'd like to optimize an MZI such that one of the minima is at 1550nm. To do this, we need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:" ] }, { "cell_type": "code", "execution_count": null, "id": "38", "metadata": { "papermill": { "duration": 0.015471, "end_time": "2025-02-27T06:30:56.175772", "exception": false, "start_time": "2025-02-27T06:30:56.160301", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "@jax.jit\n", "def loss_fn(delta_length):\n", " S = mzi2(\n", " wl=1.55,\n", " straight_top1={\"length\": delta_length / 2},\n", " straight_top2={\"length\": delta_length / 2},\n", " )\n", " return jnp.mean(jnp.abs(S[\"o1\", \"o2\"]) ** 2)" ] }, { "cell_type": "markdown", "id": "39", "metadata": { "papermill": { "duration": 0.00979, "end_time": "2025-02-27T06:30:56.194980", "exception": false, "start_time": "2025-02-27T06:30:56.185190", "status": "completed" }, "tags": [] }, "source": [ "We can use this loss function to define a grad function which works on the parameters of the loss function:" ] }, { "cell_type": "code", "execution_count": null, "id": "40", "metadata": { "papermill": { "duration": 0.015154, "end_time": "2025-02-27T06:30:56.220047", "exception": false, "start_time": "2025-02-27T06:30:56.204893", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "grad_fn = jax.jit(\n", " jax.grad(\n", " loss_fn,\n", " argnums=0, # JAX gradient function for the first positional argument, jitted\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "41", "metadata": { "papermill": { "duration": 0.009737, "end_time": "2025-02-27T06:30:56.239690", "exception": false, "start_time": "2025-02-27T06:30:56.229953", "status": "completed" }, "tags": [] }, "source": [ "Next, we need to define a JAX optimizer, which on its own is nothing more than three more functions: an initialization function with which to initialize the optimizer state, an update function which will update the optimizer state (and with it the model parameters). The third function that's being returned will give the model parameters given the optimizer state." ] }, { "cell_type": "code", "execution_count": null, "id": "42", "metadata": { "papermill": { "duration": 0.350362, "end_time": "2025-02-27T06:30:56.600103", "exception": false, "start_time": "2025-02-27T06:30:56.249741", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "loss_fn(20.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "43", "metadata": { "papermill": { "duration": 0.017903, "end_time": "2025-02-27T06:30:56.629478", "exception": false, "start_time": "2025-02-27T06:30:56.611575", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "initial_delta_length = 10.0\n", "init_fn, update_fn, params_fn = opt.adam(step_size=0.1)\n", "state = init_fn(initial_delta_length)" ] }, { "cell_type": "markdown", "id": "44", "metadata": { "papermill": { "duration": 0.019199, "end_time": "2025-02-27T06:30:56.662566", "exception": false, "start_time": "2025-02-27T06:30:56.643367", "status": "completed" }, "tags": [] }, "source": [ "Given all this, a single training step can be defined:" ] }, { "cell_type": "code", "execution_count": null, "id": "45", "metadata": { "papermill": { "duration": 0.023889, "end_time": "2025-02-27T06:30:56.698708", "exception": false, "start_time": "2025-02-27T06:30:56.674819", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def step_fn(step, state):\n", " params = params_fn(state)\n", " loss = loss_fn(params)\n", " grad = grad_fn(params)\n", " state = update_fn(step, grad, state)\n", " return loss, state" ] }, { "cell_type": "markdown", "id": "46", "metadata": { "papermill": { "duration": 0.012545, "end_time": "2025-02-27T06:30:56.737993", "exception": false, "start_time": "2025-02-27T06:30:56.725448", "status": "completed" }, "tags": [] }, "source": [ "And we can use this step function to start the training of the MZI:" ] }, { "cell_type": "code", "execution_count": null, "id": "47", "metadata": { "papermill": { "duration": 7.337535, "end_time": "2025-02-27T06:31:04.087878", "exception": false, "start_time": "2025-02-27T06:30:56.750343", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "for step in (\n", " pb := trange(300)\n", "): # the first two iterations take a while because the circuit is being jitted...\n", " loss, state = step_fn(step, state)\n", " pb.set_postfix(loss=f\"{loss:.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "48", "metadata": { "papermill": { "duration": 0.016168, "end_time": "2025-02-27T06:31:04.114897", "exception": false, "start_time": "2025-02-27T06:31:04.098729", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "delta_length = params_fn(state)\n", "delta_length" ] }, { "cell_type": "markdown", "id": "49", "metadata": { "papermill": { "duration": 0.014745, "end_time": "2025-02-27T06:31:04.140420", "exception": false, "start_time": "2025-02-27T06:31:04.125675", "status": "completed" }, "tags": [] }, "source": [ "Let's see what we've got over a range of wavelengths:" ] }, { "cell_type": "code", "execution_count": null, "id": "50", "metadata": { "papermill": { "duration": 0.166027, "end_time": "2025-02-27T06:31:04.315845", "exception": false, "start_time": "2025-02-27T06:31:04.149818", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "S = mzi2(\n", " wl=wl,\n", " straight_top1={\"length\": delta_length / 2},\n", " straight_top2={\"length\": delta_length / 2},\n", ")\n", "plt.plot(wl * 1e3, abs(S[\"o1\", \"o2\"]) ** 2)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.plot([1550, 1550], [-1, 2], ls=\":\", color=\"black\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "51", "metadata": { "papermill": { "duration": 0.0092, "end_time": "2025-02-27T06:31:04.334686", "exception": false, "start_time": "2025-02-27T06:31:04.325486", "status": "completed" }, "tags": [] }, "source": [ "Note that we could've just as well optimized the waveguide width:" ] }, { "cell_type": "code", "execution_count": null, "id": "52", "metadata": { "papermill": { "duration": 2.816105, "end_time": "2025-02-27T06:31:07.160592", "exception": false, "start_time": "2025-02-27T06:31:04.344487", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "@jax.jit\n", "def loss_fn(width):\n", " S = mzi2(\n", " wl=1.55,\n", " straight_top1={\"width\": width},\n", " straight_top2={\"width\": width},\n", " )\n", " return jnp.mean(jnp.abs(S[\"o1\", \"o2\"]) ** 2)\n", "\n", "\n", "grad_fn = jax.jit(\n", " jax.grad(\n", " loss_fn,\n", " argnums=0, # JAX gradient function for the first positional argument, jitted\n", " )\n", ")\n", "initial_width = 0.5\n", "init_fn, update_fn, params_fn = opt.adam(step_size=0.01)\n", "state = init_fn(initial_width)\n", "for step in (\n", " pb := trange(300)\n", "): # the first two iterations take a while because the circuit is being jitted...\n", " loss, state = step_fn(step, state)\n", " pb.set_postfix(loss=f\"{loss:.6f}\")\n", "\n", "optim_width = params_fn(state)\n", "S = Sw = mzi2(\n", " wl=wl,\n", " straight_top1={\"width\": optim_width},\n", " straight_top2={\"width\": optim_width},\n", ")\n", "plt.plot(wl * 1e3, abs(S[\"o1\", \"o2\"]) ** 2)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.plot([1550, 1550], [-1, 2], color=\"black\", ls=\":\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "53", "metadata": { "papermill": { "duration": 0.010732, "end_time": "2025-02-27T06:31:07.181930", "exception": false, "start_time": "2025-02-27T06:31:07.171198", "status": "completed" }, "tags": [] }, "source": [ "## Layout-aware Monte Carlo\n", "\n", "Let's assume the waveguide width changes with a certain correlation length. We can create a 'wafermap' of width variations by randomly varying the width and low pass filtering with a spatial frequency being the inverse of the correlation length (there are probably better ways to do this, but this works for this tutorial)." ] }, { "cell_type": "code", "execution_count": null, "id": "54", "metadata": { "papermill": { "duration": 0.02461, "end_time": "2025-02-27T06:31:07.217540", "exception": false, "start_time": "2025-02-27T06:31:07.192930", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def create_wafermaps(\n", " placements, correlation_length=1.0, num_maps=1, mean=0.0, std=1.0, seed=None\n", "):\n", " dx = dy = correlation_length / 200\n", " xs = [p[\"x\"] for p in placements.values()]\n", " ys = [p[\"y\"] for p in placements.values()]\n", " xmin, xmax, ymin, ymax = min(xs), max(xs), min(ys), max(ys)\n", " wx, wy = xmax - xmin, ymax - ymin\n", " xmin, xmax, ymin, ymax = xmin - wx, xmax + wx, ymin - wy, ymax + wy\n", " x, y = np.arange(xmin, xmax + dx, dx), np.arange(ymin, ymax + dy, dy)\n", " if seed is None:\n", " r = np.random\n", " else:\n", " r = np.random.RandomState(seed=seed)\n", " W0 = r.randn(num_maps, x.shape[0], y.shape[0])\n", "\n", " fx = fftshift(fftfreq(x.shape[0], d=x[1] - x[0]))\n", " fy = fftshift(fftfreq(y.shape[0], d=y[1] - y[0]))\n", " fY, fX = np.meshgrid(fy, fx)\n", "\n", " fW = fftshift(fft2(W0))\n", "\n", " if correlation_length >= min(x.shape[0], y.shape[0]):\n", " fW = np.zeros_like(fW)\n", " else:\n", " fW = np.where(np.sqrt(fX**2 + fY**2)[None] > 1 / correlation_length, 0, fW)\n", "\n", " W = np.abs(fftshift(ifft2(fW))) ** 2\n", "\n", " mean_ = W.mean(1, keepdims=True).mean(2, keepdims=True)\n", "\n", " std_ = W.std(1, keepdims=True).std(2, keepdims=True)\n", " if (std_ == 0).all():\n", " std_ = 1\n", "\n", " W = (W - mean_) / std_\n", "\n", " W = W * std + mean\n", "\n", " return x, y, W" ] }, { "cell_type": "code", "execution_count": null, "id": "55", "metadata": { "papermill": { "duration": 2.460188, "end_time": "2025-02-27T06:31:09.688154", "exception": false, "start_time": "2025-02-27T06:31:07.227966", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "placements = mzi.get_netlist()[\"placements\"]\n", "xm, ym, wmaps = create_wafermaps(\n", " placements,\n", " correlation_length=100,\n", " mean=0.5,\n", " std=0.002,\n", " num_maps=100,\n", " seed=42,\n", ")\n", "\n", "for i, wmap in enumerate(wmaps):\n", " if i > 1:\n", " break\n", " plt.imshow(wmap, cmap=\"RdBu\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "56", "metadata": { "papermill": { "duration": 0.054387, "end_time": "2025-02-27T06:31:09.754722", "exception": false, "start_time": "2025-02-27T06:31:09.700335", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def widths(xw, yw, wmaps, x, y):\n", " _wmap_grid = [xw, yw]\n", " params = jnp.stack(jnp.broadcast_arrays(jnp.asarray(x), jnp.asarray(y)), 0)\n", " coords = _get_coordinates(_wmap_grid, params)\n", "\n", " map_coordinates = partial(\n", " jax.scipy.ndimage.map_coordinates, coordinates=coords, order=1, mode=\"nearest\"\n", " )\n", " w = jax.vmap(map_coordinates)(wmaps)\n", " return w" ] }, { "cell_type": "markdown", "id": "57", "metadata": { "papermill": { "duration": 0.018257, "end_time": "2025-02-27T06:31:09.817099", "exception": false, "start_time": "2025-02-27T06:31:09.798842", "status": "completed" }, "tags": [] }, "source": [ "Let's now sample the MZI width variation on the wafer map (let's assume a single width variation per point):" ] }, { "cell_type": "code", "execution_count": null, "id": "58", "metadata": { "papermill": { "duration": 3.203972, "end_time": "2025-02-27T06:31:13.035998", "exception": false, "start_time": "2025-02-27T06:31:09.832026", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi_params = sax.get_settings(mzi2)\n", "placements = mzi.get_netlist()[\"placements\"]\n", "width_params = {\n", " k: {\"width\": widths(xm, ym, wmaps, v[\"x\"], v[\"y\"])}\n", " for k, v in placements.items()\n", " if \"width\" in mzi_params[k]\n", "}\n", "\n", "S0 = mzi2(wl=wl)\n", "S = mzi2(\n", " wl=wl[:, None],\n", " **width_params,\n", ")\n", "ps = plt.plot(wl * 1e3, abs(S[\"o1\", \"o2\"]) ** 2, color=\"C0\", lw=1, alpha=0.1)\n", "nps = plt.plot(wl * 1e3, abs(S0[\"o1\", \"o2\"]) ** 2, color=\"C1\", lw=2, alpha=1)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.plot([1550, 1550], [-1, 2], color=\"black\", ls=\":\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.grid(True)\n", "plt.figlegend([*ps[-1:], *nps], [\"MC\", \"nominal\"], bbox_to_anchor=(1.1, 0.9))\n", "rmse = jnp.mean(\n", " jnp.abs(jnp.abs(S[\"o1\", \"o2\"]) ** 2 - jnp.abs(S0[\"o1\", \"o2\"][:, None]) ** 2) ** 2\n", ")\n", "plt.title(f\"{rmse=}\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "59", "metadata": { "papermill": { "duration": 0.579178, "end_time": "2025-02-27T06:31:13.629880", "exception": false, "start_time": "2025-02-27T06:31:13.050702", "status": "completed" }, "tags": [] }, "source": [ "## Compact MZI\n", "\n", "Let's see if we can improve variability (i.e. the RMSE w.r.t. nominal) by making the MZI more compact:" ] }, { "cell_type": "code", "execution_count": null, "id": "60", "metadata": { "papermill": { "duration": 0.593916, "end_time": "2025-02-27T06:31:14.690710", "exception": false, "start_time": "2025-02-27T06:31:14.096794", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "@gf.cell\n", "def compact_mzi():\n", " c = gf.Component()\n", "\n", " # instances\n", " mmi_in = gf.components.mmi1x2()\n", " mmi_out = gf.components.mmi2x2()\n", " bend = gf.components.bend_euler()\n", " half_delay_straight = gf.components.straight()\n", " middle_straight = gf.components.straight(length=6.0)\n", " half_middle_straight = gf.components.straight(3.0)\n", "\n", " # references (sax convention: vars ending in underscore are references)\n", " mmi_in = c.add_ref(mmi_in, name=\"mmi_in\")\n", "\n", " bend_top1 = c.add_ref(bend, name=\"bend_top1\")\n", " straight_top1 = c.add_ref(half_delay_straight, name=\"straight_top1\")\n", " bend_top2 = c.add_ref(bend, name=\"bend_top2\").dmirror()\n", " straight_top2 = c.add_ref(middle_straight, name=\"straight_top2\")\n", " bend_top3 = c.add_ref(bend, name=\"bend_top3\").dmirror()\n", " straight_top3 = c.add_ref(half_delay_straight, name=\"straight_top3\")\n", " bend_top4 = c.add_ref(bend, name=\"bend_top4\")\n", "\n", " straight_btm1 = c.add_ref(half_middle_straight, name=\"straight_btm1\")\n", " bend_btm1 = c.add_ref(bend, name=\"bend_btm1\")\n", " bend_btm2 = c.add_ref(bend, name=\"bend_btm2\").dmirror()\n", " bend_btm3 = c.add_ref(bend, name=\"bend_btm3\").dmirror()\n", " bend_btm4 = c.add_ref(bend, name=\"bend_btm4\")\n", " straight_btm2 = c.add_ref(half_middle_straight, name=\"straight_btm2\")\n", "\n", " mmi_out = c.add_ref(mmi_out, name=\"mmi_out\")\n", "\n", " # connections\n", " bend_top1.connect(\"o1\", mmi_in.ports[\"o2\"])\n", " straight_top1.connect(\"o1\", bend_top1.ports[\"o2\"])\n", " bend_top2.connect(\"o1\", straight_top1.ports[\"o2\"])\n", " straight_top2.connect(\"o1\", bend_top2.ports[\"o2\"])\n", " bend_top3.connect(\"o1\", straight_top2.ports[\"o2\"])\n", " straight_top3.connect(\"o1\", bend_top3.ports[\"o2\"])\n", " bend_top4.connect(\"o1\", straight_top3.ports[\"o2\"])\n", "\n", " straight_btm1.connect(\"o1\", mmi_in.ports[\"o3\"])\n", " bend_btm1.connect(\"o1\", straight_btm1.ports[\"o2\"])\n", " bend_btm2.connect(\"o1\", bend_btm1.ports[\"o2\"])\n", " bend_btm3.connect(\"o1\", bend_btm2.ports[\"o2\"])\n", " bend_btm4.connect(\"o1\", bend_btm3.ports[\"o2\"])\n", " straight_btm2.connect(\"o1\", bend_btm4.ports[\"o2\"])\n", "\n", " mmi_out.connect(\"o1\", straight_btm2.ports[\"o2\"])\n", "\n", " # ports\n", " c.add_port(\n", " \"o1\",\n", " port=mmi_in.ports[\"o1\"],\n", " )\n", " c.add_port(\"o2\", port=mmi_out.ports[\"o3\"])\n", " c.add_port(\"o3\", port=mmi_out.ports[\"o4\"])\n", " return c" ] }, { "cell_type": "code", "execution_count": null, "id": "61", "metadata": { "papermill": { "duration": 0.631483, "end_time": "2025-02-27T06:31:15.536726", "exception": false, "start_time": "2025-02-27T06:31:14.905243", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "compact_mzi1 = compact_mzi()\n", "compact_mzi1" ] }, { "cell_type": "code", "execution_count": null, "id": "62", "metadata": { "papermill": { "duration": 1.723087, "end_time": "2025-02-27T06:31:17.560265", "exception": false, "start_time": "2025-02-27T06:31:15.837178", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "placements = compact_mzi1.get_netlist()[\"placements\"]\n", "mzi3, _ = sax.circuit(compact_mzi1.get_netlist(recursive=True), models=models)\n", "mzi3()" ] }, { "cell_type": "code", "execution_count": null, "id": "63", "metadata": { "papermill": { "duration": 2.611084, "end_time": "2025-02-27T06:31:20.585361", "exception": false, "start_time": "2025-02-27T06:31:17.974277", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "mzi_params = sax.get_settings(mzi3)\n", "placements = compact_mzi1.get_netlist()[\"placements\"]\n", "width_params = {\n", " k: {\"width\": widths(xm, ym, wmaps, v[\"x\"], v[\"y\"])}\n", " for k, v in placements.items()\n", " if \"width\" in mzi_params[k]\n", "}\n", "\n", "S0 = mzi3(wl=wl)\n", "S = mzi3(\n", " wl=wl[:, None],\n", " **width_params,\n", ")\n", "ps = plt.plot(wl * 1e3, abs(S[\"o1\", \"o2\"]) ** 2, color=\"C0\", lw=1, alpha=0.1)\n", "nps = plt.plot(wl * 1e3, abs(S0[\"o1\", \"o2\"]) ** 2, color=\"C1\", lw=2, alpha=1)\n", "plt.xlabel(\"λ [nm]\")\n", "plt.ylabel(\"T\")\n", "plt.plot([1550, 1550], [-1, 2], color=\"black\", ls=\":\")\n", "plt.ylim(-0.05, 1.05)\n", "plt.grid(True)\n", "plt.figlegend([*ps[-1:], *nps], [\"MC\", \"nominal\"], bbox_to_anchor=(1.1, 0.9))\n", "rmse = jnp.mean(\n", " jnp.abs(jnp.abs(S[\"o1\", \"o2\"]) ** 2 - jnp.abs(S0[\"o1\", \"o2\"][:, None]) ** 2) ** 2\n", ")\n", "plt.title(f\"{rmse=}\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 5 }