{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp fit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos.fit\n", "\n", "> Module that wraps constrained + global fitters in a differentiable form using the module fax. Uses the two-phase-solver." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "import jax\n", "from fax.implicit import twophase\n", "import jax.experimental.optimizers as optimizers\n", "\n", "from neos.transforms import to_bounded_vec, to_inf_vec, to_bounded, to_inf\n", "from neos.models import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_solvers(\n", " model_constructor,\n", " pdf_transform=False,\n", " default_rtol=1e-10,\n", " default_atol=1e-10,\n", " default_max_iter=int(1e7),\n", " learning_rate = 0.01\n", "):\n", "\n", " adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)\n", "\n", " def make_model(hyper_pars):\n", " constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]\n", " m, bonlypars = model_constructor(nn_pars)\n", "\n", "\n", " bounds = m.config.suggested_bounds\n", " constrained_mu = to_inf(constrained_mu,bounds[0]) if pdf_transform else constrained_mu\n", "\n", " exp_bonly_data = expected_data(m,bonlypars, include_auxdata=True)\n", "\n", " def expected_logpdf(pars): # maps pars to bounded space if pdf_transform = True\n", "\n", " return (\n", " logpdf(\n", " m,to_bounded_vec(pars, bounds), exp_bonly_data\n", " )\n", " if pdf_transform\n", " else\n", " logpdf(\n", " m,pars, exp_bonly_data\n", " )\n", " )\n", "\n", " def global_fit_objective(pars): # NLL\n", " return -expected_logpdf(pars)[0]\n", "\n", " def constrained_fit_objective(nuis_par): # NLL\n", " pars = jax.numpy.concatenate(\n", " [jax.numpy.asarray([constrained_mu]), nuis_par]\n", " )\n", " return -expected_logpdf(pars)[0]\n", "\n", " return constrained_mu, global_fit_objective, constrained_fit_objective,bounds\n", "\n", " def global_bestfit_minimized(hyper_param):\n", " _, nll, _ ,_ = make_model(hyper_param)\n", "\n", " def bestfit_via_grad_descent(i, param): # gradient descent\n", " g = jax.grad(nll)(param)\n", " # param = param - g * learning_rate\n", " param = adam_get_params(adam_update(i,g,adam_init(param)))\n", " return param\n", "\n", " return bestfit_via_grad_descent\n", "\n", " def constrained_bestfit_minimized(hyper_param):\n", " mu, nll, cnll,bounds = make_model(hyper_param)\n", "\n", " def bestfit_via_grad_descent(i, param): # gradient descent\n", " _, np = param[0], param[1:]\n", " g = jax.grad(cnll)(np)\n", " np = adam_get_params(adam_update(i,g,adam_init(np)))\n", " param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])\n", " return param\n", "\n", " return bestfit_via_grad_descent\n", "\n", " global_solve = twophase.two_phase_solver(\n", " param_func=global_bestfit_minimized,\n", " default_rtol=default_rtol,\n", " default_atol=default_atol,\n", " default_max_iter=default_max_iter\n", " )\n", " constrained_solver = twophase.two_phase_solver(\n", " param_func=constrained_bestfit_minimized,\n", " default_rtol=default_rtol,\n", " default_atol=default_atol,\n", " default_max_iter=default_max_iter,\n", " )\n", "\n", " def g_fitter(init, hyper_pars):\n", " solve = global_solve(init, hyper_pars)\n", " return solve.value\n", "\n", " def c_fitter(init, hyper_pars):\n", " solve = constrained_solver(init, hyper_pars)\n", " return solve.value\n", "\n", " return g_fitter, c_fitter" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### getting test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from neos.models import nn_model_maker\n", "\n", "g_fitter, c_fitter = get_solvers(nn_model_maker)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### non-transformed fit test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### transformed fit test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }