{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Polarizability using automatic differentiation\n",
    "\n",
    "Simple example for computing properties using (forward-mode)\n",
    "automatic differentiation.\n",
    "For a more classical approach and more details about computing polarizabilities,\n",
    "see Polarizability by linear response."
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "using DFTK\n",
    "using LinearAlgebra\n",
    "using ForwardDiff\n",
    "using PseudoPotentialData\n",
    "\n",
    "# Construct PlaneWaveBasis given a particular electric field strength\n",
    "# Again we take the example of a Helium atom.\n",
    "function make_basis(ε::T; a=10., Ecut=30) where {T}\n",
    "    lattice = T(a) * I(3)  # lattice is a cube of $a$ Bohrs\n",
    "    # Helium at the center of the box\n",
    "    pseudopotentials = PseudoFamily(\"cp2k.nc.sr.lda.v0_1.semicore.gth\")\n",
    "    atoms     = [ElementPsp(:He, pseudopotentials)]\n",
    "    positions = [[1/2, 1/2, 1/2]]\n",
    "\n",
    "    model = model_DFT(lattice, atoms, positions;\n",
    "                      functionals=[:lda_x, :lda_c_vwn],\n",
    "                      extra_terms=[ExternalFromReal(r -> -ε * (r[1] - a/2))],\n",
    "                      symmetries=false)\n",
    "    PlaneWaveBasis(model; Ecut, kgrid=[1, 1, 1])  # No k-point sampling on isolated system\n",
    "end\n",
    "\n",
    "# dipole moment of a given density (assuming the current geometry)\n",
    "function dipole(basis, ρ)\n",
    "    @assert isdiag(basis.model.lattice)\n",
    "    a  = basis.model.lattice[1, 1]\n",
    "    rr = [a * (r[1] - 1/2) for r in r_vectors(basis)]\n",
    "    sum(rr .* ρ) * basis.dvol\n",
    "end\n",
    "\n",
    "# Function to compute the dipole for a given field strength\n",
    "function compute_dipole(ε; tol=1e-8, kwargs...)\n",
    "    scfres = self_consistent_field(make_basis(ε; kwargs...); tol)\n",
    "    dipole(scfres.basis, scfres.ρ)\n",
    "end;"
   ],
   "metadata": {},
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "With this in place we can compute the polarizability from finite differences\n",
    "(just like in the previous example):"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ----   ------\n",
      "  1   -2.770859815902                   -0.53    9.0    177ms\n",
      "  2   -2.772140634348       -2.89       -1.31    1.0    114ms\n",
      "  3   -2.772169899236       -4.53       -2.57    1.0    110ms\n",
      "  4   -2.772170706547       -6.09       -3.53    2.0    127ms\n",
      "  5   -2.772170722269       -7.80       -4.12    1.0    114ms\n",
      "  6   -2.772170722951       -9.17       -4.93    1.0    161ms\n",
      "  7   -2.772170723014      -10.20       -5.41    2.0    496ms\n",
      "  8   -2.772170723015      -12.04       -6.25    1.0    114ms\n",
      "  9   -2.772170723015      -14.31       -6.79    1.0    139ms\n",
      " 10   -2.772170723015      -14.12       -7.47    1.0    115ms\n",
      " 11   -2.772170723015      -14.07       -8.22    2.0    136ms\n",
      "n     Energy            log10(ΔE)   log10(Δρ)   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ----   ------\n",
      "  1   -2.770755627493                   -0.53    8.0    155ms\n",
      "  2   -2.772053846938       -2.89       -1.31    1.0    109ms\n",
      "  3   -2.772082926712       -4.54       -2.55    1.0    108ms\n",
      "  4   -2.772083368510       -6.35       -3.45    1.0    111ms\n",
      "  5   -2.772083416412       -7.32       -3.96    2.0    127ms\n",
      "  6   -2.772083417755       -8.87       -5.19    1.0    117ms\n",
      "  7   -2.772083417807      -10.28       -5.28    2.0    131ms\n",
      "  8   -2.772083417811      -11.50       -6.30    1.0    121ms\n",
      "  9   -2.772083417811      -13.23       -6.56    2.0    539ms\n",
      " 10   -2.772083417811      -14.18       -7.40    1.0    114ms\n",
      " 11   -2.772083417811   +  -14.10       -8.00    1.0    116ms\n",
      " 12   -2.772083417811      -13.82       -8.66    2.0    154ms\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1.7735579730013822"
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "cell_type": "code",
   "source": [
    "polarizability_fd = let\n",
    "    ε = 0.01\n",
    "    (compute_dipole(ε) - compute_dipole(0.0)) / ε\n",
    "end"
   ],
   "metadata": {},
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "We do the same thing using automatic differentiation. Under the hood this uses\n",
    "custom rules to implicitly differentiate through the self-consistent\n",
    "field fixed-point problem. This leads to a density-functional perturbation\n",
    "theory problem, which is automatically set up and solved in the background."
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ----   ------\n",
      "  1   -2.770657324401                   -0.53    9.0    180ms\n",
      "  2   -2.772053993343       -2.85       -1.30    1.0    106ms\n",
      "  3   -2.772082856401       -4.54       -2.65    1.0    110ms\n",
      "  4   -2.772083415438       -6.25       -4.03    2.0    127ms\n",
      "  5   -2.772083417616       -8.66       -4.45    2.0    128ms\n",
      "  6   -2.772083417803       -9.73       -5.42    1.0    600ms\n",
      "  7   -2.772083417810      -11.12       -5.99    2.0    148ms\n",
      "  8   -2.772083417811      -13.01       -6.50    1.0    152ms\n",
      "  9   -2.772083417811      -13.64       -7.01    2.0    145ms\n",
      " 10   -2.772083417811   +  -15.35       -7.92    1.0    116ms\n",
      " 11   -2.772083417811   +  -14.35       -8.56    2.0    130ms\n",
      "Solving response problem\n",
      "[ Info: GMRES linsolve starts with norm of residual = 4.19e+00\n",
      "[ Info: GMRES linsolve in iteration 1; step 1: normres = 2.49e-01\n",
      "[ Info: GMRES linsolve in iteration 1; step 2: normres = 3.76e-03\n",
      "[ Info: GMRES linsolve in iteration 1; step 3: normres = 2.84e-04\n",
      "[ Info: GMRES linsolve in iteration 1; step 4: normres = 4.67e-06\n",
      "[ Info: GMRES linsolve in iteration 1; step 5: normres = 1.08e-08\n",
      "┌ Info: GMRES linsolve converged at iteration 1, step 6:\n",
      "│ * norm of residual = 7.68e-10\n",
      "└ * number of operations = 8\n",
      "\n",
      "Polarizability via ForwardDiff:       1.7725349649275122\n",
      "Polarizability via finite difference: 1.7735579730013822\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "polarizability = ForwardDiff.derivative(compute_dipole, 0.0)\n",
    "println()\n",
    "println(\"Polarizability via ForwardDiff:       $polarizability\")\n",
    "println(\"Polarizability via finite difference: $polarizability_fd\")"
   ],
   "metadata": {},
   "execution_count": 3
  }
 ],
 "nbformat_minor": 3,
 "metadata": {
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.4"
  },
  "kernelspec": {
   "name": "julia-1.11",
   "display_name": "Julia 1.11.4",
   "language": "julia"
  }
 },
 "nbformat": 4
}