{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Saving SCF results on disk and SCF checkpoints\n",
    "\n",
    "For longer DFT calculations it is pretty standard to run them on a cluster\n",
    "in advance and to perform postprocessing (band structure calculation,\n",
    "plotting of density, etc.) at a later point and potentially on a different\n",
    "machine.\n",
    "\n",
    "To support such workflows DFTK offers the two functions `save_scfres`\n",
    "and `load_scfres`, which allow to save the data structure returned\n",
    "by `self_consistent_field` on disk or retrieve it back into memory,\n",
    "respectively. For this purpose DFTK uses the\n",
    "[JLD2.jl](https://github.com/JuliaIO/JLD2.jl) file format and Julia package.\n",
    "\n",
    "> **Availability of `load_scfres`, `save_scfres` and checkpointing**\n",
    ">\n",
    "> As JLD2 is an optional dependency of DFTK these three functions are only\n",
    "> available once one has *both* imported DFTK and JLD2 (`using DFTK`\n",
    "> and `using JLD2`).\n",
    "\n",
    "To illustrate the use of the functions in practice we will compute\n",
    "the total energy of the O₂ molecule at PBE level. To get the triplet\n",
    "ground state we use a collinear spin polarisation\n",
    "(see Collinear spin and magnetic systems for details)\n",
    "and a bit of temperature to ease convergence:"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Magnet   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ------   ----   ------\n",
      "  1   -27.64800912096                   -0.13    0.001    7.5    426ms\n",
      "  2   -28.92456851650        0.11       -0.83    0.688    2.0    629ms\n",
      "  3   -28.93111803647       -2.18       -1.13    1.192    2.0    110ms\n",
      "  4   -28.93774962933       -2.18       -1.19    1.776    1.5   84.2ms\n",
      "  5   -28.93924836688       -2.82       -1.06    2.000    2.0   91.0ms\n",
      "  6   -28.93957812528       -3.48       -1.87    1.975    1.5   81.3ms\n",
      "  7   -28.93961156694       -4.48       -2.86    1.984    1.0   81.3ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "using DFTK\n",
    "using LinearAlgebra\n",
    "using JLD2\n",
    "using PseudoPotentialData\n",
    "\n",
    "d = 2.079  # oxygen-oxygen bondlength\n",
    "a = 9.0    # size of the simulation box\n",
    "lattice = a * I(3)\n",
    "pseudopotentials = PseudoFamily(\"cp2k.nc.sr.pbe.v0_1.semicore.gth\")\n",
    "O = ElementPsp(:O, pseudopotentials)\n",
    "atoms     = [O, O]\n",
    "positions = d / 2a * [[0, 0, 1], [0, 0, -1]]\n",
    "magnetic_moments = [1., 1.]\n",
    "\n",
    "Ecut  = 10  # Far too small to be converged\n",
    "model = model_DFT(lattice, atoms, positions;\n",
    "                  functionals=PBE(),\n",
    "                  temperature=0.02, smearing=Smearing.Gaussian(),\n",
    "                  magnetic_moments)\n",
    "basis = PlaneWaveBasis(model; Ecut, kgrid=[1, 1, 1])\n",
    "\n",
    "scfres = self_consistent_field(basis, tol=1e-2, ρ=guess_density(basis, magnetic_moments))\n",
    "save_scfres(\"scfres.jld2\", scfres);"
   ],
   "metadata": {},
   "execution_count": 1
  },
  {
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "Energy breakdown (in Ha):\n    Kinetic             16.7728512\n    AtomicLocal         -58.4960399\n    AtomicNonlocal      4.7090966 \n    Ewald               -4.8994689\n    PspCorrection       0.0044178 \n    Hartree             19.3618046\n    Xc                  -6.3913832\n    Entropy             -0.0008897\n\n    total               -28.939611566936"
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "cell_type": "code",
   "source": [
    "scfres.energies"
   ],
   "metadata": {},
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "The `scfres.jld2` file could now be transferred to a different computer,\n",
    "Where one could fire up a REPL to inspect the results of the above\n",
    "calculation:"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "(:α, :history_Δρ, :converged, :occupation, :occupation_threshold, :algorithm, :basis, :runtime_ns, :n_iter, :n_matvec, :history_Etot, :εF, :energies, :ρ, :timedout, :n_bands_converge, :τ, :eigenvalues, :ψ, :ham)"
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "cell_type": "code",
   "source": [
    "using DFTK\n",
    "using JLD2\n",
    "loaded = load_scfres(\"scfres.jld2\")\n",
    "propertynames(loaded)"
   ],
   "metadata": {},
   "execution_count": 3
  },
  {
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "Energy breakdown (in Ha):\n    Kinetic             16.7728512\n    AtomicLocal         -58.4960399\n    AtomicNonlocal      4.7090966 \n    Ewald               -4.8994689\n    PspCorrection       0.0044178 \n    Hartree             19.3618046\n    Xc                  -6.3913832\n    Entropy             -0.0008897\n\n    total               -28.939611566936"
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "cell_type": "code",
   "source": [
    "loaded.energies"
   ],
   "metadata": {},
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "source": [
    "Since the loaded data contains exactly the same data as the `scfres` returned by the\n",
    "SCF calculation one could use it to plot a band structure, e.g.\n",
    "`plot_bandstructure(load_scfres(\"scfres.jld2\"))` directly from the stored data.\n",
    "\n",
    "Notice that both `load_scfres` and `save_scfres` work by transferring all data\n",
    "to/from the master process, which performs the IO operations without parallelisation.\n",
    "Since this can become slow, both functions support optional arguments to speed up\n",
    "the processing. An overview:\n",
    "- `save_scfres(\"scfres.jld2\", scfres; save_ψ=false)` avoids saving\n",
    "  the Bloch wave, which is usually faster and saves storage space.\n",
    "- `load_scfres(\"scfres.jld2\", basis)` avoids reconstructing the basis from the file,\n",
    "  but uses the passed basis instead. This save the time of constructing the basis\n",
    "  twice and allows to specify parallelisation options (via the passed basis). Usually\n",
    "  this is useful for continuing a calculation on a supercomputer or cluster.\n",
    "\n",
    "See also the discussion on Input and output formats on JLD2 files."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Checkpointing of SCF calculations\n",
    "A related feature, which is very useful especially for longer calculations with DFTK\n",
    "is automatic checkpointing, where the state of the SCF is periodically written to disk.\n",
    "The advantage is that in case the calculation errors or gets aborted due\n",
    "to overrunning the walltime limit one does not need to start from scratch,\n",
    "but can continue the calculation from the last checkpoint.\n",
    "\n",
    "The easiest way to enable checkpointing is to use the `kwargs_scf_checkpoints`\n",
    "function, which does two things. (1) It sets up checkpointing using the\n",
    "`ScfSaveCheckpoints` callback and (2) if a checkpoint file is detected,\n",
    "the stored density is used to continue the calculation instead of the usual\n",
    "atomic-orbital based guess. In practice this is done by modifying the keyword arguments\n",
    "passed to # `self_consistent_field` appropriately, e.g. by using the density\n",
    "or orbitals from the checkpoint file. For example:"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Magnet   α      Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ------   ----   ----   ------\n",
      "  1   -27.64750293488                   -0.13    0.001   0.80    7.0    246ms\n",
      "  2   -28.92453454523        0.11       -0.83    0.685   0.80    2.0    879ms\n",
      "  3   -28.93108291898       -2.18       -1.13    1.188   0.80    2.0    113ms\n",
      "  4   -28.93772996838       -2.18       -1.18    1.775   0.80    1.5   89.3ms\n",
      "  5   -28.93929943042       -2.80       -1.09    2.000   0.80    1.5   90.5ms\n",
      "  6   -28.93958281278       -3.55       -1.89    1.976   0.80    1.5    122ms\n",
      "  7   -28.93961188486       -4.54       -3.00    1.985   0.80    1.0   89.6ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "checkpointargs = kwargs_scf_checkpoints(basis; ρ=guess_density(basis, magnetic_moments))\n",
    "scfres = self_consistent_field(basis; tol=1e-2, checkpointargs...);"
   ],
   "metadata": {},
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "source": [
    "Notice that the `ρ` argument is now passed to kwargs_scf_checkpoints instead.\n",
    "If we run in the same folder the SCF again (here using a tighter tolerance),\n",
    "the calculation just continues."
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Magnet   α      Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ------   ----   ----   ------\n",
      "  1   -28.93961282874                   -3.34    1.985   0.80   10.5    190ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "checkpointargs = kwargs_scf_checkpoints(basis; ρ=guess_density(basis, magnetic_moments))\n",
    "scfres = self_consistent_field(basis; tol=1e-3, checkpointargs...);"
   ],
   "metadata": {},
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "source": [
    "Since only the density is stored in a checkpoint\n",
    "(and not the Bloch waves), the first step needs a slightly elevated number\n",
    "of diagonalizations. Notice, that reconstructing the `checkpointargs` in this second\n",
    "call is important as the `checkpointargs` now contain different data,\n",
    "such that the SCF continues from the checkpoint.\n",
    "By default checkpoint is saved in the file `dftk_scf_checkpoint.jld2`, which can be changed\n",
    "using the `filename` keyword argument of `kwargs_scf_checkpoints`. Note that the\n",
    "file is not deleted by DFTK, so it is your responsibility to clean it up. Further note\n",
    "that warnings or errors will arise if you try to use a checkpoint, which is incompatible\n",
    "with your calculation.\n",
    "\n",
    "We can also inspect the checkpoint file manually using the `load_scfres` function\n",
    "and use it manually to continue the calculation:"
   ],
   "metadata": {}
  },
  {
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n     Energy            log10(ΔE)   log10(Δρ)   Magnet   Diag   Δtime\n",
      "---   ---------------   ---------   ---------   ------   ----   ------\n",
      "  1   -28.93866808195                   -2.32    1.985    5.5    123ms\n",
      "  2   -28.93945065159       -3.11       -2.78    1.985    1.0   87.0ms\n",
      "  3   -28.93961082794       -3.80       -2.60    1.985    4.5    167ms\n",
      "  4   -28.93961209896       -5.90       -2.74    1.985    1.0   76.8ms\n",
      "  5   -28.93961268109       -6.23       -2.89    1.985    1.0   72.3ms\n",
      "  6   -28.93961285717       -6.75       -3.00    1.985    1.0   75.3ms\n",
      "  7   -28.93961307921       -6.65       -3.28    1.985    1.5   79.0ms\n",
      "  8   -28.93961312238       -7.36       -3.42    1.985    1.0   73.4ms\n",
      "  9   -28.93961316530       -7.37       -3.87    1.985    2.0   84.3ms\n",
      " 10   -28.93961317227       -8.16       -5.14    1.985    2.0   84.4ms\n"
     ]
    }
   ],
   "cell_type": "code",
   "source": [
    "oldstate = load_scfres(\"dftk_scf_checkpoint.jld2\")\n",
    "scfres   = self_consistent_field(oldstate.basis, ρ=oldstate.ρ, ψ=oldstate.ψ, tol=1e-4);"
   ],
   "metadata": {},
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "source": [
    "Some details on what happens under the hood in this mechanism: When using the\n",
    "`kwargs_scf_checkpoints` function, the `ScfSaveCheckpoints` callback is employed\n",
    "during the SCF, which causes the density to be stored to the JLD2 file in every iteration.\n",
    "When reading the file, the `kwargs_scf_checkpoints` transparently patches away the `ψ`\n",
    "and `ρ` keyword arguments and replaces them by the data obtained from the file.\n",
    "For more details on using callbacks with DFTK's `self_consistent_field` function\n",
    "see Monitoring self-consistent field calculations."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "(Cleanup files generated by this notebook)"
   ],
   "metadata": {}
  },
  {
   "outputs": [],
   "cell_type": "code",
   "source": [
    "rm(\"dftk_scf_checkpoint.jld2\")\n",
    "rm(\"scfres.jld2\")"
   ],
   "metadata": {},
   "execution_count": 8
  }
 ],
 "nbformat_minor": 3,
 "metadata": {
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.4"
  },
  "kernelspec": {
   "name": "julia-1.11",
   "display_name": "Julia 1.11.4",
   "language": "julia"
  }
 },
 "nbformat": 4
}