{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Custom solvers\n",
    "In this example, we show how to define custom solvers. Our system\n",
    "will again be silicon, because we are not very imaginative"
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "using DFTK\n",
    "using LinearAlgebra\n",
    "using PseudoPotentialData\n",
    "using AtomsBuilder\n",
    "\n",
    "# We take very (very) crude parameters\n",
    "pseudopotentials = PseudoFamily(\"dojo.nc.sr.lda.v0_4_1.standard.upf\")\n",
    "model = model_DFT(bulk(:Si); functionals=LDA(), pseudopotentials)\n",
    "basis = PlaneWaveBasis(model; Ecut=5, kgrid=[1, 1, 1]);"
   ],
   "metadata": {},
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "We define our custom fix-point solver: simply a damped fixed-point"
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "function my_fp_solver(f, x0, info0; maxiter)\n",
    "    mixing_factor = .7\n",
    "    x = x0\n",
    "    info = info0\n",
    "    for n = 1:maxiter\n",
    "        fx, info = f(x, info)\n",
    "        if info.converged || info.timedout\n",
    "            break\n",
    "        end\n",
    "        x = x + mixing_factor * (fx - x)\n",
    "    end\n",
    "    (; fixpoint=x, info)\n",
    "end;"
   ],
   "metadata": {},
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "Note that the fixpoint map `f` operates on an auxiliary variable `info` for\n",
    "state bookkeeping. Early termination criteria are flagged from inside\n",
    "the function `f` using boolean flags `info.converged` and `info.timedout`.\n",
    "For control over these criteria, see the `is_converged` and `maxtime`\n",
    "keyword arguments of `self_consistent_field`."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Our eigenvalue solver just forms the dense matrix and diagonalizes\n",
    "it explicitly (this only works for very small systems)"
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "function my_eig_solver(A, X0; maxiter, tol, kwargs...)\n",
    "    n = size(X0, 2)\n",
    "    A = Array(A)\n",
    "    E = eigen(A)\n",
    "    λ = E.values[1:n]\n",
    "    X = E.vectors[:, 1:n]\n",
    "    (; λ, X, residual_norms=[], n_iter=0, converged=true, n_matvec=0)\n",
    "end;"
   ],
   "metadata": {},
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "Finally we also define our custom mixing scheme. It will be a mixture\n",
    "of simple mixing (for the first 2 steps) and than default to Kerker mixing.\n",
    "In the mixing interface `δF` is $(ρ_\\text{out} - ρ_\\text{in})$, i.e.\n",
    "the difference in density between two subsequent SCF steps and the `mix`\n",
    "function returns $δρ$, which is added to $ρ_\\text{in}$ to yield $ρ_\\text{next}$,\n",
    "the density for the next SCF step."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "struct MyMixing\n",
    "    n_simple  # Number of iterations for simple mixing\n",
    "end\n",
    "MyMixing() = MyMixing(2)\n",
    "\n",
    "function DFTK.mix_density(mixing::MyMixing, basis, δF; n_iter, kwargs...)\n",
    "    if n_iter <= mixing.n_simple\n",
    "        return δF  # Simple mixing -> Do not modify update at all\n",
    "    else\n",
    "        # Use the default KerkerMixing from DFTK\n",
    "        DFTK.mix_density(KerkerMixing(), basis, δF; kwargs...)\n",
    "    end\n",
    "end"
   ],
   "metadata": {},
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "source": [
    "That's it! Now we just run the SCF with these solvers"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ----   ------\n",
      "  1   -7.844704367369                   -0.58    0.0    1.50s\n",
      "  2   -7.856031172216       -1.95       -1.01    0.0    1.33s\n",
      "  3   -7.857206749077       -2.93       -1.43    0.0   55.8ms\n",
      "  4   -7.857296337864       -4.05       -1.75    0.0   61.3ms\n",
      "  5   -7.857319136840       -4.64       -2.06    0.0    102ms\n",
      "  6   -7.857325084669       -5.23       -2.36    0.0   55.8ms\n",
      "  7   -7.857326696758       -5.79       -2.66    0.0   60.3ms\n",
      "  8   -7.857327153270       -6.34       -2.95    0.0    157ms\n",
      "  9   -7.857327288370       -6.87       -3.23    0.0   65.3ms\n",
      " 10   -7.857327330007       -7.38       -3.49    0.0   79.0ms\n",
      " 11   -7.857327343295       -7.88       -3.75    0.0    584ms\n",
      " 12   -7.857327347658       -8.36       -4.00    0.0   56.7ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "scfres = self_consistent_field(basis;\n",
    "                               tol=1e-4,\n",
    "                               solver=my_fp_solver,\n",
    "                               eigensolver=my_eig_solver,\n",
    "                               mixing=MyMixing());"
   ],
   "metadata": {},
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "source": [
    "Note that the default convergence criterion is the difference in\n",
    "density. When this gets below `tol`, the fixed-point solver terminates.\n",
    "You can also customize this with the `is_converged` keyword argument to\n",
    "`self_consistent_field`, as shown below."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Customizing the convergence criterion\n",
    "Here is an example of a defining a custom convergence criterion and specifying\n",
    "it using the `is_converged` callback keyword to `self_consistent_field`."
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ----   ------\n",
      "  1   -7.844704367369                   -0.58    0.0    141ms\n",
      "  2   -7.856031172216       -1.95       -1.01    0.0    138ms\n",
      "  3   -7.857206749077       -2.93       -1.43    0.0   57.7ms\n",
      "  4   -7.857296337864       -4.05       -1.75    0.0   56.7ms\n",
      "  5   -7.857319136840       -4.64       -2.06    0.0   56.7ms\n",
      "  6   -7.857325084669       -5.23       -2.36    0.0   56.9ms\n",
      "  7   -7.857326696758       -5.79       -2.66    0.0    145ms\n",
      "  8   -7.857327153270       -6.34       -2.95    0.0   55.7ms\n",
      "  9   -7.857327288370       -6.87       -3.23    0.0   60.1ms\n",
      " 10   -7.857327330007       -7.38       -3.49    0.0   69.6ms\n",
      " 11   -7.857327343295       -7.88       -3.75    0.0   68.7ms\n",
      " 12   -7.857327347658       -8.36       -4.00    0.0   92.1ms\n",
      " 13   -7.857327349123       -8.83       -4.25    0.0    123ms\n",
      " 14   -7.857327349624       -9.30       -4.48    0.0   56.1ms\n",
      " 15   -7.857327349797       -9.76       -4.72    0.0   66.1ms\n",
      " 16   -7.857327349857      -10.22       -4.95    0.0   69.4ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "function my_convergence_criterion(info)\n",
    "    tol = 1e-10\n",
    "    length(info.history_Etot) < 2 && return false\n",
    "    ΔE = (info.history_Etot[end-1] - info.history_Etot[end])\n",
    "    ΔE < tol\n",
    "end\n",
    "\n",
    "scfres2 = self_consistent_field(basis;\n",
    "                                solver=my_fp_solver,\n",
    "                                is_converged=my_convergence_criterion,\n",
    "                                eigensolver=my_eig_solver,\n",
    "                                mixing=MyMixing());"
   ],
   "metadata": {},
   "execution_count": 6
  }
 ],
 "nbformat_minor": 3,
 "metadata": {
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.4"
  },
  "kernelspec": {
   "name": "julia-1.11",
   "display_name": "Julia 1.11.4",
   "language": "julia"
  }
 },
 "nbformat": 4
}