{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizing datasets in fastai" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we'll see how to create custom subclasses of [`ItemBase`](/core.html#ItemBase) or [`ItemList`](/data_block.html#ItemList) while retaining everything the fastai library has to offer. To allow basic functions to work consistently across various applications, the fastai library delegates several tasks to one of those specific objects, and we'll see here which methods you have to implement to be able to have everything work properly. But first let's take a step back to see where you'll use your end result." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Links with the data block API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data block API works by allowing you to pick a class that is responsible to get your items and another class that is charged with getting your targets. Combined together, they create a pytorch [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) that is then wrapped inside a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). The training set, validation set and maybe test set are then all put in a [`DataBunch`](/basic_data.html#DataBunch).\n", "\n", "The data block API allows you to mix and match what class your inputs have, what class your targets have, how to do the split between train and validation set, then how to create the [`DataBunch`](/basic_data.html#DataBunch), but if you have a very specific kind of input/target, the fastai classes might no be sufficient to you. This tutorial is there to explain what is needed to create a new class of items and what methods are important to implement or override.\n", "\n", "It goes in two phases: first we focus on what you need to create a custom [`ItemBase`](/core.html#ItemBase) class (which is the type of your inputs/targets) then on how to create your custom [`ItemList`](/data_block.html#ItemList) (which is basically a set of [`ItemBase`](/core.html#ItemBase)) while highlighting which methods are called by the library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a custom [`ItemBase`](/core.html#ItemBase) subclass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fastai library contains three basic types of [`ItemBase`](/core.html#ItemBase) that you might want to subclass:\n", "- [`Image`](/vision.image.html#Image) for vision applications\n", "- [`Text`](/text.data.html#Text) for text applications\n", "- [`TabularLine`](/tabular.data.html#TabularLine) for tabular applications\n", "\n", "Whether you decide to create your own item class or to subclass one of the above, here is what you need to implement:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic attributes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Those are the more important attributes your custom [`ItemBase`](/core.html#ItemBase) needs as they're used everywhere in the fastai library:\n", "- `ItemBase.data` is the thing that is passed to pytorch when you want to create a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). This is what needs to be fed to your model. Note that it might be different from the representation of your item since you might want something that is more understandable.\n", "- `__str__` representation: if applicable, this is what will be displayed when the fastai library has to show your item.\n", "\n", "If we take the example of a [`MultiCategory`](/core.html#MultiCategory) object `o` for instance:\n", "- `o.data` is a tensor where the tags are one-hot encoded\n", "- `str(o)` returns the tags separated by ;\n", "\n", "If you want to code the way data augmentation should be applied to your custom `Item`, you should write an `apply_tfms` method. This is what will be called if you apply a [`transform`](/vision.transform.html#vision.transform) block in the data block API." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: ImageTuple" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For cycleGANs, we need to create a custom type of items since we feed the model tuples of images. Let's look at how to code this. The basis is to code the [`data`](/vision.data.html#vision.data) attribute that is what will be given to the model. Note that we still keep track of the initial object (usuall in an `obj` attribute) to be able to show nice representations later on. Here the object is the tuple of images and the data their underlying tensors normalized between -1 and 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ImageTuple(ItemBase):\n", " def __init__(self, img1, img2):\n", " self.img1,self.img2 = img1,img2\n", " self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we want to apply data augmentation to our tuple of images. That's done by writing an `apply_tfms` method as we saw before. Here we pass that call to the two underlying images then update the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ " def apply_tfms(self, tfms, **kwargs):\n", " self.img1 = self.img1.apply_tfms(tfms, **kwargs)\n", " self.img2 = self.img2.apply_tfms(tfms, **kwargs)\n", " self.data = [-1+2*self.img1.data,-1+2*self.img2.data]\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define a last method to stack the two images next to each other, which we will use later for a customized `show_batch` / `show_results` behavior." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ " def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is all you need to create your custom [`ItemBase`](/core.html#ItemBase). You won't be able to use it until you have put it inside your custom [`ItemList`](/data_block.html#ItemList) though, so you should continue reading the next section." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a custom [`ItemList`](/data_block.html#ItemList) subclass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the main class that allows you to group your inputs or your targets in the data block API. You can then use any of the splitting or labelling methods before creating a [`DataBunch`](/basic_data.html#DataBunch). To make sure everything is properly working, here is what you need to know." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Class variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Whether you're directly subclassing [`ItemList`](/data_block.html#ItemList) or one of the particular fastai ones, make sure to know the content of the following three variables as you may need to adjust them:\n", "- `_bunch` contains the name of the class that will be used to create a [`DataBunch`](/basic_data.html#DataBunch) \n", "- `_processor` contains a class (or a list of classes) of [`PreProcessor`](/data_block.html#PreProcessor) that will then be used as the default to create processor for this [`ItemList`](/data_block.html#ItemList)\n", "- `_label_cls` contains the class that will be used to create the labels by default\n", "\n", "`_label_cls` is the first to be used in the data block API, in the labelling function. If this variable is set to `None`, the label class will be set to [`CategoryList`](/data_block.html#CategoryList), [`MultiCategoryList`](/data_block.html#MultiCategoryList) or [`FloatList`](/data_block.html#FloatList) depending on the type of the first item. The default can be overridden by passing a `label_cls` in the kwargs of the labelling function.\n", "\n", "`_processor` is the second to be used. The processors are called at the end of the labelling to apply some kind of function on your items. The default processor of the inputs can be overriden by passing a `processor` in the kwargs when creating the [`ItemList`](/data_block.html#ItemList), the default processor of the targets can be overridden by passing a `processor` in the kwargs of the labelling function. \n", "\n", "Processors are useful for pre-processing some data, but you also need to put in their state any variable you want to save for the call of `data.export()` before creating a [`Learner`](/basic_train.html#Learner) object for inference: the state of the [`ItemList`](/data_block.html#ItemList) isn't saved there, only their processors. For instance `SegmentationProcessor`'s only reason to exist is to save the dataset classes, and during the process call, it doesn't do anything apart from setting the `classes` and `c` attributes to its dataset.\n", "``` python\n", "class SegmentationProcessor(PreProcessor):\n", " def __init__(self, ds:ItemList): self.classes = ds.classes\n", " def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes)\n", "```\n", "\n", "`_bunch` is the last class variable used in the data block. When you type the final `databunch()`, the data block API calls the `_bunch.create` method with the `_bunch` of the inputs. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Keeping \\_\\_init\\_\\_ arguments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you pass additional arguments in your `__init__` call that you save in the state of your [`ItemList`](/data_block.html#ItemList), we have to make sure they are also passed along in the `new` method as this one is used to create your training and validation set when splitting. To do that, you just have to add their names in the `copy_new` argument of your custom [`ItemList`](/data_block.html#ItemList), preferably during the `__init__`. Here we will need two collections of filenames (for the two type of images) so we make sure the second one is copied like this:\n", "\n", "```python\n", "def __init__(self, items, itemsB=None, **kwargs):\n", " super().__init__(items, **kwargs)\n", " self.itemsB = itemsB\n", " self.copy_new.append('itemsB')\n", "```\n", "\n", "Be sure to keep the kwargs as is, as they contain all the additional stuff you can pass to an [`ItemList`](/data_block.html#ItemList)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Important methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### - get" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The most important method you have to implement is `get`: this one will enable your custom [`ItemList`](/data_block.html#ItemList) to generate an [`ItemBase`](/core.html#ItemBase) from the thing stored in its `items` array. For instance an [`ImageList`](/vision.data.html#ImageList) has the following `get` method:\n", "``` python\n", "def get(self, i):\n", " fn = super().get(i)\n", " res = self.open(fn)\n", " self.sizes[i] = res.size\n", " return res\n", "```\n", "The first line basically looks at `self.items[i]` (which is a filename). The second line opens it since the `open`method is just\n", "``` python\n", "def open(self, fn): return open_image(fn)\n", "```\n", "The third line is there for [`ImagePoints`](/vision.image.html#ImagePoints) or [`ImageBBox`](/vision.image.html#ImageBBox) targets that require the size of the input [`Image`](/vision.image.html#Image) to be created. Note that if you are building a custom target class and you need the size of an image, you should call `self.x.size[i]`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "