{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {
    "papermill": {
     "duration": 0.005885,
     "end_time": "2024-06-17T18:16:32.123653",
     "exception": false,
     "start_time": "2024-06-17T18:16:32.117768",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Multimode simulations\n",
    "> SAX can handle multiple modes too!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {
    "papermill": {
     "duration": 1.621421,
     "end_time": "2024-06-17T18:16:33.750679",
     "exception": false,
     "start_time": "2024-06-17T18:16:32.129258",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from itertools import combinations_with_replacement, product\n",
    "\n",
    "import jax.numpy as jnp\n",
    "import sax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {
    "papermill": {
     "duration": 0.005148,
     "end_time": "2024-06-17T18:16:33.761216",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.756068",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Ports and modes per port"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {
    "papermill": {
     "duration": 0.005268,
     "end_time": "2024-06-17T18:16:33.771516",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.766248",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Let's denote a combination of a port and a mode by a string of the following format: `\"{port}@{mode}\"`. We can obtain all possible port-mode combinations with some magic itertools functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {
    "papermill": {
     "duration": 0.023678,
     "end_time": "2024-06-17T18:16:33.800511",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.776833",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "ports = [\"in0\", \"out0\"]\n",
    "modes = [\"TE\", \"TM\"]\n",
    "portmodes = [\n",
    "    (f\"{p1}@{m1}\", f\"{p2}@{m2}\")\n",
    "    for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)\n",
    "]\n",
    "portmodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5",
   "metadata": {
    "papermill": {
     "duration": 0.006864,
     "end_time": "2024-06-17T18:16:33.813076",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.806212",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "If we would disregard any backreflection, this can be further simplified:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {
    "papermill": {
     "duration": 0.01988,
     "end_time": "2024-06-17T18:16:33.838439",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.818559",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "portmodes_without_backreflection = [\n",
    "    (p1, p2) for p1, p2 in portmodes if p1.split(\"@\")[0] != p2.split(\"@\")[0]\n",
    "]\n",
    "portmodes_without_backreflection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {
    "papermill": {
     "duration": 0.055422,
     "end_time": "2024-06-17T18:16:33.899178",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.843756",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Sometimes cross-polarization terms can also be ignored:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {
    "papermill": {
     "duration": 0.02547,
     "end_time": "2024-06-17T18:16:33.930968",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.905498",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "portmodes_without_crosspolarization = [\n",
    "    (p1, p2) for p1, p2 in portmodes if p1.split(\"@\")[1] == p2.split(\"@\")[1]\n",
    "]\n",
    "portmodes_without_crosspolarization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {
    "papermill": {
     "duration": 0.006752,
     "end_time": "2024-06-17T18:16:33.943557",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.936805",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Multimode waveguide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {
    "papermill": {
     "duration": 0.005521,
     "end_time": "2024-06-17T18:16:33.954804",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.949283",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Let's create a waveguide with two ports (`\"in\"`, `\"out\"`) and two modes (`\"te\"`, `\"tm\"`) without backreflection. Let's assume there is 5% cross-polarization and that the `\"tm\"`->`\"tm\"` transmission is 10% worse than the `\"te\"`->`\"te\"` transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {
    "papermill": {
     "duration": 0.10532,
     "end_time": "2024-06-17T18:16:34.065733",
     "exception": false,
     "start_time": "2024-06-17T18:16:33.960413",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):\n",
    "    \"\"\"a simple straight waveguide model\n",
    "\n",
    "    Args:\n",
    "        wl: wavelength\n",
    "        neff: waveguide effective index\n",
    "        ng: waveguide group index (used for linear neff dispersion)\n",
    "        wl0: center wavelength at which neff is defined\n",
    "        length: [m] wavelength length\n",
    "        loss: [dB/m] waveguide loss\n",
    "    \"\"\"\n",
    "    dwl = wl - wl0\n",
    "    dneff_dwl = (ng - neff) / wl0\n",
    "    neff = neff - dwl * dneff_dwl\n",
    "    phase = 2 * jnp.pi * neff * length / wl\n",
    "    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)\n",
    "    sdict = sax.reciprocal(\n",
    "        {\n",
    "            (\"in0@TE\", \"out0@TE\"): 0.95 * transmission,  # 5% lost to cross-polarization\n",
    "            (\"in0@TE\", \"out0@TM\"): 0.05 * transmission,  # 5% cross-polarization\n",
    "            (\"in0@TM\", \"out0@TM\"): 0.85 * transmission,  # 10% worse tm->tm than te->te\n",
    "            (\"in0@TM\", \"out0@TE\"): 0.05 * transmission,  # 5% cross-polarization\n",
    "        }\n",
    "    )\n",
    "    return sdict\n",
    "\n",
    "\n",
    "waveguide()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {
    "papermill": {
     "duration": 0.005208,
     "end_time": "2024-06-17T18:16:34.076402",
     "exception": false,
     "start_time": "2024-06-17T18:16:34.071194",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Multimode Coupler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {
    "papermill": {
     "duration": 0.016482,
     "end_time": "2024-06-17T18:16:34.098324",
     "exception": false,
     "start_time": "2024-06-17T18:16:34.081842",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def coupler():\n",
    "    return {\n",
    "        (\"in0@TE\", \"out0@TE\"): 0.45**0.5,\n",
    "        (\"in0@TE\", \"out1@TE\"): 1j * 0.45**0.5,\n",
    "        (\"in1@TE\", \"out0@TE\"): 1j * 0.45**0.5,\n",
    "        (\"in1@TE\", \"out1@TE\"): 0.45**0.5,\n",
    "        (\"in0@TM\", \"out0@TM\"): 0.45**0.5,\n",
    "        (\"in0@TM\", \"out1@TM\"): 1j * 0.45**0.5,\n",
    "        (\"in1@TM\", \"out0@TM\"): 1j * 0.45**0.5,\n",
    "        (\"in1@TM\", \"out1@TM\"): 0.45**0.5,\n",
    "        (\"in0@TE\", \"out0@TM\"): 0.01**0.5,\n",
    "        (\"in0@TE\", \"out1@TM\"): 1j * 0.01**0.5,\n",
    "        (\"in1@TE\", \"out0@TM\"): 1j * 0.01**0.5,\n",
    "        (\"in1@TE\", \"out1@TM\"): 0.01**0.5,\n",
    "        (\"in0@TM\", \"out0@TE\"): 0.01**0.5,\n",
    "        (\"in0@TM\", \"out1@TE\"): 1j * 0.01**0.5,\n",
    "        (\"in1@TM\", \"out0@TE\"): 1j * 0.01**0.5,\n",
    "        (\"in1@TM\", \"out1@TE\"): 0.01**0.5,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {
    "papermill": {
     "duration": 0.00622,
     "end_time": "2024-06-17T18:16:34.110226",
     "exception": false,
     "start_time": "2024-06-17T18:16:34.104006",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Multimode MZI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {
    "papermill": {
     "duration": 0.005893,
     "end_time": "2024-06-17T18:16:34.122417",
     "exception": false,
     "start_time": "2024-06-17T18:16:34.116524",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "We can now combine these models into a circuit in much the same way as before. We just need to add the `modes=` keyword:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {
    "papermill": {
     "duration": 1.178067,
     "end_time": "2024-06-17T18:16:35.306342",
     "exception": false,
     "start_time": "2024-06-17T18:16:34.128275",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "mzi, _ = sax.circuit(\n",
    "    netlist={\n",
    "        \"instances\": {\n",
    "            \"lft\": \"coupler\",  # single mode models will be automatically converted to multimode models without cross polarization.\n",
    "            \"top\": {\"component\": \"straight\", \"settings\": {\"length\": 25.0}},\n",
    "            \"btm\": {\"component\": \"straight\", \"settings\": {\"length\": 15.0}},\n",
    "            \"rgt\": \"coupler\",  # single mode models will be automatically converted to multimode models without cross polarization.\n",
    "        },\n",
    "        \"connections\": {\n",
    "            \"lft,out0\": \"btm,in0\",\n",
    "            \"btm,out0\": \"rgt,in0\",\n",
    "            \"lft,out1\": \"top,in0\",\n",
    "            \"top,out0\": \"rgt,in1\",\n",
    "        },\n",
    "        \"ports\": {\n",
    "            \"in0\": \"lft,in0\",\n",
    "            \"in1\": \"lft,in1\",\n",
    "            \"out0\": \"rgt,out0\",\n",
    "            \"out1\": \"rgt,out1\",\n",
    "        },\n",
    "    },\n",
    "    models={\n",
    "        \"coupler\": coupler,\n",
    "        \"straight\": waveguide,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {
    "papermill": {
     "duration": 0.612689,
     "end_time": "2024-06-17T18:16:35.925392",
     "exception": false,
     "start_time": "2024-06-17T18:16:35.312703",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "mzi()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {
    "papermill": {
     "duration": 0.006007,
     "end_time": "2024-06-17T18:16:35.937349",
     "exception": false,
     "start_time": "2024-06-17T18:16:35.931342",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "we can convert this model back to a singlemode `SDict` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {
    "papermill": {
     "duration": 0.095986,
     "end_time": "2024-06-17T18:16:36.039466",
     "exception": false,
     "start_time": "2024-06-17T18:16:35.943480",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "mzi_te = sax.singlemode(mzi, mode=\"TE\")\n",
    "mzi_te()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}