{ "cells": [ { "cell_type": "markdown", "id": "208285ec", "metadata": {}, "source": [ "# `datasets` and Remote URLs" ] }, { "cell_type": "markdown", "id": "990a65dd", "metadata": {}, "source": [ "Nowadays it's very usual to distribute large image datasets with links to the original images that were scraped, rather than the images themselves.\n", "\n", "Consider, for example, [Conceptual Captions](https://huggingface.co/datasets/conceptual_captions). Each record contains an URL to the image and a text caption.\n", "\n", "If we load it with `datasets`, images will not be downloaded." ] }, { "cell_type": "code", "execution_count": 1, "id": "93cb3fcc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No config specified, defaulting to: conceptual_captions/unlabeled\n", "Found cached dataset conceptual_captions (/home/pedro/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b13e6745f78a4adf9015493157b24dc5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_image_http(valid[0][\"image_url\"])" ] }, { "cell_type": "markdown", "id": "3d1dcdf3", "metadata": {}, "source": [ "The `map` function can be used to apply transformations to a dataset. We could use it to process the dataset and download all the images.\n", "\n", "**Note**: because the dataset was loaded in normal (non _streaming_) mode, this would process all the items and cache them locally as parquet files. This may be convenient if we have the disk space, but it could be impractical in some situations.\n", "\n", "Let's verify by selecting a subset. The easiest way could be to do something like this:" ] }, { "cell_type": "code", "execution_count": 7, "id": "48473d35", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No config specified, defaulting to: conceptual_captions/unlabeled\n", "Found cached dataset conceptual_captions (/home/pedro/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)\n" ] } ], "source": [ "# small_dataset = load_dataset(\"conceptual_captions\", split=\"validation[:1%]\")\n", "small_dataset = load_dataset(\"conceptual_captions\", split=\"validation[:20]\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "d432af35", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['image_url', 'caption'],\n", " num_rows: 20\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_dataset" ] }, { "cell_type": "markdown", "id": "d9f0d067", "metadata": {}, "source": [ "We can't trust URLs from scraped datasets, they may be long gone. The processing function that we'll apply during mapping will do some very basic error handling." ] }, { "cell_type": "code", "execution_count": 9, "id": "41a6d5a0", "metadata": {}, "outputs": [], "source": [ "def add_image(sample):\n", " try:\n", " return {\"image\": get_image_http(sample[\"image_url\"])}\n", " except:\n", " return {\"image\": None}" ] }, { "cell_type": "code", "execution_count": 10, "id": "a446cfde", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/pedro/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8/cache-fe437448f3d39f72.arrow\n" ] } ], "source": [ "small_dataset = small_dataset.map(add_image)" ] }, { "cell_type": "markdown", "id": "4bc324f4", "metadata": {}, "source": [ "This is iterative. If the process that the map function performs can be done on several samples at once, then we can use the `batched=True` argument and it'll receive a batch instead of a single sample." ] }, { "cell_type": "code", "execution_count": 11, "id": "744f49f3", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/pedro/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8/cache-6b4348a763894fee.arrow\n" ] }, { "data": { "text/plain": [ "Dataset({\n", " features: ['image_url', 'caption', 'image'],\n", " num_rows: 13\n", "})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "small_dataset = small_dataset.filter(lambda x: x[\"image\"] is not None)\n", "small_dataset" ] }, { "cell_type": "code", "execution_count": 12, "id": "0f5c0e9e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image_url': 'https://i.pinimg.com/736x/66/01/6c/66016c3ba27c0e04f39e2bd81a934e3e--anita-ekberg-bob-hope.jpg',\n", " 'caption': 'author : a life in photography -- in pictures',\n", " 'image': }" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(small_dataset))" ] }, { "cell_type": "markdown", "id": "19385393", "metadata": {}, "source": [ "## Streamed datasets" ] }, { "cell_type": "markdown", "id": "46e6f45f", "metadata": {}, "source": [ "We'll now load the dataset in streaming mode. Just for fun, we'll use a version of the data with additional metadata fields. This is called a \"configuration\". It's exposed [in the hub](https://huggingface.co/datasets/conceptual_captions) with a dropdown list called `Subset`." ] }, { "cell_type": "code", "execution_count": 13, "id": "8b45d6cc", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode" ] }, { "cell_type": "code", "execution_count": 14, "id": "013252b2", "metadata": {}, "outputs": [], "source": [ "# Crop and tensorize\n", "\n", "def center_crop(image, max_size=512):\n", " if image is None: return torch.tensor([0.])\n", " s = min(image.size)\n", "\n", " # Note: this would upscale too\n", " r = max_size / s\n", " s = (round(r * image.size[1]), round(r * image.size[0]))\n", " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n", " image = TF.center_crop(image, output_size=2 * [max_size])\n", " image = torch.unsqueeze(T.ToTensor()(image), 0)\n", " return image\n", " \n", "def download_and_crop(url):\n", " try:\n", " image = get_image_http(url)\n", " if image is None: return None\n", " return center_crop(image)\n", " except:\n", " return None\n", "\n", "def download_sample(sample):\n", " return {\"image\": download_and_crop(sample[\"image_url\"])}" ] }, { "cell_type": "code", "execution_count": 15, "id": "49231ed8", "metadata": {}, "outputs": [], "source": [ "train_ds = load_dataset(\n", " \"conceptual_captions\", split=\"train\", streaming=True, name=\"labeled\"\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "a3657647", "metadata": {}, "outputs": [], "source": [ "train_ds = train_ds.filter(lambda x: x[\"confidence_scores\"][0] > 0.98)\n", "train_ds = train_ds.remove_columns([\"labels\", \"MIDs\", \"confidence_scores\"])\n", "train_ds = train_ds.map(download_sample).filter(lambda x: x[\"image\"] is not None)" ] }, { "cell_type": "markdown", "id": "c4963a0e", "metadata": {}, "source": [ "The default format is \"Python objects\". We need to change it to prevent datasets from returning nested lists instead of tensors." ] }, { "cell_type": "code", "execution_count": 17, "id": "5d4ae141", "metadata": {}, "outputs": [], "source": [ "train_ds = train_ds.with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "e8f326a7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image_url': 'https://thumb1.shutterstock.com/display_pic_with_logo/261388/223876810/stock-vector-christmas-tree-on-a-black-background-vector-223876810.jpg',\n", " 'caption': 'christmas tree on a black background .',\n", " 'image': tensor([[[[0.2549, 0.2549, 0.2510, ..., 0.0000, 0.0000, 0.0000],\n", " [0.2510, 0.2510, 0.2510, ..., 0.0000, 0.0000, 0.0000],\n", " [0.2471, 0.2471, 0.2471, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],\n", " \n", " [[0.0000, 0.0000, 0.0039, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0039, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],\n", " \n", " [[0.1176, 0.1176, 0.1176, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1176, 0.1176, 0.1176, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1137, 0.1137, 0.1137, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 0.9961, 0.9961, 0.9961],\n", " [1.0000, 1.0000, 1.0000, ..., 0.9882, 0.9882, 0.9882],\n", " [1.0000, 1.0000, 1.0000, ..., 0.9843, 0.9843, 0.9843]]]])}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item = next(iter(train_ds))\n", "item" ] }, { "cell_type": "code", "execution_count": 19, "id": "5eee1da3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 512, 512])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item[\"image\"].shape" ] }, { "cell_type": "markdown", "id": "6279242f", "metadata": {}, "source": [ "**How does streaming actually work?**\n", "\n", "See the implementation for this particular dataset in [`conceptual_captions.py` here](https://huggingface.co/datasets/conceptual_captions/tree/main)" ] }, { "cell_type": "markdown", "id": "412841c1", "metadata": {}, "source": [ "### Parallel loading" ] }, { "cell_type": "markdown", "id": "45575fb1", "metadata": {}, "source": [ "As suggested in [the model card](https://huggingface.co/datasets/conceptual_captions), we can download the images using parallelization. This also works for streaming datasets, so we don't need to process the entire dataset beforehand.\n", "\n", "We'll use batched mapping and a `ThreadPoolExecutor`." ] }, { "cell_type": "code", "execution_count": 20, "id": "0ea22e57", "metadata": {}, "outputs": [], "source": [ "from concurrent.futures import ThreadPoolExecutor\n", "\n", "def fetch_images(batch, num_threads=20):\n", " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", " batch[\"image\"] = list(executor.map(download_and_crop, batch[\"image_url\"]))\n", " return batch" ] }, { "cell_type": "code", "execution_count": 21, "id": "2a70ffe9", "metadata": {}, "outputs": [], "source": [ "train_ds = load_dataset(\n", " \"conceptual_captions\", split=\"train\", streaming=True, name=\"labeled\"\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "a0df55b0", "metadata": {}, "outputs": [], "source": [ "download_bs = 16" ] }, { "cell_type": "code", "execution_count": 23, "id": "6cff10c6", "metadata": {}, "outputs": [], "source": [ "train_ds = train_ds.filter(lambda x: x[\"confidence_scores\"][0] > 0.98)\n", "train_ds = train_ds.remove_columns([\"labels\", \"MIDs\", \"confidence_scores\"])\n", "train_ds = train_ds.map(fetch_images, batched=True, batch_size=download_bs)\n", "train_ds = train_ds.filter(lambda x: x[\"image\"] is not None)\n", "train_ds = train_ds.with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 24, "id": "e4ac48d0", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'image_url': 'https://thumb1.shutterstock.com/display_pic_with_logo/261388/223876810/stock-vector-christmas-tree-on-a-black-background-vector-223876810.jpg',\n", " 'caption': 'christmas tree on a black background .',\n", " 'image': tensor([[[[0.2549, 0.2549, 0.2510, ..., 0.0000, 0.0000, 0.0000],\n", " [0.2510, 0.2510, 0.2510, ..., 0.0000, 0.0000, 0.0000],\n", " [0.2471, 0.2471, 0.2471, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],\n", " \n", " [[0.0000, 0.0000, 0.0039, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0039, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],\n", " [1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],\n", " \n", " [[0.1176, 0.1176, 0.1176, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1176, 0.1176, 0.1176, ..., 0.0000, 0.0000, 0.0000],\n", " [0.1137, 0.1137, 0.1137, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [1.0000, 1.0000, 1.0000, ..., 0.9961, 0.9961, 0.9961],\n", " [1.0000, 1.0000, 1.0000, ..., 0.9882, 0.9882, 0.9882],\n", " [1.0000, 1.0000, 1.0000, ..., 0.9843, 0.9843, 0.9843]]]])}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(train_ds))" ] }, { "cell_type": "markdown", "id": "c82c5064", "metadata": {}, "source": [ "### Create `DataLoader`" ] }, { "cell_type": "code", "execution_count": 25, "id": "40146b99", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", "execution_count": 33, "id": "3b545f51", "metadata": {}, "outputs": [], "source": [ "train_ds = train_ds.shuffle(100) # This should be larger" ] }, { "cell_type": "code", "execution_count": 38, "id": "c19fc738", "metadata": {}, "outputs": [], "source": [ "loader = DataLoader(train_ds, batch_size=4)\n", "iter_loader = iter(loader)" ] }, { "cell_type": "code", "execution_count": 39, "id": "e3d831bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2min 24s, sys: 10.8 s, total: 2min 35s\n", "Wall time: 3min 56s\n" ] } ], "source": [ "%%time\n", "batch = next(iter_loader)" ] }, { "cell_type": "code", "execution_count": 40, "id": "b65d6612", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 32.7 ms, sys: 0 ns, total: 32.7 ms\n", "Wall time: 2.69 ms\n" ] } ], "source": [ "%%time\n", "batch = next(iter_loader)" ] }, { "cell_type": "code", "execution_count": 41, "id": "d39e4fd8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 1, 3, 512, 512])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"image\"].shape" ] }, { "cell_type": "markdown", "id": "b915e7d8", "metadata": {}, "source": [ "We didn't need the unsqueeze after all." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }