{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " Try in Google Colab\n", " \n", " \n", " \n", " \n", " Share via nbviewer\n", " \n", " \n", " \n", " \n", " View on GitHub\n", " \n", " \n", " \n", " \n", " Download notebook\n", " \n", "
\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Clean Google's Conceptual Captions Dataset\n", "\n", "This notebook walks you through how to download Google's Conceptual Captions Dataset, and then clean and curate the data. Once you have a refined dataset, you can use this to train your own state-of-the-art ControlNet model, or to train a model for image captioning tasks!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The only libraries we will need to clean and curate this data are [pandas](https://pandas.pydata.org/) (for tabular data) and [FiftyOne](https://github.com/voxel51/fiftyone) (for unstructured image data):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install pandas fiftyone" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Additionally, you will need [hashlib](https://docs.python.org/3/library/hashlib.html) for helper functions, and you will probably want [tqdm](https://github.com/tqdm/tqdm) to track progress while downloading images.\n", "\n", "You can import all of the required modules as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import hashlib\n", "import pandas as pd\n", "from tqdm.notebook import tqdm\n", "\n", "import fiftyone as fo\n", "import fiftyone.zoo as foz\n", "import fiftyone.brain as fob\n", "from fiftyone import ViewField as F" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Download the Dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Download the tab-separated variable (`.tsv`) file by clicking the “Download” button at the bottom of Google’s Conceptual Captions webpage, or by clicking on [this link](https://ai.google.com/research/ConceptualCaptions/download).\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can load the tsv file as a pandas DataFrame in similar fashion to a csv, by passing in `sep=\\t` to specify that the separator is a tab." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"Train_GCC-training.tsv\", sep='\\t')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Give the columns of the `DataFrame` descriptive names:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "df.columns =['caption', 'url']" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "And then hash the url for each entry to generate a unique ID:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def hash_url(url):\n", " return hashlib.md5(url.encode()).hexdigest()[:12]\n", "\n", "df['url_hash'] = df['url'].apply(hash_url)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The `DataFrame` looks like this:" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "
\n",
    "caption\turl\turl_hash\n",
    "0\tsierra looked stunning in this top and this sk...\thttp://78.media.tumblr.com/3b133294bdc7c7784b7...\te7023a8dfcd2\n",
    "1\tyoung confused girl standing in front of a war...\thttps://media.gettyimages.com/photos/young-con...\t92679c323fc6\n",
    "2\tinterior design of modern living room with fir...\thttps://thumb1.shutterstock.com/display_pic_wi...\t74c4fa5539f4\n",
    "3\tcybernetic scene isolated on white background .\thttps://thumb1.shutterstock.com/display_pic_wi...\tf1ea388e05e1\n",
    "4\tgangsta rap artist attends sports team vs play...\thttps://media.gettyimages.com/photos/jayz-atte...\t9a6f8026f593\n",
    "...\t...\t...\t...\n",
    "3318327\tthe teams line up for a photo after kick - off\thttps://i0.wp.com/i.dailymail.co.uk/i/pix/2015...\t6aec77a477f9\n",
    "3318328\tstickers given to delegates at the convention .\thttp://cdn.radioiowa.com/wp-content/uploads/20...\t7d42aea90652\n",
    "3318329\tthis is my very favourite design that i recent...\thttps://i.pinimg.com/736x/96/f0/77/96f07728efe...\tf6dd151121c0\n",
    "3318330\tman driving a car through the mountains\thttps://www.quickenloans.com/blog/wp-content/u...\tee4244df5c55\n",
    "3318331\ta longtail boat with a flag goes by spectacula...\thttp://l7.alamy.com/zooms/338c4740f7b2480dbb72...\t7625946297b7\n",
    "3318332 rows × 3 columns\n",
    "
" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We will use these IDs to specify the download locations (filepaths) of images, so that we can associate captions to the corresponding images." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If we want to download the images in batches, we can do so as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def download_batch(df, batch_size=10000, start_index=0):\n", " batch = df.iloc[start_index:start_index+batch_size]\n", " for j in tqdm(range(batch_size)):\n", " url, uh = batch.iloc[j][['url', 'url_hash']]\n", " !curl -s --connect-timeout 3 --max-time 3 \"{url}\" -o images/{uh}.jpg" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Here we download `batch_size` images starting from `start_index` into the folder `images`, with filename specified by the url hash we generated above. We use `curl` to execute the download operation, and set limits for the time spent attempting to download each image, because some of the links are no longer valid. " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "To download a total of `num_images` images, run the following:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def download_images(df, batch_size=10000, num_images = 100000):\n", " for i in range(num_images//batch_size):\n", " download_batch(df, batch_size=batch_size, start_index=i*batch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## download all images\n", "num_images = len(df)\n", "download_images(df, batch_size=10000, num_images=num_images)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load and Visualize the Data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Once we have the images downloaded into a `images` folder, we can load the images and their captions as a `Dataset` in FiftyOne:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = fo.Dataset(name=\"gcc\", persistent=True)\n", "dataset.add_sample_field(\"caption\", fo.StringField)\n", "\n", "samples = []\n", "\n", "for i in tqdm(range(num_images)):\n", " caption, uh = df.iloc[i]['caption'], df.iloc[i]['url_hash']\n", " filepath = f\"images/{uh}.jpg\"\n", " sample = fo.Sample(\n", " filepath=filepath,\n", " caption=caption\n", " )\n", " samples.append(sample)\n", "dataset.add_samples(samples)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "This code creates a `Dataset` named “gcc”, which is persisted to the underlying database, and then iterates through the first `num_images` rows of the pandas `DataFrame`, creating a `Sample` with the appropriate filepath and caption.\n", "\n", "For this walkthrough, I downloaded the first roughly 310,000 images.\n", "\n", "The first step we should take when inspecting a new computer vision dataset is to visualize it! We can do this by launching the FiftyOne App:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "session = fo.launch_app(dataset)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Remove Corrupted Samples" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "When we look at the data, we can immediately see that some of the images are not valid. This may be due to links which are no longer working, interruptions during downloading, or some other issue entirely.\n", "\n", "Fortunately, we can filter out these invalid images easily. In FiftyOne, the `compute_metadata()` method computes media-type-specific metadata for each sample. For image-based samples, this includes image width, height, and size in bytes. \n", "\n", "When the media file is nonexistent or corrupted, the metadata will be left as null. We can thus filter out the corrupted images by running `compute_metadata()` and matching for samples where the metadata exists:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.compute_metadata()\n", "\n", "## view containing only valid images\n", "view = dataset.exists(\"metadata\")\n", "\n", "session = fo.launch_app(view)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Filter by Aspect Ratio" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "A next step we may want to take is filtering out samples with unusual aspect ratios. If our goal is to control the outputs of a diffusion model, we will likely only be working with images within a certain range of reasonable aspect ratios.\n", "\n", "We can do this using FiftyOne’s `ViewField`, which allows us to apply arbitrary expressions to attributes of our samples, and then filter based on these. For instance, if we want to discard all images that are more than twice as large in either dimension as they are in the other dimension, we can do so with the following code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fiftyone import ViewField as F\n", "\n", "long_filter = F(\"metadata.width\") > 2*F(\"metadata.height\")\n", "tall_filter = F(\"metadata.height\") > 2*F(\"metadata.width\")\n", "aspect_ratio_filter = (~long_filter) & (~tall_filter)\n", "\n", "view = valid_image_view.match(aspect_ratio_filter)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "For the sake of clarity, this is what the discarded samples look like:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bad_aspect_view = valid_image_view.match(~aspect_ratio_filter)\n", "session = fo.launch_app(bad_aspect_view)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If you so choose, you can use a more or less stringent aspect ratio filter!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Filter by Resolution" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In a similar vein, we might want to remove the low resolution images. We want to generate *stunning, photorealistic* images, so there is no sense including low resolution images in the training data.\n", "\n", "This filter is similar to the aspect ratio filter. If we select 300 pixels as our lowest allowed width and height, the filter takes the form:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "hires_filter = (F(\"metadata.width\") > 300) & (F(\"metadata.height\") > 300)\n", "view = good_aspect_view.match(hires_filter)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Once again, you can choose whatever thresholds you like. For clarity, here is a representative view of the discarded images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lowres_view = good_aspect_view.match(~hires_filter)\n", "session = fo.launch_app(lowres_view)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Ensure Color Pallette" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Looking at the low resolution images, we also might be reminded that some of the images in our dataset are greyscale. We likely want to generate images that are as vibrant as possible, so we should discard the black-and-white images.\n", "\n", "In FiftyOne, one of the attributes logged in image metadata is the number of channels: color images have three channels (RGB), whereas grayscale images only have one channel. Removing grayscale images is as simple as matching for images with three channels!" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "## color images to keep\n", "view = view.match(F(\"metadata.num_channels\") == 3)\n", "## gray images to discard\n", "gray_view = view.match(F(\"metadata.num_channels\") == 1)\n", "session = fo.launch_app(gray_view)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Deduplicate the Dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Our next task in our data curation quest is to remove duplicate images. When an image is exactly or approximately duplicated in a training dataset, the resulting model may be biased by this small set of overrepresented samples - not to mention the added training costs.\n", "\n", "We can find approximate duplicates in our dataset by using a model to generate embeddings for our images (we will use a [CLIP model](https://github.com/openai/CLIP) for illustration):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## load CLIP model from the FiftyOne Model Zoo\n", "model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n", "## Compute embeddings and store them in embeddings_field\n", "view.compute_embeddings(\n", " model, \n", " embeddings_field = \"image_clip_embedding\"\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we create a similarity index based on these embeddings:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = fob.compute_similarity(view, embeddings=\"image_clip_embedding\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can set a numerical threshold at which point we will consider images approximate duplicates (here we choose 0.3), and only retain one representative from each group of approximate duplicates:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results.find_duplicates(thresh=0.3)\n", "\n", "# view the duplicates, paired up\n", "dup_view = results.duplicates_view()\n", "session = fo.launch_app(dup_view, auto = False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get one image from each group of duplicates\n", "dup_rep_ids = list(results.neighbors_map.keys())\n", "\n", "# get ids of non-duplicates\n", "non_dup_ids = view.exclude(\n", " dup_view.values(\"id\")\n", ").values(\"id\")\n", "\n", "# ids to keep\n", "ids = dup_rep_ids + non_dup_ids\n", "\n", "# create view from ids\n", "view = view[ids]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Validate Image-Caption Alignment" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Okay, now you’re in luck, because we saved the coolest step for last!\n", "\n", "Google’s Conceptual Captions Dataset consists of image-caption pairs from the internet. More precisely, “the raw descriptions are harvested from the Alt-text HTML attribute associated with web images”. This is great as an initial pass, but there are bound to be some low-quality captions in there.\n", "\n", "We may not be able to ensure that all of our captions perfectly describe their images, but we can certainly filter out some poorly aligned image-captions pairs!\n", "\n", "We will do so using [CLIPScore](https://arxiv.org/pdf/2104.08718.pdf), which is a “reference-free evaluation metric for image captioning”. In other words, you just need the image and the caption. CLIPScore is easy to implement. First, we use Scipy’s cosine distance method to define a cosine similarity function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.spatial.distance import cosine as cosine_distance\n", "\n", "def cosine(vector1, vector2):\n", " return 1. - cosine_distance(vector1, vector2)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we define a function which takes in a `Sample`, and computes the CLIPScore between image embedding and caption embedding, stored on the samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_clip_score(sample):\n", " image_embedding = sample[\"image_clip_embedding\"]\n", " caption_embedding = sample[\"caption_clip_embedding\"]\n", " return max(100.*cosine(image_embedding, caption_embedding), 0.)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Essentially, this expression just lower bounds the score at zero. The scaling factor 100 is the same as used by [PyTorch](https://torchmetrics.readthedocs.io/en/stable/multimodal/clip_score.html).\n", "\n", "We can then compute the CLIPScore - our measure of alignment between images and captions - by adding the fields to our dataset and iterating over our samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.add_sample_field(\"caption_clip_embedding\", fo.VectorField)\n", "dataset.add_sample_field(\"clip_score\", fo.FloatField)\n", "\n", "for sample in view.iter_samples(autosave=True, progress=True):\n", " sample[\"caption_clip_embedding\"] = model.embed_prompt(sample[\"caption\"])\n", " sample[\"clip_score\"] = compute_clip_score(sample)\n", "view.save()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If we want to see the “least aligned” samples, we can sort by “clip_score”." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## 100 least aligned samples\n", "least_aligned_view = view.sort_by(\"clip_score\")[:100]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "To see the most aligned samples, we can do the same, but passing in `reverse=True`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## 100 most aligned samples\n", "most_aligned_view = view.sort_by(\"clip_score\", reverse=True)[:100]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can then set a CLIPScore threshold depending on how aligned we demand the image-caption pairs are. To my taste, a threshold of 21.8 seemed good enough:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "view = view.match(F(\"clip_score\") > 21.8)\n", "gcc_clean = view.clone(name = \"gcc_clean\", persistent=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The second line clones the view into a new persistent `Dataset` named “gcc_clean”." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "After our data cleaning and curation, we have turned a relatively mediocre initial dataset into a high-quality dataset that is ready for training a ControlNet model! We surely haven’t created a perfect dataset — a perfect dataset does not exist. What we have done is addressed all of the data quality issues that plagued ControlNet 1.0, plus a few more, just for good measure :)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }