{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Aug_DiskCachedDataset for efficient caching of augmented copies\n", "- `Aug_DiskCachedDataset` is a modified version of `DiskCachedDataset` that is useful while applying deterministic augmentations on data samples. \n", "\n", "- This is the case when the parameter space of augmentation is desceret, for instance applying `pitchshift` on audio data in which shift parameter (semitone) can only take N values.\n", "\n", "- Using `DiskCachedDataset` and setting `num_copies` to N is likely to cause 2 issues:\n", "\n", " - Copies might not be unique, as copy_index is not linked to the augmentation parameter \n", " - And there is no guarantee that copies cover the desired augmentation space\n", " \n", "\n", "\n", "- `Aug_DiskCachedDataset` resolves this limitation by mapping and linking copy index to augmentation parameter. Following considerations need to be takes into account:\n", "\n", " - The user needs to pass `all_transforms` dict as input with seperated transforms `pre_aug`, `aug`, `post_aug` (spesifying transforms that are applied before and after augmentations, also augmentation transforms). \n", " \n", " - The augmentation class receives `aug_index` (aug_index = copy) as initialization parameter also `caching=True` needs to be set (please see `tonic.audio_augmentations`)\n", "\n", "- Follwing is a simple example to show function of `Aug_DiskCachedDataset`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A simple dataset " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# %%writefile mini_dataset.py\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "import numpy as np\n", "from torch.utils.data import Dataset\n", "\n", "\n", "class mini_dataset(Dataset):\n", " def __init__(self) -> None:\n", " super().__init__()\n", " np.random.seed(0)\n", " self.data = np.random.rand(10, 16000)\n", " self.transform = None\n", " self.target_transform = None\n", "\n", " def __getitem__(self, index):\n", " sample = self.data[index]\n", " label = 1\n", " if sample.ndim == 1:\n", " sample = sample[None, ...]\n", " if self.transform is not None:\n", " sample = self.transform(sample)\n", " if self.target_transform is not None:\n", " label = self.target_transform(label)\n", "\n", " return sample, label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initializing `Aug_DiskCachedDataset` with transforms" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tonic.audio_augmentations import RandomPitchShift\n", "from tonic.audio_transforms import AmplitudeScale, FixLength\n", "from tonic.cached_dataset import Aug_DiskCachedDataset, load_from_disk_cache\n", "\n", "all_transforms = {}\n", "all_transforms[\"pre_aug\"] = [AmplitudeScale(max_amplitude=0.150)]\n", "all_transforms[\"augmentations\"] = [RandomPitchShift(samplerate=16000, caching=True)]\n", "all_transforms[\"post_aug\"] = [FixLength(16000)]\n", "\n", "# number of copies is set to number of augmentation params (factors)\n", "n = len(RandomPitchShift(samplerate=16000, caching=True).factors)\n", "Aug_cach = Aug_DiskCachedDataset(\n", " dataset=mini_dataset(),\n", " cache_path=\"cache/\",\n", " all_transforms=all_transforms,\n", " num_copies=n,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating all copies of a data sample\n", " - 10 augmented versions of data sample with index = 0 are generated" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "sample_index = 0\n", "Aug_cach.generate_all(sample_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To verify\n", " - loading the saved copies \n", " - and comparing them with the ones generated out of `Aug_DiskCacheDataset` with the same transforms and matching augmentation parameter \n", " - they are equal\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "True\n" ] } ], "source": [ "from torchvision.transforms import Compose\n", "\n", "for i in range(n):\n", " transform = Compose(\n", " [\n", " AmplitudeScale(max_amplitude=0.150),\n", " RandomPitchShift(samplerate=16000, caching=True, aug_index=i),\n", " FixLength(16000),\n", " ]\n", " )\n", " ds = mini_dataset()\n", " ds.transform = transform\n", " sample = ds[sample_index][0]\n", " data, targets = load_from_disk_cache(\"cache/\" + \"0_\" + str(i) + \".hdf5\")\n", " print((sample == data).all())" ] } ], "metadata": { "kernelspec": { "display_name": "py_310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" } }, "nbformat": 4, "nbformat_minor": 2 }