{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from openff.toolkit.topology import Molecule, Topology\n", "from openff.toolkit.typing.engines.smirnoff import ForceField\n", "\n", "from openff.interchange import Interchange" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Construct a single-molecule system from toolkit classes\n", "mol = Molecule.from_smiles(\"CCO\")\n", "mol.generate_conformers(n_conformers=1)\n", "top = Topology.from_molecules([mol])\n", "parsley = ForceField(\"openff-1.0.0.offxml\")\n", "\n", "off_sys = Interchange.from_smirnoff(force_field=parsley, topology=top)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bonds = off_sys.handlers[\"Bonds\"]\n", "\n", "# Transform parameters into matrix representations\n", "p = bonds.get_force_field_parameters()\n", "mapping = bonds.get_mapping()\n", "q = bonds.get_system_parameters()\n", "m = bonds.get_param_matrix()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# force field parameters, each row is something like [k (kcal/mol/A), length (A)]\n", "p" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# system parameters, a.k.a. force field parameters as they exist in a parametrized system\n", "q" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# m is the parametrization matrix, which can be dotted with p to get out q\n", "assert np.allclose(m.dot(p.flatten()).reshape((-1, 2)), q)\n", "\n", "m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save and set initial values\n", "q0 = q\n", "p0 = p\n", "\n", "# set learning rate\n", "a = 0.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from copy import deepcopy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# let jax run with autodiff\n", "_, f_vjp_bonds = jax.vjp(bonds.parametrize, jnp.asarray(p0)) # d/dp\n", "\n", "# jax.jvp( ..., has_aux=True) is another approach, but requires that bonds.parametrize returns the indices as well" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "q_target = deepcopy(q0)\n", "p_target = deepcopy(p0)\n", "\n", "# modify a few of the force field targets to arbitrary values;\n", "# this mimic some \"true\" values we wish to tune to, despite\n", "# these values not being known in real-world fitting\n", "p_target[:, 1] = 0.5 + np.random.rand(4)\n", "\n", "# obtain the target _sytem_ parameters by dotting the parametrization\n", "# matrix with target force field values\n", "q_target = m.dot(p_target.flatten()).reshape((-1, 2))\n", "\n", "\n", "# create a dummy loss function via faking known target parameters;\n", "# in practice this could be the result of an MD run, FE calculation, etc.\n", "def loss(p):\n", " return jnp.linalg.norm(bonds.parametrize(p) - q_target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out, f_vjp_bonds = jax.vjp(loss, p0) # composes a jax.grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This also returns loss(p0), which we do not need to store\n", "out == loss(p0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f_vjp_bonds(1.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# this does the same as the jax.vjp above\n", "jax_loss = jax.grad(loss)\n", "\n", "jax_loss(p0) # dL/dp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f_vjp_bonds(1.0)[0] == jax_loss(p0) # dL/dp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# derivative of loss function evaluated at the original system parameters;\n", "# note that column 0 mathces target values, so the derivate is flat\n", "f_vjp_bonds(loss(q0)) # dL/dp (!) can be used as gradient in fitting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "\n", "# label target values\n", "ax.hlines(p_target[0, 1], 0, 100, color=\"k\", ls=\"--\", label=\"[#6X4:1]-[#6X4:2]\")\n", "ax.hlines(p_target[1, 1], 0, 100, color=\"r\", ls=\"--\", label=\"[#6X4:1]-[#1:2]\")\n", "ax.hlines(p_target[2, 1], 0, 100, color=\"g\", ls=\"--\", label=\"[#6:1]-[#8:2]\")\n", "ax.hlines(p_target[3, 1], 0, 100, color=\"b\", ls=\"--\", label=\"[#8:1]-[#1:2]\")\n", "\n", "for i in range(100):\n", " if i % 10 == 0:\n", " print(f\"step {i}\\tloss: {loss(p)}\")\n", " ax.plot(i, p[0][1], \"k.\")\n", " ax.plot(i, p[1][1], \"r.\")\n", " ax.plot(i, p[2][1], \"g.\")\n", " ax.plot(i, p[3][1], \"b.\")\n", "\n", " # use jax to get the gradient\n", " _, f_vjp_bonds = jax.vjp(loss, p)\n", " grad = f_vjp_bonds(1.0)[0]\n", " # update force field parameters\n", " p -= a * grad\n", " # use the parametrization matrix to propagate new\n", " # force field parameters into new system parameters\n", " q = m.dot(p.flatten()).reshape((-1, 2))\n", "\n", "\n", "ax.legend(loc=0)\n", "ax.set_xlabel(\"iteration\")\n", "ax.set_ylabel(\"parameter value (bond length-ish)\")\n", "ax.set_xlim((0, 100))\n", "ax.set_ylim((0, 1.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We can do everything all over again with angles, almost identically\n", "angles = off_sys.handlers[\"Angles\"]\n", "q0 = angles.get_system_parameters()\n", "p0 = angles.get_force_field_parameters()\n", "mapping = angles.get_mapping()\n", "m = angles.get_param_matrix()\n", "q = q0\n", "p = p0\n", "a = 0.1\n", "\n", "q_target = deepcopy(q0)\n", "p_target = deepcopy(p0)\n", "p_target[:, 1] = np.random.randint(100, 120, 3)\n", "\n", "q_target = angles.parametrize(p_target)\n", "\n", "\n", "def loss(p):\n", " return jnp.linalg.norm(angles.parametrize(p) - q_target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "\n", "# label target values\n", "ax.hlines(p_target[0, 1], 0, 100, color=\"k\", ls=\"--\", label=\"[*:1]~[#6X4:2]-[*:3]\")\n", "ax.hlines(p_target[1, 1], 0, 100, color=\"r\", ls=\"--\", label=\"[*:1]-[#8:2]-[*:3]\")\n", "ax.hlines(p_target[2, 1], 0, 100, color=\"g\", ls=\"--\", label=\"[#1:1]-[#6X4:2]-[#1:3]\")\n", "\n", "for i in range(100):\n", " if i % 10 == 0:\n", " print(f\"step {i}\\tloss: {loss(p)}\")\n", " ax.plot(i, p[0][1], \"k.\")\n", " ax.plot(i, p[1][1], \"r.\")\n", " ax.plot(i, p[2][1], \"g.\")\n", "\n", " # use jax to get the gradient\n", " _, f_vjp_angles = jax.vjp(loss, p)\n", " grad = f_vjp_angles(1.0)[0]\n", " # update force field parameters\n", " p -= a * grad\n", " # print(p[0])\n", " q = m.dot(p.flatten()).reshape((-1, 2))\n", "\n", "ax.legend(loc=0)\n", "ax.set_xlabel(\"iteration\")\n", "ax.set_ylabel(\"parameter value (angle-ish)\")\n", "ax.set_xlim((0, 100))\n", "ax.set_ylim((100, 120))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 4 }