# [Distributed statistical inference with `pyhf`](https://indico.cern.ch/event/1019958/contributions/4418598/)

## Cursorary introduction of `pyhf`

For the sake of brevity and time, we won't go into a full discussion of what `pyhf` is and what you can do with it. For now we'll point you to the [latest `pyhf` tutorial for `pyhf` `v0.6.2`](https://github.com/pyhf/pyhf-tutorial/tree/786702385e003511bbce27773c48df8769dfcfcb) as well as our vCHEP 2021 talk: [Distributed statistical inference with `pyhf` enabled through `funcX`](https://indico.cern.ch/event/948465/contributions/4324013/).

Very shortly though, `pyhf` is a pure-Python implimentation of the `HistFactory` family of statistical models that through optional computational backends like JAX provides autodifferentiation and hardware acceleration on GPUs. `pyhf` is part of Scikit-HEP and is designed to have a clear Pythonic API with the goal of making it easier and clearer to produce and interpret binned models.

Taking an example from the `pyhf` project `README`, this is all the code that is needed to build a simple 1-bin model and then to perform a hypothesis test scan across multiple parameters of interest (POIs), plot those results, and inverting that determine the 95% CL upper limit on the POI value.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pyhf
from pyhf.contrib.viz import brazil

In [None]:
pyhf.set_backend("numpy")
model = pyhf.simplemodels.uncorrelated_background(
    signal=[10.0], bkg=[50.0], bkg_uncertainty=[7.0]
)
data = [55.0] + model.config.auxdata

poi_vals = np.linspace(0, 5, 41)
results = [
    pyhf.infer.hypotest(
        test_poi, data, model, test_stat="qtilde", return_expected_set=True
    )
    for test_poi in poi_vals
]

fig, ax = plt.subplots()
fig.set_size_inches(7, 5)
plot = brazil.plot_results(poi_vals, results, ax=ax)

In [None]:
obs_limit, exp_limits, (scan, results) = pyhf.infer.intervals.upperlimit(
    data, model, poi_vals, return_results=True
)
print(f"observed limit: {obs_limit}")

The important part to emphasize for the purposes of this notebook talk though is just that `pyhf` allows for **statistical modelling of binned models** and allows for **fast fitting using Pythonic APIs**.

## Introduction to [`funcX`](https://funcx.readthedocs.io/en/latest/) - High Performance Function Serving

* `funcX` is a high-performance Function as a Service (FaaS) platform
* Designed to orchestrate scientific workloads across **heterogeneous computing resources** (clusters, clouds, and supercomputers) and **task execution providers** (HTCondor, Slurm, Torque, and Kubernetes)
* Leverages [Parsl](https://parsl.readthedocs.io/en/stable/) (flexible and scalable parallel programming library for Python) for efficient parallelism and managing concurrent task execution

**`funcX` endpoints** are logical entities that represent a specified computer resource.

* Managed by an agent process allowing the `funcX` service to dispatch **user defined functions** to resources for execution

The agent handles:
- Authentication through (Globus) and authorization
- Provisioning of nodes on the compute resource
- Monitoring and management

We'll see a bit more in a little bit

## Demo of `funcX`

In [None]:
from time import sleep

import funcx
from funcx.sdk.client import FuncXClient

### Endpoint Creation (On execution machine)

With the `funcx-endpoint` CLI API

In [None]:
! funcx-endpoint --help

you need to create a template environment for your endpoint.

```
$ funcx-endpoint configure pyhf
```

Which will create a default `funcX` configuration file at `~/.funcx/pyhf/config.py`.

1. Note that `funcX` requires the use of [Gloubs](https://www.globus.org/) and so will require you to first login to a Globus account to use the `funcx-sdk`. Globus allows authentication through existing organizational logins or through Google accounts or [ORCID iD](https://orcid.org/) so this shouldn't be a barrier to use.
<br><br>
![globus_login_page](figures/globus_login_page.png)
<br><br>
2. Once you authenticate with Globus you'll then need to approve the `funcx-sdk`'s required permissions and you'll be given a time limited authorization code.
3. Copy this code and paste it back into your terminal you ran `funcx-endpoint configure pyhf` in where you're asked to "Please Paste your Auth Code Below"

Upon success you'll see

```
A default profile has been create for <pyhf> at /home/jovyan/.funcx/pyhf/config.py
Configure this file and try restarting with:
    $ funcx-endpoint start pyhf
```

> If you're following along you'll want to switch over to a terminal to make this part easier

In [None]:
! echo "funcx-endpoint configure pyhf"
! ls -l ~/.funcx/pyhf/config.py

In [None]:
! cat ~/.funcx/pyhf/config.py

We'll go a step further though and use a prepared `funcX` configuration found under `funcX/binder-config.py`.

In [None]:
! cp funcX/binder-config.py ~/.funcx/pyhf/config.py

and look at it again

In [None]:
! cat ~/.funcx/pyhf/config.py

Let's break down some relevant information from Parsl

* [`block`](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#blocks): Basic unit of resources acquired from a provider
* [`max_blocks`](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#elasticity): Maximum number of blocks that can be active per executor
* [`nodes_per_block`](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#blocks): Number of nodes requested per block
* [`parallelism`](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#parallelism): Ratio of task execution capacity to the sum of running tasks and available tasks

And let's quickly consider this example from the [Parsl docs](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#configuration) that `funcX` extends

```python
from parsl.config import Config
from libsubmit.providers.local.local import Local
from parsl.executors import HighThroughputExecutor

config = Config(
    executors=[
        HighThroughputExecutor(
            label='local_htex',
            workers_per_node=2,
            provider=Local(
                min_blocks=1,
                init_blocks=1,
                max_blocks=2,
                nodes_per_block=1,
                parallelism=0.5
            )
        )
    ]
)
```

[![parsl_parallelism](figures/parsl_parallelism.gif)](https://parsl.readthedocs.io/en/1.1.0/userguide/execution.html#configuration)

<br><br>
**What's happening in the GIF above**:

* `9` taks to compute
* Tasks are allocated to the first block until its task_capacity (here `4` tasks) reached
* Task `5`: First block full and `5/9` > `parallelism` so Parsl provisions a new block for executing the remaining tasks

Okay, now we'll start the endpoint

In [None]:
! funcx-endpoint start pyhf

and you can verify that it is registered and up

In [None]:
! funcx-endpoint list

**N.B.**: You'll want to take careful note of this `uuid` as this is the endpoint ID that you'll have your `funcX` code use.

A good way to deal with this is to save it in a `endpoint_id.txt` file that is ignored from version control.

In [None]:
! funcx-endpoint list | grep pyhf | awk '{print $(NF-1)}' > endpoint_id.txt
! cat endpoint_id.txt

## Using funcX for (Fitting) Functions as a Service (FaaS)

To keep this as easy as possible to follow along with, we've done something that isn't very practical: We setup our `funcx` endpoint locally (this is probably not where your dedicate compute will be, but for demonstration purposes we'll pretend that our `funcx-endpoint` lives on another machine/cluster someplace).

### Prepare Functions (On your local submission machine)

Locally we can now write our code that we'd like `funcX` to run for us **as functions** (remember FaaS)

In [None]:
def simple_example(backend="numpy", test_poi=1.0):
    import time

    import pyhf

    pyhf.set_backend(backend)

    tick = time.time()
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
    )

    data = model.expected_data(model.config.suggested_init())
    return {
        "cls_obs": float(
            pyhf.infer.hypotest(test_poi, data, model, test_stat="qtilde")
        ),
        "fit-time": time.time() - tick,
    }

The return is just a `dict` of the observed $\mathrm{CL}_{s}$ value and the time to fit

In [None]:
simple_example()

we can then initalize our local `funcX` client and **register** our function with it for execution

In [None]:
# Initialize funcX client
fxc = FuncXClient()
fxc.max_requests = 200

In [None]:
# register functions
infer_func = fxc.register_function(simple_example)

With our functions registered we can now have the `funcx` client serialize and send them to the `funcx` endpoint (which can be on any machine anywhere!) to be sent out to the `funcx` worker nodes on the execution machine

In [None]:
with open("endpoint_id.txt") as endpoint_file:
    pyhf_endpoint = str(endpoint_file.read().rstrip())

In [None]:
# Serialize and send to funcX ednpoint to run
task_id = fxc.run(
    backend="numpy", test_poi=1.0, endpoint_id=pyhf_endpoint, function_id=infer_func
)

While that runs, we can now start to send queries from our local submission machine to the (remote) execution machine and check to see if the tasks we've submitted have finished execution

In [None]:
# wait for it to run. Here this is super fast, but you'd want to setup a loop to check periodically
sleep(1)

In [None]:
# retrieve output
result = fxc.get_result(task_id)
result

In [None]:
# Run a different test POI
task_id = fxc.run(
    backend="numpy", test_poi=2.0, endpoint_id=pyhf_endpoint, function_id=infer_func
)
sleep(0.01)
try:
    result = fxc.get_result(task_id)
except Exception as excep:
    print(f"inference: {excep}")
    sleep(2)

result = fxc.get_result(task_id)
result

## funcX endpoint shutdown (On execution machine)

To stop a funcX endpoint from running simple use the `funcx-endpoint` CLI API again

In [None]:
! funcx-endpoint stop pyhf
! funcx-endpoint list