{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0f31f7d1", "metadata": {}, "outputs": [], "source": [ "#| default_exp vision_augmentation" ] }, { "cell_type": "code", "execution_count": null, "id": "8ef813c4", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "id": "2d6694aa", "metadata": {}, "outputs": [], "source": "#| export\nfrom fastai.data.all import *\nfrom fastMONAI.vision_core import *\nimport torchio as tio\nfrom monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity" }, { "cell_type": "markdown", "id": "0e00fb56", "metadata": {}, "source": [ "# Data augmentation\n", ">" ] }, { "cell_type": "markdown", "id": "9612a3d7", "metadata": {}, "source": [ "## Transforms wrapper" ] }, { "cell_type": "code", "execution_count": null, "id": "7c2bacd9", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class CustomDictTransform(ItemTransform):\n", " \"\"\"A class that serves as a wrapper to perform an identical transformation on both\n", " the image and the target (if it's a mask).\n", " \"\"\"\n", "\n", " split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.\n", "\n", " def __init__(self, aug):\n", " \"\"\"Constructs CustomDictTransform object.\n", "\n", " Args:\n", " aug (Callable): Function to apply augmentation on the image.\n", " \"\"\"\n", " self.aug = aug\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\n", "\n", " This property enables using fastMONAI wrappers in patch-based workflows\n", " where raw TorchIO transforms are needed for tio.Compose().\n", " \"\"\"\n", " return self.aug\n", "\n", " def encodes(self, x):\n", " \"\"\"\n", " Applies the stored transformation to an image, and the same random transformation\n", " to the target if it is a mask. If the target is not a mask, it returns the target as is.\n", "\n", " Args:\n", " x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the\n", " image and the target.\n", "\n", " Returns:\n", " Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.\n", " If the target is a mask, it's transformed identically to the image. If the target\n", " is not a mask, the original target is returned.\n", " \"\"\"\n", " img, y_true = x\n", " \n", " # Use identity affine if MedImage.affine_matrix is not set\n", " affine = MedImage.affine_matrix if MedImage.affine_matrix is not None else np.eye(4)\n", "\n", " if isinstance(y_true, (MedMask)):\n", " aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=affine),\n", " mask=tio.LabelMap(tensor=y_true, affine=affine)))\n", " return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)\n", "\n", " aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))\n", " return MedImage.create(aug['img'].data), y_true" ] }, { "cell_type": "markdown", "id": "73732ad1", "metadata": {}, "source": [ "## Vanilla transforms" ] }, { "cell_type": "code", "execution_count": null, "id": "d7e1fdf2", "metadata": {}, "outputs": [], "source": [ "#| export\n", "def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):\n", " #TODO:refactorize\n", " pad_or_crop = tio.CropOrPad(target_shape=target_shape, padding_mode=padding_mode, mask_name=mask_name)\n", " return dtype(pad_or_crop(o))" ] }, { "cell_type": "code", "execution_count": null, "id": "9c75bb84", "metadata": {}, "outputs": [], "source": [ "#| export \n", "class PadOrCrop(DisplayedTransform):\n", " \"\"\"Resize image using TorchIO `CropOrPad`.\"\"\"\n", " \n", " order = 0\n", "\n", " def __init__(self, size, padding_mode=0, mask_name=None):\n", " if not is_listy(size): \n", " size = [size, size, size]\n", " self.pad_or_crop = tio.CropOrPad(target_shape=size,\n", " padding_mode=padding_mode, \n", " mask_name=mask_name)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.pad_or_crop\n", "\n", " def encodes(self, o: (MedImage, MedMask)):\n", " return type(o)(self.pad_or_crop(o))" ] }, { "cell_type": "code", "execution_count": null, "id": "ca95a690", "metadata": {}, "outputs": [], "source": [ "# | export\n", "class ZNormalization(DisplayedTransform):\n", " \"\"\"Apply TorchIO `ZNormalization`.\"\"\"\n", "\n", " order = 0\n", "\n", " def __init__(self, masking_method=None, channel_wise=True):\n", " self.z_normalization = tio.ZNormalization(masking_method=masking_method)\n", " self.channel_wise = channel_wise\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.z_normalization\n", "\n", " def encodes(self, o: MedImage):\n", " try:\n", " if self.channel_wise:\n", " o = torch.stack([self.z_normalization(c[None])[0] for c in o])\n", " else: \n", " o = self.z_normalization(o)\n", " except RuntimeError as e:\n", " if \"Standard deviation is 0\" in str(e):\n", " # Calculate mean for debugging information\n", " mean = float(o.mean())\n", " \n", " error_msg = (\n", " f\"Standard deviation is 0 for image (mean={mean:.3f}).\\n\"\n", " f\"This indicates uniform pixel values.\\n\\n\"\n", " f\"Possible causes:\\n\"\n", " f\"• Corrupted or blank image\\n\"\n", " f\"• Oversaturated regions\\n\" \n", " f\"• Background-only regions\\n\"\n", " f\"• All-zero mask being processed as image\\n\\n\"\n", " f\"Suggested solutions:\\n\"\n", " f\"• Check image quality and acquisition\\n\"\n", " f\"• Verify image vs mask data loading\"\n", " )\n", " raise RuntimeError(error_msg) from e\n", "\n", " return MedImage.create(o)\n", "\n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "ee9cf2a3", "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RescaleIntensity(DisplayedTransform):\n", " \"\"\"Apply TorchIO RescaleIntensity for robust intensity scaling.\n", " \n", " Args:\n", " out_min_max (tuple[float, float]): Output intensity range (min, max)\n", " in_min_max (tuple[float, float]): Input intensity range (min, max) \n", " \n", " Example for CT images:\n", " # Normalize CT from air (-1000 HU) to bone (1000 HU) into range (-1, 1)\n", " transform = RescaleIntensity(out_min_max=(-1, 1), in_min_max=(-1000, 1000))\n", " \"\"\"\n", " \n", " order = 0\n", " \n", " def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):\n", " self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.rescale\n", "\n", " def encodes(self, o: MedImage):\n", " return MedImage.create(self.rescale(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "341dc45f", "metadata": {}, "outputs": [], "source": "# | export\nclass _TioNormalizeIntensity(tio.IntensityTransform):\n \"\"\"TorchIO-compatible wrapper for MONAI NormalizeIntensity.\n\n Enables NormalizeIntensity to work in patch-based workflows\n where raw TorchIO transforms are needed for tio.Compose().\n \"\"\"\n\n def __init__(self, nonzero=True, channel_wise=True,\n subtrahend=None, divisor=None, **kwargs):\n super().__init__(p=1, **kwargs)\n self.channel_wise = channel_wise\n self.transform = MonaiNormalizeIntensity(\n nonzero=nonzero,\n channel_wise=False,\n subtrahend=subtrahend,\n divisor=divisor\n )\n\n def apply_transform(self, subject):\n for image in self.get_images(subject):\n data = image.data\n if self.channel_wise:\n result = torch.stack([self.transform(c[None])[0] for c in data])\n else:\n result = torch.Tensor(self.transform(data))\n image.set_data(result)\n return subject\n\n\nclass NormalizeIntensity(DisplayedTransform):\n \"\"\"Apply MONAI NormalizeIntensity.\n \n Args:\n nonzero (bool): Only normalize non-zero values (default: True)\n channel_wise (bool): Apply normalization per channel (default: True)\n subtrahend (float, optional): Value to subtract \n divisor (float, optional): Value to divide by\n \"\"\"\n \n order = 0\n \n def __init__(self, nonzero: bool = True, channel_wise: bool = True, \n subtrahend: float = None, divisor: float = None):\n self.nonzero = nonzero\n self.channel_wise = channel_wise\n self.subtrahend = subtrahend\n self.divisor = divisor\n \n self.transform = MonaiNormalizeIntensity(\n nonzero=nonzero,\n channel_wise=False, # Always 'False', we handle channel-wise manually\n subtrahend=subtrahend,\n divisor=divisor\n )\n self._tio_normalize = _TioNormalizeIntensity(\n nonzero=nonzero,\n channel_wise=channel_wise,\n subtrahend=subtrahend,\n divisor=divisor\n )\n\n @property\n def tio_transform(self):\n \"\"\"Return TorchIO-compatible transform for patch-based workflows.\"\"\"\n return self._tio_normalize\n \n def encodes(self, o: MedImage):\n if self.channel_wise:\n result = torch.stack([self.transform(c[None])[0] for c in o])\n else:\n result = torch.Tensor(self.transform(o))\n \n return MedImage.create(result)\n \n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "id": "be431cc2", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class BraTSMaskConverter(DisplayedTransform):\n", " '''Convert BraTS masks.'''\n", "\n", " order=1\n", "\n", " def encodes(self, o:(MedImage)): return o\n", "\n", " def encodes(self, o:(MedMask)):\n", " o = torch.where(o==4, 3., o)\n", " return MedMask.create(o)" ] }, { "cell_type": "code", "execution_count": null, "id": "f63701cb", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class BinaryConverter(DisplayedTransform):\n", " '''Convert to binary mask.'''\n", "\n", " order=1\n", "\n", " def encodes(self, o: MedImage): \n", " return o\n", "\n", " def encodes(self, o: MedMask):\n", " o = torch.where(o>0, 1., 0)\n", " return MedMask.create(o)" ] }, { "cell_type": "code", "execution_count": null, "id": "6b5a795f", "metadata": {}, "outputs": [], "source": "#| export\nclass RandomGhosting(DisplayedTransform):\n \"\"\"Apply TorchIO `RandomGhosting`.\"\"\"\n \n split_idx, order = 0, 1\n\n def __init__(self, intensity=(0.5, 1), p=0.5):\n self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.add_ghosts\n\n def encodes(self, o: MedImage):\n return MedImage.create(self.add_ghosts(o))\n\n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "id": "bf9ce6fc", "metadata": {}, "outputs": [], "source": "#| export\nclass RandomSpike(DisplayedTransform):\n '''Apply TorchIO `RandomSpike`.'''\n \n split_idx, order = 0, 1\n\n def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):\n self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.add_spikes\n\n def encodes(self, o: MedImage): \n return MedImage.create(self.add_spikes(o))\n \n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "id": "a48e315a", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomNoise(DisplayedTransform):\n", " '''Apply TorchIO `RandomNoise`.'''\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, mean=0, std=(0, 0.25), p=0.5):\n", " self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.add_noise\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_noise(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "6623a796", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomBiasField(DisplayedTransform):\n", " '''Apply TorchIO `RandomBiasField`.'''\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, coefficients=0.5, order=3, p=0.5):\n", " self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.add_biasfield\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_biasfield(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "7e7d29c1", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomBlur(DisplayedTransform):\n", " '''Apply TorchIO `RandomBlur`.'''\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, std=(0, 2), p=0.5):\n", " self.add_blur = tio.RandomBlur(std=std, p=p)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.add_blur\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_blur(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "cd45c2c9", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomGamma(DisplayedTransform):\n", " '''Apply TorchIO `RandomGamma`.'''\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):\n", " self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return the underlying TorchIO transform.\"\"\"\n", " return self.add_gamma\n", "\n", " def encodes(self, o: MedImage): \n", " return MedImage.create(self.add_gamma(o))\n", " \n", " def encodes(self, o: MedMask):\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "id": "4116aa33", "metadata": {}, "outputs": [], "source": "#| export\nclass _TioRandomIntensityScale(tio.IntensityTransform):\n \"\"\"TorchIO-compatible RandomIntensityScale for patch-based workflows.\n\n Randomly scales image intensities by a multiplicative factor.\n Only applies to ScalarImage keys (not LabelMap), which is the standard\n TorchIO IntensityTransform behavior.\n \"\"\"\n\n def __init__(self, scale_range=(0.5, 2.0), p=0.5, **kwargs):\n super().__init__(p=p, **kwargs)\n self.scale_range = scale_range\n\n def apply_transform(self, subject):\n for image in self.get_images(subject):\n scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()\n image.set_data(image.data * scale)\n return subject\n\n\nclass RandomIntensityScale(DisplayedTransform):\n \"\"\"Randomly scale image intensities by a multiplicative factor.\n\n Useful for domain generalization across different acquisition protocols\n with varying intensity ranges.\n\n Args:\n scale_range (tuple[float, float]): Range of scale factors (min, max).\n Values > 1 increase intensity, < 1 decrease intensity.\n p (float): Probability of applying the transform (default: 0.5)\n\n Example:\n # Scale intensities randomly between 0.5x and 2.0x\n transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)\n \"\"\"\n\n split_idx, order = 0, 1\n\n def __init__(self, scale_range: tuple[float, float] = (0.5, 2.0), p: float = 0.5):\n self.scale_range = scale_range\n self.p = p\n self._tio_intensity_scale = _TioRandomIntensityScale(\n scale_range=scale_range, p=p\n )\n\n @property\n def tio_transform(self):\n \"\"\"Return TorchIO-compatible transform for patch-based workflows.\"\"\"\n return self._tio_intensity_scale\n\n def encodes(self, o: MedImage):\n if torch.rand(1).item() > self.p:\n return o\n scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()\n return MedImage.create(o * scale)\n\n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "id": "bf89ffbe", "metadata": {}, "outputs": [], "source": "#| export\nclass RandomMotion(DisplayedTransform):\n \"\"\"Apply TorchIO `RandomMotion`.\"\"\"\n\n split_idx, order = 0, 1\n\n def __init__(\n self, \n degrees=10, \n translation=10, \n num_transforms=2, \n image_interpolation='linear', \n p=0.5\n ):\n self.add_motion = tio.RandomMotion(\n degrees=degrees, \n translation=translation, \n num_transforms=num_transforms, \n image_interpolation=image_interpolation, \n p=p\n )\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.add_motion\n\n def encodes(self, o: MedImage):\n return MedImage.create(self.add_motion(o))\n\n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "id": "e7ea6486", "metadata": {}, "outputs": [], "source": [ "#| export\n", "def _create_ellipsoid_mask(shape, center, radii):\n", " \"\"\"Create a 3D ellipsoid mask.\n", "\n", " Args:\n", " shape: (D, H, W) shape of the volume\n", " center: (z, y, x) center of ellipsoid\n", " radii: (rz, ry, rx) radii along each axis\n", "\n", " Returns:\n", " Boolean mask where True = inside ellipsoid\n", " \"\"\"\n", " z, y, x = torch.meshgrid(\n", " torch.arange(shape[0]),\n", " torch.arange(shape[1]),\n", " torch.arange(shape[2]),\n", " indexing='ij'\n", " )\n", " dist = ((z - center[0]) / radii[0]) ** 2 + \\\n", " ((y - center[1]) / radii[1]) ** 2 + \\\n", " ((x - center[2]) / radii[2]) ** 2\n", " return dist <= 1.0" ] }, { "cell_type": "code", "execution_count": null, "id": "935c13e7", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class _TioRandomCutout(tio.IntensityTransform):\n", " \"\"\"TorchIO-compatible RandomCutout for patch-based workflows.\n", "\n", " When mask_only=True, cutouts only affect voxels where the mask is positive.\n", " The mask should be available in the Subject as 'mask' key.\n", " \"\"\"\n", "\n", " def __init__(self, holes=1, spatial_size=8, fill_value=None,\n", " max_holes=None, max_spatial_size=None, mask_only=True, p=0.2, **kwargs):\n", " super().__init__(p=p, **kwargs)\n", " self.holes = holes\n", " self.spatial_size = spatial_size\n", " self.fill_value = fill_value\n", " self.max_holes = max_holes\n", " self.max_spatial_size = max_spatial_size\n", " self.mask_only = mask_only\n", "\n", " def _apply_cutout(self, data, fill_val, mask_tensor=None):\n", " \"\"\"Apply spherical cutout(s) to a tensor.\n", "\n", " Args:\n", " data: Input tensor of shape (C, D, H, W)\n", " fill_val: Value to fill cutout regions\n", " mask_tensor: Optional mask tensor for mask-only cutouts\n", "\n", " Returns:\n", " Tensor with cutout applied\n", " \"\"\"\n", " result = data.clone()\n", " n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()\n", "\n", " spatial_shape = data.shape[1:] # (D, H, W)\n", " min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]\n", " max_size = self.max_spatial_size or self.spatial_size\n", " max_size = max_size if isinstance(max_size, int) else max_size[0]\n", "\n", " for _ in range(n_holes):\n", " # Random size for this hole\n", " size = torch.randint(min_size, max_size + 1, (3,))\n", " radii = size.float() / 2\n", "\n", " # Random center (ensure hole fits in volume)\n", " center = [\n", " torch.randint(int(radii[i].item()),\n", " max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),\n", " (1,)).item()\n", " for i in range(3)\n", " ]\n", "\n", " ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii)\n", "\n", " if self.mask_only and mask_tensor is not None:\n", " # INTERSECT with tumor mask - only affect tumor voxels\n", " tumor_mask = mask_tensor[0] > 0\n", " cutout_region = ellipsoid & tumor_mask\n", " else:\n", " cutout_region = ellipsoid\n", "\n", " result[:, cutout_region] = fill_val\n", "\n", " return result\n", "\n", " def apply_transform(self, subject):\n", " # Get mask if available for mask-only cutouts\n", " mask_tensor = None\n", " if self.mask_only and 'mask' in subject:\n", " mask_tensor = subject['mask'].data\n", "\n", " # Skip if mask is empty\n", " if mask_tensor is not None and not (mask_tensor > 0).any():\n", " return subject\n", "\n", " for image in self.get_images(subject):\n", " data = image.data\n", " fill_val = self.fill_value if self.fill_value is not None else float(data.min())\n", " result = self._apply_cutout(data, fill_val, mask_tensor)\n", " image.set_data(result)\n", " return subject" ] }, { "cell_type": "code", "execution_count": null, "id": "bc99f04d", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class RandomCutout(ItemTransform):\n", " \"\"\"Randomly erase spherical regions in 3D medical images with mask-aware placement.\n", "\n", " Simulates post-operative surgical cavities by filling random ellipsoid\n", " volumes with specified values. When mask_only=True (default), cutouts only\n", " affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.\n", "\n", " Args:\n", " holes: Minimum number of cutout regions. Default: 1.\n", " max_holes: Maximum number of regions. Default: 3.\n", " spatial_size: Minimum cutout diameter in voxels. Default: 8.\n", " max_spatial_size: Maximum cutout diameter. Default: 16.\n", " fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.\n", " mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue).\n", " If False, cutouts can affect any voxel (original behavior). Default: True.\n", " p: Probability of applying transform. Default: 0.2.\n", "\n", " Example:\n", " >>> # Simulate post-op cavities only within tumor regions\n", " >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,\n", " ... max_spatial_size=25, fill='min', mask_only=True, p=0.2)\n", "\n", " >>> # Original behavior - cutouts anywhere in the volume\n", " >>> tfm = RandomCutout(mask_only=False, p=0.2)\n", " \"\"\"\n", "\n", " split_idx, order = 0, 1\n", "\n", " def __init__(self, holes=1, max_holes=3, spatial_size=8,\n", " max_spatial_size=16, fill='min', mask_only=True, p=0.2):\n", " self.holes = holes\n", " self.max_holes = max_holes\n", " self.spatial_size = spatial_size\n", " self.max_spatial_size = max_spatial_size\n", " self.fill = fill\n", " self.mask_only = mask_only\n", " self.p = p\n", "\n", " self._tio_cutout = _TioRandomCutout(\n", " holes=holes, spatial_size=spatial_size,\n", " fill_value=None if isinstance(fill, str) else fill,\n", " max_holes=max_holes, max_spatial_size=max_spatial_size,\n", " mask_only=mask_only, p=p\n", " )\n", "\n", " @property\n", " def tio_transform(self):\n", " \"\"\"Return TorchIO-compatible transform for patch-based workflows.\n", "\n", " Note: For mask-aware cutouts in patch workflows, the mask must be\n", " available in the TorchIO Subject as 'mask' key.\n", " \"\"\"\n", " return self._tio_cutout\n", "\n", " def _get_fill_value(self, tensor):\n", " if self.fill == 'min': return float(tensor.min())\n", " elif self.fill == 'mean': return float(tensor.mean())\n", " elif self.fill == 'random':\n", " return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()\n", " else: return self.fill\n", "\n", " def encodes(self, x):\n", " \"\"\"Apply mask-aware cutout to image.\n", "\n", " Args:\n", " x: Tuple of (MedImage, target) where target is MedMask or TensorCategory\n", "\n", " Returns:\n", " Tuple of (transformed MedImage, unchanged target)\n", " \"\"\"\n", " img, y_true = x\n", "\n", " # Probability check\n", " if torch.rand(1).item() > self.p:\n", " return img, y_true\n", "\n", " # Get mask data if available (as numpy for safe boolean operations)\n", " mask_np = None\n", " tumor_coords = None\n", " if isinstance(y_true, MedMask):\n", " mask_np = y_true.numpy()\n", " # Get coordinates of tumor voxels for mask-aware center placement\n", " if self.mask_only and (mask_np > 0).any():\n", " tumor_coords = np.argwhere(mask_np[0] > 0) # Shape: (N, 3) for z, y, x\n", "\n", " # Skip cutout if mask_only=True but no mask or empty mask\n", " if self.mask_only:\n", " if mask_np is None or tumor_coords is None or len(tumor_coords) == 0:\n", " return img, y_true\n", "\n", " # Work with numpy array to avoid tensor subclass issues\n", " result_np = img.numpy().copy()\n", " spatial_shape = img.shape[1:] # (D, H, W)\n", " fill_val = self._get_fill_value(img)\n", "\n", " n_holes = torch.randint(self.holes, self.max_holes + 1, (1,)).item()\n", "\n", " min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]\n", " max_size = self.max_spatial_size or self.spatial_size\n", " max_size = max_size if isinstance(max_size, int) else max_size[0]\n", "\n", " for _ in range(n_holes):\n", " # Random size for this hole\n", " size = torch.randint(min_size, max_size + 1, (3,))\n", " radii = size.float() / 2\n", "\n", " if self.mask_only and tumor_coords is not None:\n", " # Pick center from within tumor region to ensure intersection\n", " idx = torch.randint(0, len(tumor_coords), (1,)).item()\n", " center = tumor_coords[idx].tolist() # [z, y, x]\n", " else:\n", " # Random center anywhere (ensure hole fits in volume)\n", " center = [\n", " torch.randint(int(radii[i].item()),\n", " max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),\n", " (1,)).item()\n", " for i in range(3)\n", " ]\n", "\n", " # Create ellipsoid mask (numpy)\n", " ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii).numpy()\n", "\n", " if self.mask_only and mask_np is not None:\n", " # INTERSECT with tumor mask - only affect tumor voxels\n", " tumor_mask = mask_np[0] > 0\n", " cutout_region = ellipsoid & tumor_mask\n", " else:\n", " cutout_region = ellipsoid\n", "\n", " result_np[:, cutout_region] = fill_val\n", "\n", " return MedImage.create(torch.from_numpy(result_np)), y_true" ] }, { "cell_type": "markdown", "id": "f71b3ae7", "metadata": {}, "source": [ "## Dictionary transforms " ] }, { "cell_type": "code", "execution_count": null, "id": "3c4f113d", "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomElasticDeformation(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomElasticDeformation`.\"\"\"\n", "\n", " def __init__(self, num_control_points=7, max_displacement=7.5,\n", " image_interpolation='linear', p=0.5):\n", " \n", " super().__init__(tio.RandomElasticDeformation(\n", " num_control_points=num_control_points,\n", " max_displacement=max_displacement,\n", " image_interpolation=image_interpolation,\n", " p=p))" ] }, { "cell_type": "code", "execution_count": null, "id": "feaecd21", "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomAffine(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomAffine`.\"\"\"\n", "\n", " def __init__(self, scales=0, degrees=10, translation=0, isotropic=False,\n", " image_interpolation='linear', default_pad_value=0., p=0.5):\n", " \n", " super().__init__(tio.RandomAffine(\n", " scales=scales,\n", " degrees=degrees,\n", " translation=translation,\n", " isotropic=isotropic,\n", " image_interpolation=image_interpolation,\n", " default_pad_value=default_pad_value,\n", " p=p))" ] }, { "cell_type": "code", "execution_count": null, "id": "022c90cf", "metadata": {}, "outputs": [], "source": [ "# | export\n", "class RandomFlip(CustomDictTransform):\n", " \"\"\"Apply TorchIO `RandomFlip`.\"\"\"\n", "\n", " def __init__(self, axes='LR', p=0.5):\n", " super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))" ] }, { "cell_type": "code", "execution_count": null, "id": "ddd7b99b", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class OneOf(CustomDictTransform):\n", " \"\"\"Apply only one of the given transforms using TorchIO `OneOf`.\"\"\"\n", "\n", " def __init__(self, transform_dict, p=1):\n", " super().__init__(tio.OneOf(transform_dict, p=p))" ] }, { "cell_type": "code", "execution_count": null, "id": "5117c50a", "metadata": {}, "outputs": [], "source": "# Test .tio_transform property\n# CustomDictTransform-based wrappers\ntest_eq(type(RandomAffine(degrees=10).tio_transform), tio.RandomAffine)\ntest_eq(type(RandomFlip(p=0.5).tio_transform), tio.RandomFlip)\ntest_eq(type(RandomElasticDeformation(p=0.5).tio_transform), tio.RandomElasticDeformation)\n\n# DisplayedTransform-based wrappers\ntest_eq(type(PadOrCrop([64, 64, 64]).tio_transform), tio.CropOrPad)\ntest_eq(type(ZNormalization().tio_transform), tio.ZNormalization)\ntest_eq(type(RescaleIntensity((-1, 1), (-1000, 1000)).tio_transform), tio.RescaleIntensity)\ntest_eq(type(RandomGamma(p=0.5).tio_transform), tio.RandomGamma)\ntest_eq(type(RandomNoise(p=0.5).tio_transform), tio.RandomNoise)\ntest_eq(type(RandomBiasField(p=0.5).tio_transform), tio.RandomBiasField)\ntest_eq(type(RandomBlur(p=0.5).tio_transform), tio.RandomBlur)\ntest_eq(type(RandomGhosting(p=0.5).tio_transform), tio.RandomGhosting)\ntest_eq(type(RandomSpike(p=0.5).tio_transform), tio.RandomSpike)\ntest_eq(type(RandomMotion(p=0.5).tio_transform), tio.RandomMotion)\n\n# Custom TorchIO wrappers (isinstance check since these are custom subclasses)\ntest_eq(isinstance(RandomIntensityScale(p=0.5).tio_transform, tio.IntensityTransform), True)\ntest_eq(isinstance(NormalizeIntensity().tio_transform, tio.IntensityTransform), True)" }, { "cell_type": "code", "execution_count": null, "id": "2c4e671d", "metadata": {}, "outputs": [], "source": [ "# Test RandomCutout (ItemTransform - expects tuple input)\n", "import numpy as np\n", "\n", "# Create test data\n", "test_img = MedImage(torch.randn(1, 32, 32, 32))\n", "test_mask = MedMask(torch.zeros(1, 32, 32, 32))\n", "test_mask[0, 10:20, 10:20, 10:20] = 1.0 # Tumor region\n", "\n", "# Test mask_only=True (default): only tumor voxels affected\n", "cutout = RandomCutout(holes=1, spatial_size=8, fill='min', mask_only=True, p=1.0)\n", "result_img, result_mask = cutout.encodes((test_img, test_mask))\n", "test_eq(type(result_img), MedImage)\n", "test_eq(type(result_mask), MedMask)\n", "test_eq(result_img.shape, test_img.shape)\n", "# Verify: healthy tissue (mask==0) unchanged (use numpy for comparison)\n", "healthy_region = test_mask.numpy()[0] == 0\n", "test_eq(np.array_equal(result_img.numpy()[0, healthy_region], test_img.numpy()[0, healthy_region]), True)\n", "# Verify: mask unchanged\n", "test_eq(torch.equal(result_mask, test_mask), True)\n", "\n", "# Test empty mask skips cutout (mask_only=True)\n", "empty_mask = MedMask(torch.zeros(1, 32, 32, 32))\n", "result_img, _ = cutout.encodes((test_img, empty_mask))\n", "test_eq(torch.equal(result_img, test_img), True) # Unchanged\n", "\n", "# Test mask_only=False: cutouts can affect any voxel\n", "cutout_any = RandomCutout(mask_only=False, p=1.0)\n", "result_img, _ = cutout_any.encodes((test_img, test_mask))\n", "test_eq(result_img.shape, test_img.shape)\n", "\n", "# Test with TensorCategory target (classification task with mask_only=False)\n", "test_label = TensorCategory(1)\n", "cutout_cls = RandomCutout(mask_only=False, p=1.0)\n", "result_img, result_label = cutout_cls.encodes((test_img, test_label))\n", "test_eq(type(result_img), MedImage)\n", "test_eq(result_label, test_label)\n", "\n", "# Test TensorCategory with mask_only=True skips cutout (no mask available)\n", "cutout_mask_only = RandomCutout(mask_only=True, p=1.0)\n", "result_img, result_label = cutout_mask_only.encodes((test_img, test_label))\n", "test_eq(torch.equal(result_img, test_img), True) # Unchanged - no mask to intersect\n", "\n", "# tio_transform property\n", "test_eq(isinstance(cutout.tio_transform, tio.IntensityTransform), True)\n", "\n", "# Test fill modes with mask_only=False\n", "for fill_mode in ['min', 'mean', 'random', 0.0]:\n", " cutout = RandomCutout(fill=fill_mode, mask_only=False, p=1.0)\n", " result_img, _ = cutout.encodes((test_img, test_mask))\n", " test_eq(result_img.shape, test_img.shape)" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }