{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp vision_augmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "from fastai.data.all import *\n", "from fastMONAI.vision_core import *\n", "import torchio as tio" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data augmentation\n", ">" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transforms wrapper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class CustomDictTransform(ItemTransform):\n", " \"\"\"A class that serves as a wrapper to perform an identical transformation on both \n", " the image and the target (if it's a mask).\n", " \"\"\"\n", " \n", " split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.\n", "\n", " def __init__(self, aug):\n", " \"\"\"Constructs CustomDictTransform object.\n", "\n", " Args:\n", " aug (Callable): Function to apply augmentation on the image.\n", " \"\"\"\n", " self.aug = aug\n", "\n", " def encodes(self, x):\n", " \"\"\"\n", " Applies the stored transformation to an image, and the same random transformation \n", " to the target if it is a mask. If the target is not a mask, it returns the target as is.\n", "\n", " Args:\n", " x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the \n", " image and the target.\n", "\n", " Returns:\n", " Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target. \n", " If the target is a mask, it's transformed identically to the image. If the target \n", " is not a mask, the original target is returned.\n", " \"\"\"\n", " img, y_true = x\n", "\n", " if isinstance(y_true, (MedMask)):\n", " aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=MedImage.affine_matrix), \n", " mask=tio.LabelMap(tensor=y_true, affine=MedImage.affine_matrix)))\n", " return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)\n", "\n", " aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))\n", " return MedImage.create(aug['img'].data), y_true\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vanilla transforms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):\n", " #TODO:refactorize\n", " pad_or_crop = tio.CropOrPad(target_shape=target_shape, padding_mode=padding_mode, mask_name=mask_name)\n", " return dtype(pad_or_crop(o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export \n", "class PadOrCrop(DisplayedTransform):\n", " \"\"\"Resize image using TorchIO `CropOrPad`.\"\"\"\n", " \n", " order = 0\n", "\n", " def __init__(self, size, padding_mode=0, mask_name=None):\n", " if not is_listy(size): \n", " size = [size, size, size]\n", " self.pad_or_crop = tio.CropOrPad(target_shape=size,\n", " padding_mode=padding_mode, \n", " mask_name=mask_name)\n", "\n", " def encodes(self, o: (MedImage, MedMask)):\n", " return type(o)(self.pad_or_crop(o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "class ZNormalization(DisplayedTransform):\n", " \"\"\"Apply TorchIO `ZNormalization`.\"\"\"\n", "\n", " order = 0\n", "\n", " def __init__(self, masking_method=None, channel_wise=True):\n", " self.z_normalization = tio.ZNormalization(masking_method=masking_method)\n", " self.channel_wise = channel_wise\n", "\n", " def encodes(self, o: MedImage):\n", " if self.channel_wise:\n", " o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n", " else: o = self.z_normalization(o) \n", "\n", " return MedImage.create(o)\n", "\n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class BraTSMaskConverter(DisplayedTransform):\n", " '''Convert BraTS masks.'''\n", "\n", " order=1\n", "\n", " def encodes(self, o:(MedImage)): return o\n", "\n", " def encodes(self, o:(MedMask)):\n", " o = torch.where(o==4, 3., o)\n", " return MedMask.create(o)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class BinaryConverter(DisplayedTransform):\n", " '''Convert to binary mask.'''\n", "\n", " order=1\n", "\n", " def encodes(self, o: MedImage): \n", " return o\n", "\n", " def encodes(self, o: MedMask):\n", " o = torch.where(o>0, 1., 0)\n", " return MedMask.create(o)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomGhosting(DisplayedTransform):\n", " \"\"\"Apply TorchIO `RandomGhosting`.\"\"\"\n", " \n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, intensity=(0.5, 1), p=0.5):\n", " self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)\n", "\n", " def encodes(self, o: MedImage):\n", " return MedImage.create(self.add_ghosts(o))\n", "\n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomSpike(DisplayedTransform):\n", " '''Apply TorchIO `RandomSpike`.'''\n", " \n", " split_idx,order=0,1\n", "\n", " def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):\n", " self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)\n", "\n", " def encodes(self, o:MedImage): \n", " return MedImage.create(self.add_spikes(o))\n", " \n", " def encodes(self, o:MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomNoise(DisplayedTransform):\n", " '''Apply TorchIO `RandomNoise`.'''\n", "\n", " split_idx,order=0,1\n", "\n", " def __init__(self, mean=0, std=(0, 0.25), p=0.5):\n", " self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_noise(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomBiasField(DisplayedTransform):\n", " '''Apply TorchIO `RandomBiasField`.'''\n", "\n", " split_idx,order=0,1\n", "\n", " def __init__(self, coefficients=0.5, order=3, p=0.5):\n", " self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_biasfield(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomBlur(DisplayedTransform):\n", " '''Apply TorchIO `RandomBiasField`.'''\n", "\n", " split_idx,order=0,1\n", "\n", " def __init__(self, std=(0, 2), p=0.5):\n", " self.add_blur = tio.RandomBlur(std=std, p=p)\n", " \n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_blur(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomGamma(DisplayedTransform):\n", " '''Apply TorchIO `RandomGamma`.'''\n", "\n", "\n", " split_idx,order=0,1\n", "\n", " def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):\n", " self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_gamma(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomMotion(DisplayedTransform):\n", " \"\"\"Apply TorchIO `RandomMotion`.\"\"\"\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(\n", " self, \n", " degrees=10, \n", " translation=10, \n", " num_transforms=2, \n", " image_interpolation='linear', \n", " p=0.5\n", " ):\n", " self.add_motion = tio.RandomMotion(\n", " degrees=degrees, \n", " translation=translation, \n", " num_transforms=num_transforms, \n", " image_interpolation=image_interpolation, \n", " p=p\n", " )\n", "\n", " def encodes(self, o: MedImage):\n", " return MedImage.create(self.add_motion(o))\n", "\n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dictionary transforms " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomElasticDeformation(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomElasticDeformation`.\"\"\"\n", "\n", " def __init__(self, num_control_points=7, max_displacement=7.5,\n", " image_interpolation='linear', p=0.5):\n", " \n", " super().__init__(tio.RandomElasticDeformation(\n", " num_control_points=num_control_points,\n", " max_displacement=max_displacement,\n", " image_interpolation=image_interpolation,\n", " p=p))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomAffine(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomAffine`.\"\"\"\n", "\n", " def __init__(self, scales=0, degrees=10, translation=0, isotropic=False,\n", " image_interpolation='linear', default_pad_value=0., p=0.5):\n", " \n", " super().__init__(tio.RandomAffine(\n", " scales=scales,\n", " degrees=degrees,\n", " translation=translation,\n", " isotropic=isotropic,\n", " image_interpolation=image_interpolation,\n", " default_pad_value=default_pad_value,\n", " p=p))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomFlip(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomFlip`.\"\"\"\n", "\n", " def __init__(self, axes='LR', p=0.5):\n", " super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class OneOf(CustomDictTransform):\n", " \"\"\"Apply only one of the given transforms using TorchIO `OneOf`.\"\"\"\n", "\n", " def __init__(self, transform_dict, p=1):\n", " super().__init__(tio.OneOf(transform_dict, p=p))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "#TorchIO has their own test methods: https://github.com/fepegar/torchio/tree/main/tests" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }