{
"cells": [
{
"cell_type": "markdown",
"id": "03540da9-b942-49bc-8c86-54006d833652",
"metadata": {},
"source": [
"# Pytorch interoperability\n",
"In this notebook we test interoperabilty of Pytorch tensors with clesperanto."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "026234d9-43ea-4d00-9064-d3125feb65b4",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d33e8b72-7b22-47e5-b3eb-1265b2d64a0e",
"metadata": {},
"outputs": [],
"source": [
"import pyclesperanto_prototype as cle"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "23a7f1a9-0c59-4ecf-ae3e-e4dd27bdf298",
"metadata": {},
"outputs": [],
"source": [
"tensor = torch.zeros((10, 10))\n",
"tensor[1:3, 1:3] = 1\n",
"tensor[5:7, 5:7] = 1"
]
},
{
"cell_type": "markdown",
"id": "21754dfd-4d0c-4acf-a4c0-5cc685298914",
"metadata": {},
"source": [
"## Pushing tensors to the GPU"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fe927938-a952-4203-946d-c7564bcb1c3f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"\n",
" \n",
" | \n",
"\n",
"cle._ image \n",
"\n",
"shape | (10, 10) | \n",
"dtype | float32 | \n",
"size | 400.0 B | \n",
"min | 0.0 | max | 1.0 | \n",
" \n",
" \n",
" | \n",
"
\n",
"
"
],
"text/plain": [
"cl.OCLArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cle_tensor = cle.push(tensor)\n",
"cle_tensor"
]
},
{
"cell_type": "markdown",
"id": "9f982a9c-76f2-4bd5-a8d2-65c63ace5ba4",
"metadata": {},
"source": [
"... turns the tensor into an OpenCL-Array"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "63499a9d-9a75-4776-a9bb-ab220b7b2ad4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"pyclesperanto_prototype._tier0._pycl.OCLArray"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(cle_tensor)"
]
},
{
"cell_type": "markdown",
"id": "2823d95a-00cf-4b7a-abbe-836db1e191b8",
"metadata": {},
"source": [
"## Passing tensors as arguments\n",
"You can also just pass a tensor as argument to clesperanto functions. The tensor will be pushed to the GPU implicitly anyway."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a6c62f23-ee53-488b-b986-fced10de5dfe",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"\n",
" \n",
" | \n",
"\n",
"cle._ image \n",
"\n",
"shape | (10, 10) | \n",
"dtype | uint32 | \n",
"size | 400.0 B | \n",
"min | 0.0 | max | 2.0 | \n",
" \n",
"\n",
" | \n",
"
\n",
"
"
],
"text/plain": [
"cl.OCLArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cle_labels = cle.label(tensor)\n",
"cle_labels"
]
},
{
"cell_type": "markdown",
"id": "b9d36430-7af9-4879-a449-5aa52257eec7",
"metadata": {},
"source": [
"## Converting results back to a tensor\n",
"To turn the OpenCL image into a tensor, you need to call the `get()` function. Furthermore, in case of label images, you need to convert them into a pixel type that is accepted by pytorch, for example signed 32-bit integer."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b8b88eec-ebe2-4403-aa66-ac09e8fb3318",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels_tensor = torch.tensor(cle_labels.astype(np.int32).get())\n",
"type(labels_tensor)"
]
},
{
"cell_type": "markdown",
"id": "257afd2a-a7c9-488f-9b9f-4eeba6b960bc",
"metadata": {},
"source": [
"## GPU Tensors \n",
"Tensors that are stored on the GPU and managed by Pytorch need to be transferred back to the CPU memory before pushing them back to OpenCL/GPU memory. This happens transparently under the hood but may cause performance leaks due to memory transfer times."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bede46b0-1436-42d6-ba8e-82750702c7b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c5fa3935-6f60-43a6-a201-267f1428ba09",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda:0\n"
]
}
],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "22f846e3-5bf1-4534-8c55-409e03d7b7de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor.is_cuda"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "61d82b73-2501-4479-b6d3-b86f04794aa1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cuda_tensor = tensor.to(device)\n",
"cuda_tensor.is_cuda"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c6bb0c5f-e16b-4911-863c-290c44375291",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"\n",
" \n",
" | \n",
"\n",
"cle._ image \n",
"\n",
"shape | (10, 10) | \n",
"dtype | float32 | \n",
"size | 400.0 B | \n",
"min | 0.0 | max | 1.0 | \n",
" \n",
" \n",
" | \n",
"
\n",
"
"
],
"text/plain": [
"cl.OCLArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cle.push(cuda_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "5f43e6dd-87e0-4fe8-9357-873adcc0e9f7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"\n",
" \n",
" | \n",
"\n",
"cle._ image \n",
"\n",
"shape | (10, 10) | \n",
"dtype | uint32 | \n",
"size | 400.0 B | \n",
"min | 0.0 | max | 2.0 | \n",
" \n",
"\n",
" | \n",
"
\n",
"
"
],
"text/plain": [
"cl.OCLArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 2, 2, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cle.label(cuda_tensor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "925044c5-26d9-4379-b167-4009f7ca167d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}