{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "```@meta\n",
    "Draft = false\n",
    "```\n",
    "# Multi-threaded assembly"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Introduction\n",
    "\n",
    "In this howto we will explore how to use task based multithreading (shared memory\n",
    "parallelism) to speed up the analysis. Some parts of a finite element simulation are\n",
    "trivially parallelizable such as the computation of the local element contributions since\n",
    "each element can be processed independently. However, two things need to be considered in\n",
    "order to parallelize safely:\n",
    "\n",
    " - **Modification of shared data**: Although the contributions from all the elements can\n",
    "   be computed independently, eventually they need to be assembled into the global\n",
    "   matrix and vector. Letting each task assemble their own contribution would lead to\n",
    "   race conditions since elements share degrees of freedom with each other. There are\n",
    "   various ways to remedy this, for example:\n",
    "    - **Locking**: By using a lock around the call to `assemble!` we can ensure that only\n",
    "      one task assembles at a time. This is simple to implement but can lead to lock\n",
    "      contention and thus poor performance. Another drawback is that the results will not\n",
    "      be deterministic since floating point operations are neither associative nor\n",
    "      commutative.\n",
    "    - **Assembler task**: By using a designated task for the assembling we (obviously)\n",
    "      ensure that only a single task assembles. The worker tasks (the tasks computing the\n",
    "      element contributions) would then hand off their results to the assemly task. This\n",
    "      can be a useful approach if computing the element contributions is much slower than\n",
    "      the assembly -- otherwise the assembler task can't keep up with the worker tasks.\n",
    "      There might also be some extra overhead because of task switching in the scheduler.\n",
    "      The problem with non-deterministic results still remains.\n",
    "    - **Grid coloring**: By \"coloring\" the grid such that, within each color, no two\n",
    "      elements share degrees of freedom, we can safely assemble each color in parallel.\n",
    "      Even if concurrently running tasks will write to the global matrix and vector they\n",
    "      will not write to the same memory locations. Note also that this procedure gives\n",
    "      predictable results because for a memory location which, for example, a \"red\",\n",
    "      a \"blue\", and a \"green\" element will contribute to we will always add the red first,\n",
    "      then the blue, and finally the green.\n",
    " - **Scratch data**: In order to speed up the computation of the element contributions we\n",
    "   typically pre-allocate some data structures that can be reused for every element. Such\n",
    "   scratch data include, for example, the local matrix and vector, and the CellValues.\n",
    "   Each task need their own copy of the scratch data since they will be modified for each\n",
    "   element."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Grid coloring\n",
    "\n",
    "Ferrite include functionality to color the grid with the `create_coloring`\n",
    "function. Here we create a simple 2D grid, color it, and export the colors to a VTK file\n",
    "to visualize the result (see *Figure 1*.). Note that no cells with the same color has any\n",
    "shared nodes (dofs). This means that it is safe to assemble in parallel as long as we only\n",
    "assemble one color at a time.\n",
    "\n",
    "There are two coloring algorithms implemented: the \"workstream\" algorithm (from Turcksin\n",
    "et al. [Turcksin2016](@cite)) and a \"greedy\" algorithm. For this structured grid the\n",
    "greedy algorithm uses fewer colors, but both algorithms result in colors that contain\n",
    "roughly the same number of elements. The workstream algorithm is the default one since it\n",
    "in general results in more balanced colors. For unstructured grids the greedy algorithm\n",
    "can result in colors with very few elements, for example."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "using Ferrite, SparseArrays\n",
    "\n",
    "function create_example_2d_grid()\n",
    "    grid = generate_grid(Quadrilateral, (10, 10), Vec{2}((0.0, 0.0)), Vec{2}((10.0, 10.0)))\n",
    "    colors_workstream = create_coloring(grid; alg = ColoringAlgorithm.WorkStream)\n",
    "    colors_greedy = create_coloring(grid; alg = ColoringAlgorithm.Greedy)\n",
    "    VTKGridFile(\"colored\", grid) do vtk\n",
    "        Ferrite.write_cell_colors(vtk, grid, colors_workstream, \"workstream-coloring\")\n",
    "        Ferrite.write_cell_colors(vtk, grid, colors_greedy, \"greedy-coloring\")\n",
    "    end\n",
    "    return\n",
    "end\n",
    "\n",
    "create_example_2d_grid()"
   ],
   "metadata": {},
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "![](coloring.png)\n",
    "\n",
    "*Figure 1*: Element coloring using the \"workstream\"-algorithm (left) and the \"greedy\"-\n",
    "algorithm (right)."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Multithreaded assembly of a cantilever beam in 3D\n",
    "\n",
    "We will now look at an example where we assemble the stiffness matrix and right hand side\n",
    "using multiple threads. The problem setup is a cantilever beam in 3D with a linear elastic\n",
    "material behavior. For this exercise we only focus on the multithreading and are not\n",
    "bothered with boundary conditions. For more details refer to the [tutorial on linear\n",
    "elasticity](../tutorials/linear_elasticity.md)."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Setup\n",
    "\n",
    "We define the element routine, material stiffness, grid and DofHandler just like in the\n",
    "[tutorial on linear elasticity](../tutorials/linear_elasticity.md) without discussing it\n",
    "further here."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "# Element routine\n",
    "function assemble_cell!(Ke::Matrix, fe::Vector, cellvalues::CellValues, C::SymmetricTensor, b::Vec)\n",
    "    fill!(Ke, 0)\n",
    "    fill!(fe, 0)\n",
    "    for q_point in 1:getnquadpoints(cellvalues)\n",
    "        dΩ = getdetJdV(cellvalues, q_point)\n",
    "        for i in 1:getnbasefunctions(cellvalues)\n",
    "            δui = shape_value(cellvalues, q_point, i)\n",
    "            fe[i] += (δui ⋅ b) * dΩ\n",
    "            ∇δui = shape_symmetric_gradient(cellvalues, q_point, i)\n",
    "            for j in 1:getnbasefunctions(cellvalues)\n",
    "                ∇uj = shape_symmetric_gradient(cellvalues, q_point, j)\n",
    "                Ke[i, j] += (∇δui ⊡ C ⊡ ∇uj) * dΩ\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "    return Ke, fe\n",
    "end\n",
    "\n",
    "# Material stiffness\n",
    "function create_material_stiffness()\n",
    "    E = 200.0e9\n",
    "    ν = 0.3\n",
    "    λ = E * ν / ((1 + ν) * (1 - 2ν))\n",
    "    μ = E / (2(1 + ν))\n",
    "    δ(i, j) = i == j ? 1.0 : 0.0\n",
    "    C = SymmetricTensor{4, 3}() do i, j, k, l\n",
    "        return λ * δ(i, j) * δ(k, l) + μ * (δ(i, k) * δ(j, l) + δ(i, l) * δ(j, k))\n",
    "    end\n",
    "    return C\n",
    "end\n",
    "\n",
    "# Grid and grid coloring\n",
    "function create_cantilever_grid(n::Int)\n",
    "    xmin = Vec{3}((0.0, 0.0, 0.0))\n",
    "    xmax = Vec{3}((10.0, 1.0, 1.0))\n",
    "    grid = generate_grid(Hexahedron, (10 * n, n, n), xmin, xmax)\n",
    "    colors = create_coloring(grid)\n",
    "    return grid, colors\n",
    "end\n",
    "\n",
    "# DofHandler with displacement field u\n",
    "function create_dofhandler(grid::Grid, interpolation::VectorInterpolation)\n",
    "    dh = DofHandler(grid)\n",
    "    add!(dh, :u, interpolation)\n",
    "    close!(dh)\n",
    "    return dh\n",
    "end\n",
    "nothing # hide"
   ],
   "metadata": {},
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Task local scratch data\n",
    "\n",
    "We group everything that needs to be duplicated for each task in the struct\n",
    "`ScratchData`:\n",
    " - `cell_cache::CellCache`: contain buffers for coordinates and (global) dofs which will\n",
    "   be `reinit!`ed for each cell.\n",
    " - `cellvalues::CellValues`: the cell values which will be `reinit!`ed for each cell using\n",
    "   the `cell_cache`\n",
    " - `Ke::Matrix`: the local matrix\n",
    " - `fe::Vector`: the local vector\n",
    " - `assembler`: the assembler (which needs to be duplicated because it contains buffers\n",
    "   that are modified during the call to `assemble!`)"
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "struct ScratchData{CC, CV, T, A}\n",
    "    cell_cache::CC\n",
    "    cellvalues::CV\n",
    "    Ke::Matrix{T}\n",
    "    fe::Vector{T}\n",
    "    assembler::A\n",
    "end"
   ],
   "metadata": {},
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "This constructor will be called within each task to create a independent `ScratchData`\n",
    "object. For `cell_cache`, `Ke`, and `fe` we simply call the constructors to allocate\n",
    "independent objects. For `cellvalues` we use `copy` which Ferrite defines for this\n",
    "purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but\n",
    "note that we set `fillzero = false` because we don't want to risk that a task that starts\n",
    "a bit later will zero out data that another task have already assembled."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues)\n",
    "    cell_cache = CellCache(dh)\n",
    "    n = ndofs_per_cell(dh)\n",
    "    Ke = zeros(n, n)\n",
    "    fe = zeros(n)\n",
    "    asm = start_assemble(K, f; fillzero = false)\n",
    "    return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm)\n",
    "end\n",
    "nothing # hide"
   ],
   "metadata": {},
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Global assembly routine"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Finally we define the global assemble routine, which is where the parallelization happens.\n",
    "The main difference from all previous `assemble_global!` functions is that we now have an\n",
    "outer loop over the colors, and then the inner loop over the cells in each color, which\n",
    "can be parallelized.\n",
    "\n",
    "For the scheduling of parallel tasks we use the\n",
    "[OhMyThreads.jl](https://github.com/JuliaFolds2/OhMyThreads.jl) package. OhMyThreads\n",
    "provides a macro based and a functional API. Here we use the macro based API because it is\n",
    "slightly more convenient when using task local values since they can be defined with the\n",
    "`@local` macro.\n",
    "\n",
    "> **Schedulers and load balancing**\n",
    ">\n",
    "> OhMyThreads provides a number of different\n",
    "> [schedulers](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).\n",
    "> In this example we use the `DynamicScheduler` (which is the default one). The\n",
    "> `DynamicScheduler` will spawn `ntasks` tasks where each task will process a chunk of\n",
    "> (roughly) equal number of cells (i.e. `length(color) ÷ ntasks`). This should be a good\n",
    "> choice for this example because we expect all cells to take the same time to process\n",
    "> and we don't need any load balancing.\n",
    "\n",
    "    For a different problem setup where some cells might take longer to process (perhaps\n",
    "    they experience plastic deformation and we need to solve a local problem) we might\n",
    "    benefit from load balancing. The `DynamicScheduler` can be used also for load\n",
    "    balancing by specifiying `nchunks` or `chunksize`. However, the `DynamicScheduler`\n",
    "    will always spawn `nchunks` tasks which can become costly since we are allocating\n",
    "    scratch data for every task. To limit the number of tasks, while allowing for more\n",
    "    than `ntasks` chunks, we can use the `GreedyScheduler` *with chunking*. For example,\n",
    "    `scheduler = OhMyThreads.GreedyScheduler(; ntasks = ntasks, nchunks = 10 * ntasks)`\n",
    "    will split the work into `10 * ntasks` chunks and spawn `ntasks` tasks to process\n",
    "    them. Refer to the [OhMyThreads\n",
    "    documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/) for details."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "using OhMyThreads, TaskLocalValues\n",
    "\n",
    "function assemble_global!(\n",
    "        K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors,\n",
    "        cellvalues_template::CellValues; ntasks = Threads.nthreads()\n",
    "    )\n",
    "    # Zero-out existing data in K and f\n",
    "    _ = start_assemble(K, f)\n",
    "    # Body force and material stiffness\n",
    "    b = Vec{3}((0.0, 0.0, -1.0))\n",
    "    C = create_material_stiffness()\n",
    "    # Loop over the colors\n",
    "    for color in colors\n",
    "        # Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of\n",
    "        # (roughly) equal number of cells (`length(color) ÷ ntasks`).\n",
    "        scheduler = OhMyThreads.DynamicScheduler(; ntasks)\n",
    "        # Parallelize the loop over the cells in this color\n",
    "        OhMyThreads.@tasks for cellidx in color\n",
    "            # Tell the @tasks loop to use the scheduler defined above\n",
    "            @set scheduler = scheduler\n",
    "            # Obtain a task local scratch and unpack it\n",
    "            @local scratch = ScratchData(dh, K, f, cellvalues_template)\n",
    "            (; cell_cache, cellvalues, Ke, fe, assembler) = scratch\n",
    "            # Reinitialize the cell cache and then the cellvalues\n",
    "            reinit!(cell_cache, cellidx)\n",
    "            reinit!(cellvalues, cell_cache)\n",
    "            # Compute the local contribution of the cell\n",
    "            assemble_cell!(Ke, fe, cellvalues, C, b)\n",
    "            # Assemble local contribution\n",
    "            assemble!(assembler, celldofs(cell_cache), Ke, fe)\n",
    "        end\n",
    "    end\n",
    "    return K, f\n",
    "end\n",
    "nothing # hide"
   ],
   "metadata": {},
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "source": [
    "> **OhMyThreads functional API: OhMyThreads.tforeach**\n",
    ">\n",
    "> The `OhMyThreads.@tasks` block above corresponds to a call to `OhMyThreads.tforeach`.\n",
    "> Using the functional API directly would look like below. The main difference is that\n",
    "> we need to manually create a `TaskLocalValue` for the scratch data.\n",
    "> ```julia\n",
    "> # using TaskLocalValues\n",
    "> scratches = TaskLocalValue() do\n",
    ">     ScratchData(dh, K, f, cellvalues)\n",
    "> end\n",
    "> OhMyThreads.tforeach(color; scheduler) do cellidx\n",
    ">     # Obtain a task local scratch and unpack it\n",
    ">     scratch = scratches[]\n",
    ">     (; cell_cache, cellvalues, Ke, fe, assembler) = scratch\n",
    ">     # Reinitialize the cell cache and then the cellvalues\n",
    ">     reinit!(cell_cache, cellidx)\n",
    ">     reinit!(cellvalues, cell_cache)\n",
    ">     # Compute the local contribution of the cell\n",
    ">     assemble_cell!(Ke, fe, cellvalues, C, b)\n",
    ">     # Assemble local contribution\n",
    ">     assemble!(assembler, celldofs(cell_cache), Ke, fe)\n",
    "> end\n",
    "> ```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "We define the main function to setup everything and then time the call to\n",
    "`assemble_global!`."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "function main(; n = 20, ntasks = Threads.nthreads())\n",
    "    # Interpolation, quadrature and cellvalues\n",
    "    interpolation = Lagrange{RefHexahedron, 1}()^3\n",
    "    quadrature = QuadratureRule{RefHexahedron}(2)\n",
    "    cellvalues = CellValues(quadrature, interpolation)\n",
    "    # Grid, colors and DofHandler\n",
    "    grid, colors = create_cantilever_grid(n)\n",
    "    dh = create_dofhandler(grid, interpolation)\n",
    "    # Global matrix and vector\n",
    "    K = allocate_matrix(dh)\n",
    "    f = zeros(ndofs(dh))\n",
    "    # Compile it\n",
    "    assemble_global!(K, f, dh, colors, cellvalues; ntasks = ntasks)\n",
    "    # Time it\n",
    "    @time assemble_global!(K, f, dh, colors, cellvalues; ntasks = ntasks)\n",
    "    return\n",
    "end\n",
    "nothing # hide"
   ],
   "metadata": {},
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "source": [
    "On a machine with 4 cores, starting julia with `--threads=auto`, we obtain the following\n",
    "timings:\n",
    "```julia\n",
    "main(; ntasks = 1) # 1.970784 seconds (902 allocations: 816.172 KiB)\n",
    "main(; ntasks = 2) # 1.025065 seconds (1.64 k allocations: 1.564 MiB)\n",
    "main(; ntasks = 3) # 0.700423 seconds (2.38 k allocations: 2.332 MiB)\n",
    "main(; ntasks = 4) # 0.548356 seconds (3.12 k allocations: 3.099 MiB)\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "---\n",
    "\n",
    "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*"
   ],
   "metadata": {}
  }
 ],
 "nbformat_minor": 3,
 "metadata": {
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.3"
  },
  "kernelspec": {
   "name": "julia-1.11",
   "display_name": "Julia 1.11.3",
   "language": "julia"
  }
 },
 "nbformat": 4
}