# JAX

## Grad

Jax's tracer can compute gradients! Let's try:

$$
y = x^3 + x^2 + x \\
y' = 3x^2 + 2x + 1 \\
y'' = 6x + 2 \\
$$

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
def f(x):
 return x**3 + x**2 + x

In [None]:
f(1.0)

In [None]:
fp = jax.grad(f)

In [None]:
fp(1.0)

In [None]:
fpp = jax.grad(fp)

In [None]:
fpp(1.0)

## Tracer limitations


Let's watch the tracer:

In [None]:
def f(x):
 print(f"{x = }")
 y = x**2
 print(f"{y = }")
 return y

In [None]:
f_jit = jax.jit(f)

In [None]:
f_jit(2)

In [None]:
f_jit(2)

Notice that the Python code runs once, and something that is not an integer at all is being passed in. From then on, the function doesn't run the Python code anymore. Well, as long as you use the same input types / shapes:

In [None]:
f_jit(1.0)

In [None]:
f_jit(1.0)

In [None]:
f_jit(1)

You can't trace through flow control that depends on the tracers, or dynamically change the shape of the array:

In [None]:
@jax.jit
def broken(x):
 if x == 3:
 return x**3
 return x

In [None]:
broken(2)

## Jax is functional

Unlike NumPY, Jax arrays are immutable. You also should write pure functions (ones without side effects / state).

For example, you can't do an in-place set:

In [None]:
jarr = jnp.zeros((3, 3))
jarr[np.diag(np.ones(3, dtype=bool))] = 1
jarr

Jax provides a trick to make this easy to do while avoiding an in-place mutation:

In [None]:
j1 = jnp.zeros((3, 3))
j2 = j1.at[np.diag(np.ones(3, dtype=bool))].set(1)
j2

## Further reading

See the Jax docs!

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html