{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Writing Custom Dataset Importers\n",
    "\n",
    "This recipe demonstrates how to write a [custom DatasetImporter](https://voxel51.com/docs/fiftyone/user_guide/dataset_creation/datasets.html#custom-formats) and use it to load a dataset from disk in your custom format into FiftyOne."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "If you haven't already, install FiftyOne:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install fiftyone"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this recipe we'll use the [FiftyOne Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_creation/zoo_datasets.html) to download the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) to use as sample data to feed our custom importer.\n",
    "\n",
    "Behind the scenes, FiftyOne either uses the\n",
    "[TensorFlow Datasets](https://www.tensorflow.org/datasets) or\n",
    "[TorchVision Datasets](https://pytorch.org/vision/stable/datasets.html) libraries to wrangle the datasets, depending on which ML library you have installed.\n",
    "\n",
    "You can, for example, install PyTorch as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing a DatasetImporter\n",
    "\n",
    "FiftyOne provides a [DatasetImporter](https://voxel51.com/docs/fiftyone/api/fiftyone.utils.data.html#fiftyone.utils.data.importers.DatasetImporter) interface that defines how it imports datasets from disk when methods such as [Dataset.from_importer()](https://voxel51.com/docs/fiftyone/api/fiftyone.core.html#fiftyone.core.dataset.Dataset.from_importer) are used.\n",
    "\n",
    "`DatasetImporter` itself is an abstract interface; the concrete interface that you should implement is determined by the type of dataset that you are importing. See [writing a custom DatasetImporter](https://voxel51.com/docs/fiftyone/user_guide/dataset_creation/datasets.html#custom-formats) for full details.\n",
    "\n",
    "In this recipe, we'll write a custom [LabeledImageDatasetImporter](https://voxel51.com/docs/fiftyone/api/fiftyone.utils.data.html#fiftyone.utils.data.importers.LabeledImageDatasetImporter) that can import an image classification dataset whose image metadata and labels are stored in a `labels.csv` file in the dataset directory with the following format:\n",
    "\n",
    "```\n",
    "filepath,size_bytes,mime_type,width,height,num_channels,label\n",
    "<filepath>,<size_bytes>,<mime_type>,<width>,<height>,<num_channels>,<label>\n",
    "<filepath>,<size_bytes>,<mime_type>,<width>,<height>,<num_channels>,<label>\n",
    "...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's the complete definition of the `DatasetImporter`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import os\n",
    "\n",
    "import fiftyone as fo\n",
    "import fiftyone.utils.data as foud\n",
    "\n",
    "\n",
    "class CSVImageClassificationDatasetImporter(foud.LabeledImageDatasetImporter):\n",
    "    \"\"\"Importer for image classification datasets whose filepaths and labels\n",
    "    are stored on disk in a CSV file.\n",
    "\n",
    "    Datasets of this type should contain a ``labels.csv`` file in their\n",
    "    dataset directories in the following format::\n",
    "\n",
    "        filepath,size_bytes,mime_type,width,height,num_channels,label\n",
    "        <filepath>,<size_bytes>,<mime_type>,<width>,<height>,<num_channels>,<label>\n",
    "        <filepath>,<size_bytes>,<mime_type>,<width>,<height>,<num_channels>,<label>\n",
    "        ...\n",
    "\n",
    "    Args:\n",
    "        dataset_dir: the dataset directory\n",
    "        shuffle (False): whether to randomly shuffle the order in which the\n",
    "            samples are imported\n",
    "        seed (None): a random seed to use when shuffling\n",
    "        max_samples (None): a maximum number of samples to import. By default,\n",
    "            all samples are imported\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        dataset_dir,\n",
    "        shuffle=False,\n",
    "        seed=None,\n",
    "        max_samples=None,\n",
    "    ):\n",
    "        super().__init__(\n",
    "            dataset_dir=dataset_dir,\n",
    "            shuffle=shuffle,\n",
    "            seed=seed,\n",
    "            max_samples=max_samples\n",
    "        )\n",
    "        self._labels_file = None\n",
    "        self._labels = None\n",
    "        self._iter_labels = None\n",
    "\n",
    "    def __iter__(self):\n",
    "        self._iter_labels = iter(self._labels)\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        \"\"\"Returns information about the next sample in the dataset.\n",
    "\n",
    "        Returns:\n",
    "            an  ``(image_path, image_metadata, label)`` tuple, where\n",
    "\n",
    "            -   ``image_path``: the path to the image on disk\n",
    "            -   ``image_metadata``: an\n",
    "                :class:`fiftyone.core.metadata.ImageMetadata` instances for the\n",
    "                image, or ``None`` if :meth:`has_image_metadata` is ``False``\n",
    "            -   ``label``: an instance of :meth:`label_cls`, or a dictionary\n",
    "                mapping field names to :class:`fiftyone.core.labels.Label`\n",
    "                instances, or ``None`` if the sample is unlabeled\n",
    "\n",
    "        Raises:\n",
    "            StopIteration: if there are no more samples to import\n",
    "        \"\"\"\n",
    "        (\n",
    "            filepath,\n",
    "            size_bytes,\n",
    "            mime_type,\n",
    "            width,\n",
    "            height,\n",
    "            num_channels,\n",
    "            label,\n",
    "        ) = next(self._iter_labels)\n",
    "\n",
    "        image_metadata = fo.ImageMetadata(\n",
    "            size_bytes=size_bytes,\n",
    "            mime_type=mime_type,\n",
    "            width=width,\n",
    "            height=height,\n",
    "            num_channels=num_channels,\n",
    "        )\n",
    "\n",
    "        label = fo.Classification(label=label)\n",
    "        return filepath, image_metadata, label\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"The total number of samples that will be imported.\n",
    "\n",
    "        Raises:\n",
    "            TypeError: if the total number is not known\n",
    "        \"\"\"\n",
    "        return len(self._labels)\n",
    "\n",
    "    @property\n",
    "    def has_dataset_info(self):\n",
    "        \"\"\"Whether this importer produces a dataset info dictionary.\"\"\"\n",
    "        return False\n",
    "\n",
    "    @property\n",
    "    def has_image_metadata(self):\n",
    "        \"\"\"Whether this importer produces\n",
    "        :class:`fiftyone.core.metadata.ImageMetadata` instances for each image.\n",
    "        \"\"\"\n",
    "        return True\n",
    "\n",
    "    @property\n",
    "    def label_cls(self):\n",
    "        \"\"\"The :class:`fiftyone.core.labels.Label` class(es) returned by this\n",
    "        importer.\n",
    "\n",
    "        This can be any of the following:\n",
    "\n",
    "        -   a :class:`fiftyone.core.labels.Label` class. In this case, the\n",
    "            importer is guaranteed to return labels of this type\n",
    "        -   a list or tuple of :class:`fiftyone.core.labels.Label` classes. In\n",
    "            this case, the importer can produce a single label field of any of\n",
    "            these types\n",
    "        -   a dict mapping keys to :class:`fiftyone.core.labels.Label` classes.\n",
    "            In this case, the importer will return label dictionaries with keys\n",
    "            and value-types specified by this dictionary. Not all keys need be\n",
    "            present in the imported labels\n",
    "        -   ``None``. In this case, the importer makes no guarantees about the\n",
    "            labels that it may return\n",
    "        \"\"\"\n",
    "        return fo.Classification\n",
    "\n",
    "    def setup(self):\n",
    "        \"\"\"Performs any necessary setup before importing the first sample in\n",
    "        the dataset.\n",
    "\n",
    "        This method is called when the importer's context manager interface is\n",
    "        entered, :func:`DatasetImporter.__enter__`.\n",
    "        \"\"\"\n",
    "        labels_path = os.path.join(self.dataset_dir, \"labels.csv\")\n",
    "\n",
    "        labels = []\n",
    "        with open(labels_path, \"r\") as f:\n",
    "            reader = csv.DictReader(f)\n",
    "            for row in reader:\n",
    "                labels.append((\n",
    "                    row[\"filepath\"],\n",
    "                    row[\"size_bytes\"],\n",
    "                    row[\"mime_type\"],\n",
    "                    row[\"width\"],\n",
    "                    row[\"height\"],\n",
    "                    row[\"num_channels\"],\n",
    "                    row[\"label\"],\n",
    "                ))\n",
    "\n",
    "        # The `_preprocess_list()` function is provided by the base class\n",
    "        # and handles shuffling/max sample limits\n",
    "        self._labels = self._preprocess_list(labels)\n",
    "\n",
    "    def close(self, *args):\n",
    "        \"\"\"Performs any necessary actions after the last sample has been\n",
    "        imported.\n",
    "\n",
    "        This method is called when the importer's context manager interface is\n",
    "        exited, :func:`DatasetImporter.__exit__`.\n",
    "\n",
    "        Args:\n",
    "            *args: the arguments to :func:`DatasetImporter.__exit__`\n",
    "        \"\"\"\n",
    "        pass\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating a sample dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to use `CSVImageClassificationDatasetImporter`, we need to generate a sample dataset in the required format.\n",
    "\n",
    "Let's first write a small utility to populate a `labels.csv` file in the required format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_csv_labels(samples, csv_path, label_field=\"ground_truth\"):\n",
    "    \"\"\"Writes a labels CSV format for the given samples in the format expected\n",
    "    by :class:`CSVImageClassificationDatasetImporter`.\n",
    "\n",
    "    Args:\n",
    "        samples: an iterable of :class:`fiftyone.core.sample.Sample` instances\n",
    "        csv_path: the path to write the CSV file\n",
    "        label_field (\"ground_truth\"): the label field of the samples to write\n",
    "    \"\"\"\n",
    "    # Ensure base directory exists\n",
    "    basedir = os.path.dirname(csv_path)\n",
    "    if basedir and not os.path.isdir(basedir):\n",
    "        os.makedirs(basedir)\n",
    "\n",
    "    # Write the labels\n",
    "    with open(csv_path, \"w\") as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow([\n",
    "            \"filepath\",\n",
    "            \"size_bytes\",\n",
    "            \"mime_type\",\n",
    "            \"width\",\n",
    "            \"height\",\n",
    "            \"num_channels\",\n",
    "            \"label\",\n",
    "        ])\n",
    "        for sample in samples:\n",
    "            filepath = sample.filepath\n",
    "            metadata = sample.metadata\n",
    "            if metadata is None:\n",
    "                metadata = fo.ImageMetadata.build_for(filepath)\n",
    "\n",
    "            label = sample[label_field].label\n",
    "            writer.writerow([\n",
    "                filepath,\n",
    "                metadata.size_bytes,\n",
    "                metadata.mime_type,\n",
    "                metadata.width,\n",
    "                metadata.height,\n",
    "                metadata.num_channels,\n",
    "                label,\n",
    "            ])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's populate a directory with a `labels.csv` file in the format required by `CSVImageClassificationDatasetImporter` with some samples from the test split of CIFAR-10:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 'test' already downloaded\n",
      "Loading existing dataset 'cifar10-test'. To reload from disk, first delete the existing dataset\n",
      "Computing metadata for samples\n",
      " 100% |█████| 1000/1000 [421.2ms elapsed, 0s remaining, 2.4K samples/s]      \n",
      "Writing labels for 1000 samples to '/tmp/fiftyone/custom-dataset-importer/labels.csv'\n"
     ]
    }
   ],
   "source": [
    "import fiftyone.zoo as foz\n",
    "\n",
    "\n",
    "dataset_dir = \"/tmp/fiftyone/custom-dataset-importer\"\n",
    "num_samples = 1000\n",
    "\n",
    "#\n",
    "# Load `num_samples` from CIFAR-10\n",
    "#\n",
    "# This command will download the test split of CIFAR-10 from the web the first\n",
    "# time it is executed, if necessary\n",
    "#\n",
    "cifar10_test = foz.load_zoo_dataset(\"cifar10\", split=\"test\")\n",
    "samples = cifar10_test.limit(num_samples)\n",
    "\n",
    "# This dataset format requires samples to have their `metadata` fields populated\n",
    "print(\"Computing metadata for samples\")\n",
    "samples.compute_metadata()\n",
    "\n",
    "# Write labels to disk in CSV format\n",
    "csv_path = os.path.join(dataset_dir, \"labels.csv\")\n",
    "print(\"Writing labels for %d samples to '%s'\" % (num_samples, csv_path))\n",
    "write_csv_labels(samples, csv_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's inspect the contents of the labels CSV to ensure they're in the correct format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "filepath,size_bytes,mime_type,width,height,num_channels,label\r\n",
      "~/fiftyone/cifar10/test/data/000001.jpg,1422,image/jpeg,32,32,3,cat\r\n",
      "~/fiftyone/cifar10/test/data/000002.jpg,1285,image/jpeg,32,32,3,ship\r\n",
      "~/fiftyone/cifar10/test/data/000003.jpg,1258,image/jpeg,32,32,3,ship\r\n",
      "~/fiftyone/cifar10/test/data/000004.jpg,1244,image/jpeg,32,32,3,airplane\r\n",
      "~/fiftyone/cifar10/test/data/000005.jpg,1388,image/jpeg,32,32,3,frog\r\n",
      "~/fiftyone/cifar10/test/data/000006.jpg,1311,image/jpeg,32,32,3,frog\r\n",
      "~/fiftyone/cifar10/test/data/000007.jpg,1412,image/jpeg,32,32,3,automobile\r\n",
      "~/fiftyone/cifar10/test/data/000008.jpg,1218,image/jpeg,32,32,3,frog\r\n",
      "~/fiftyone/cifar10/test/data/000009.jpg,1262,image/jpeg,32,32,3,cat\r\n"
     ]
    }
   ],
   "source": [
    "!head -n 10 /tmp/fiftyone/custom-dataset-importer/labels.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Importing a dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With our dataset and `DatasetImporter` in-hand, loading the data as a FiftyOne dataset is as simple as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Importing dataset from '/tmp/fiftyone/custom-dataset-importer'\n",
      " 100% |█████| 1000/1000 [780.7ms elapsed, 0s remaining, 1.3K samples/s]      \n"
     ]
    }
   ],
   "source": [
    "# Import the dataset\n",
    "print(\"Importing dataset from '%s'\" % dataset_dir)\n",
    "importer = CSVImageClassificationDatasetImporter(dataset_dir)\n",
    "dataset = fo.Dataset.from_importer(importer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name:           2020.07.14.22.33.01\n",
      "Persistent:     False\n",
      "Num samples:    1000\n",
      "Tags:           []\n",
      "Sample fields:\n",
      "    filepath:     fiftyone.core.fields.StringField\n",
      "    tags:         fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n",
      "    metadata:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n",
      "    ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n"
     ]
    }
   ],
   "source": [
    "# Print summary information about the dataset\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Sample: {\n",
      "    'dataset_name': '2020.07.14.22.33.01',\n",
      "    'id': '5f0e6add1dfd5f8c299ac528',\n",
      "    'filepath': '~/fiftyone/cifar10/test/data/000001.jpg',\n",
      "    'tags': BaseList([]),\n",
      "    'metadata': <ImageMetadata: {\n",
      "        'size_bytes': 1422,\n",
      "        'mime_type': 'image/jpeg',\n",
      "        'width': 32,\n",
      "        'height': 32,\n",
      "        'num_channels': 3,\n",
      "    }>,\n",
      "    'ground_truth': <Classification: {'label': 'cat', 'confidence': None, 'logits': None}>,\n",
      "}>\n"
     ]
    }
   ],
   "source": [
    "# Print a sample\n",
    "print(dataset.first())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup\n",
    "\n",
    "You can cleanup the files generated by this recipe by running:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf /tmp/fiftyone"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}