{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# One-image four-masks DataBunch \n", "\n", "Kaggle's [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization) competition is about marking out, on a satellite image, areas that have certain types of clouds.\n", "\n", "Since each image can have up to four cloud types in it, we can try taking the image as the independent variable, while the dependent variable will consist of four masks, one for each cloud type. Each mask will only contain values of either 0 or 1. 0 means the pixel does not belong to the cloud type; 1 means the pixel belongs to the cloud type. \n", "\n", "The dataloader will need to load mini-batches in which each item consists of an image (as the independent variable), plus four masks (as the dependent variable)." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('/Users/jack/git_repos/fastai_dev/dev')\n", "from local.data.all import *\n", "from local.vision.all import *\n", "from local.vision.core import *\n", "from local.vision.augment import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The four cloud types are:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "CATS = ['fish', 'flower', 'gravel', 'sugar']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the images' file paths and the annotations" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "SOURCE = Path('data')" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "items = get_image_files(SOURCE/'train_images')" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "IMG_SHAPE = PILImage.create(items[12]).shape" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "def load_train_annotation(fpath):\n", " df = pd.read_csv(fpath)\n", " df['Image'] = df.Image_Label.apply(lambda o: o.split('.')[0])\n", " df['Label'] = df.Image_Label.apply(lambda o: o.split('_')[1].lower())\n", " df.drop('Image_Label', axis=1, inplace=True)\n", " df = df[['Image', 'Label', 'EncodedPixels']]\n", " return df" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "annots = load_train_annotation(SOURCE/'train.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decode Run-length Encoding" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def rle_decode(mask_rle: str = '', shape: tuple = (1400, 2100)):\n", " '''\n", " Decode rle encoded mask.\n", " \n", " :param mask_rle: run-length as string formatted (start length)\n", " :param shape: (height, width) of array to return \n", " Returns numpy array, 1 - mask, 0 - background\n", " \n", " Copied from https://www.kaggle.com/artgor/segmentation-in-pytorch-using-convenient-tools\n", " '''\n", " s = mask_rle.split()\n", " starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]\n", " starts -= 1\n", " ends = starts + lengths\n", " img = np.zeros(shape[0] * shape[1], dtype=np.uint8)\n", " for lo, hi in zip(starts, ends):\n", " img[lo:hi] = 1\n", " return img.reshape(shape, order='F')" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "class RLE_Decode(Transform):\n", " '''\n", " Image file path -> tensor\n", " '''\n", " def __init__(self, cat, annots, img_shape): \n", " self.cat, self.annots, self.img_shape = cat, annots, img_shape\n", " \n", " def encodes(self, o):\n", " df = self.annots[self.annots.Image == o.stem].fillna('')\n", " px_rle = df[df.Label == self.cat].EncodedPixels.values[0]\n", " return rle_decode(px_rle, IMG_SHAPE)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Image | \n", "Label | \n", "EncodedPixels | \n", "
---|---|---|---|
8600 | \n", "61d6640 | \n", "fish | \n", "NaN | \n", "
8601 | \n", "61d6640 | \n", "flower | \n", "NaN | \n", "
8602 | \n", "61d6640 | \n", "gravel | \n", "1349079 387 1350479 387 1351879 387 1353279 387 1354679 387 1356079 387 1357479 387 1358879 387 1360279 387 1361679 387 1363079 387 1364479 387 1365879 387 1367279 387 1368679 387 1370079 387 1371479 387 1372879 387 1374279 387 1375679 387 1377079 387 1378479 387 1379879 387 1381279 387 1382679 387 1384079 387 1385479 387 1386879 387 1388279 387 1389679 387 1391079 387 1392479 387 1393879 387 1395279 387 1396679 387 1398079 387 1399479 387 1400879 387 1402279 387 1403679 387 1405079 387 1406479 387 1407879 387 1409279 387 1410679 387 1412079 387 1413479 387 1414879 387 1416279 387 1417679 ... | \n", "
8603 | \n", "61d6640 | \n", "sugar | \n", "373839 334 375239 334 376639 334 378039 334 379439 334 380839 334 382239 334 383639 334 385039 334 386439 334 387839 334 389239 334 390639 334 392039 334 393439 334 394839 334 396239 334 397639 334 399039 334 400439 334 401839 334 403239 334 404639 334 406039 334 407439 334 408839 334 410239 334 411639 334 413039 334 414439 334 415839 334 417239 334 418639 334 420039 334 421439 334 422839 334 424239 334 425639 334 427039 334 428439 334 429839 334 431239 334 432639 334 434039 334 435439 334 436839 334 438239 334 439639 334 441039 334 442439 334 443839 334 445239 334 446639 334 448039 334 44... | \n", "