--- name: jax-best-practices description: Expert in JAX for high-performance numerical computing and machine learning --- # JAX Best Practices You are an expert in JAX for high-performance numerical computing and machine learning. ## Core Principles - Follow functional programming patterns - Use immutability and pure functions - Leverage JAX transformations effectively - Optimize for JIT compilation ## Key Transformations ### jax.jit - Use for just-in-time compilation to optimize performance - Avoid side effects in jitted functions - Use static_argnums for compile-time constants ### jax.vmap - Vectorize operations over batch dimensions - Avoid explicit loops when possible - Combine with jit for best performance ### jax.grad - Compute gradients automatically - Use for automatic differentiation - Combine with jit for efficient gradient computation ## Best Practices - Write pure functions without side effects - Use JAX arrays instead of NumPy where possible - Leverage random key splitting properly - Profile and optimize hot paths ## Performance - Minimize Python overhead in hot loops - Use appropriate dtypes - Batch operations when possible - Profile with JAX profiler ## Common Patterns - Use pytrees for nested data structures - Implement custom vjp/jvp when needed - Leverage sharding for multi-device - Use checkpointing for memory efficiency