apiVersion: v1
kind: Pod
metadata:
  name: jax-non-privileged-multi-gpu-pod
spec:
  restartPolicy: Never
  hostIPC: true
  containers:
  - name: jax-multi-gpu-container
    image: rocm/jax:latest
    command:
    - python3
    - "-c"
    - |
      import jax
      import jax.numpy as jnp
      print('Available JAX devices:', jax.devices())

      # Create data to process in parallel
      n_devices = jax.device_count()
      print(f'Number of devices: {n_devices}')

      # Create matrices for each device
      x = jnp.ones((n_devices, 1000, 1000))
      y = jnp.ones((n_devices, 1000, 1000))

      # Define computation to run in parallel
      @jax.pmap
      def parallel_matmul(a, b):
          return jnp.matmul(a, b)

      # Run computation in parallel across GPUs
      result = parallel_matmul(x, y)

      print(f'Parallel computation complete across {n_devices} devices')
      print('Result shape:', result.shape)
      print('Device mapping:', jax.devices())
      "
    resources:
      limits:
        amd.com/gpu: 2  # Request 2 AMD GPUs
    securityContext:
      privileged: false
      allowPrivilegeEscalation: false
      seccompProfile:
        type: Unconfined