apiVersion: v1 kind: Pod metadata: name: jax-non-privileged-multi-gpu-pod spec: restartPolicy: Never hostIPC: true containers: - name: jax-multi-gpu-container image: rocm/jax:latest command: - python3 - "-c" - | import jax import jax.numpy as jnp print('Available JAX devices:', jax.devices()) # Create data to process in parallel n_devices = jax.device_count() print(f'Number of devices: {n_devices}') # Create matrices for each device x = jnp.ones((n_devices, 1000, 1000)) y = jnp.ones((n_devices, 1000, 1000)) # Define computation to run in parallel @jax.pmap def parallel_matmul(a, b): return jnp.matmul(a, b) # Run computation in parallel across GPUs result = parallel_matmul(x, y) print(f'Parallel computation complete across {n_devices} devices') print('Result shape:', result.shape) print('Device mapping:', jax.devices()) " resources: limits: amd.com/gpu: 2 # Request 2 AMD GPUs securityContext: privileged: false allowPrivilegeEscalation: false seccompProfile: type: Unconfined