{ "cells": [ { "cell_type": "markdown", "id": "23539678", "metadata": {}, "source": [ "[![image](https://raw.githubusercontent.com/visual-layer/visuallayer/main/imgs/vl_horizontal_logo.png)](https://www.visual-layer.com)" ] }, { "cell_type": "markdown", "id": "1b5ed76a", "metadata": {}, "source": [ "# Hugging Face Datasets\n", "This notebook shows how you can load VL Datasets from Hugging Face Datasets and train in PyTorch.\n", "\n", "We will load the [`vl-food101`](https://huggingface.co/datasets/visual-layer/vl-food101) dataset - a sanitized version of the original [Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/). Learn more [here](https://docs.visual-layer.com/docs/available-datasets#vl-food101).\n", "\n", "The `vl-food101` is curated to minimize duplicates, outliers, blurry, overly dark and bright images.\n", "The following table summarizes the issues we found in the original Food101 dataset and were removed in in vl-food101.\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CategoryPercentageCount
Duplicates
0.23%
235
Outliers
0.08%
77
Blur
0.18%
185
Dark
0.04%
43
Leakage
0.086%
87
Total
0.62%
627
\n" ] }, { "cell_type": "markdown", "id": "b1e53ed2", "metadata": {}, "source": [ "## Installation" ] }, { "cell_type": "code", "execution_count": 1, "id": "a27bd381", "metadata": {}, "outputs": [], "source": [ "!pip install -Uq datasets torchvision" ] }, { "cell_type": "markdown", "id": "3b631046", "metadata": {}, "source": [ "## Load Dataset\n", "\n", "Now load the vl-food101 dataset from Hugging Face Dataset. See the dataset card [here](https://huggingface.co/datasets/visual-layer/vl-food101)." ] }, { "cell_type": "code", "execution_count": 2, "id": "2361e6f4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset parquet (/media/dnth/Active-Projects/vl-datasets/notebooks/images_dir/visual-layer___parquet/visual-layer--vl-food101-bd3d25b1793d94e4/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n", "Found cached dataset parquet (/media/dnth/Active-Projects/vl-datasets/notebooks/images_dir/visual-layer___parquet/visual-layer--vl-food101-bd3d25b1793d94e4/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "train_dataset = load_dataset(\"visual-layer/vl-food101\", split=\"train\", cache_dir='images_dir')\n", "valid_dataset = load_dataset(\"visual-layer/vl-food101\", split=\"test\", cache_dir='images_dir')" ] }, { "cell_type": "code", "execution_count": 3, "id": "4c45ee5e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 75284\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "43b4a39d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image': ,\n", " 'label': 0}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset[0]" ] }, { "cell_type": "markdown", "id": "6e8580d7", "metadata": {}, "source": [ "## Transform Dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "b33fb972", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "import torchvision\n", "\n", "import torchvision.transforms as transforms\n", "\n", "train_transforms = transforms.Compose(\n", " [\n", " transforms.RandomResizedCrop(64),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "valid_transform = transforms.Compose(\n", " [\n", " transforms.Resize((64, 64)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "def preprocess_train(example_batch):\n", " \"\"\"Apply train_transforms across a batch.\"\"\"\n", " example_batch[\"pixel_values\"] = [\n", " train_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]\n", " ]\n", " return example_batch\n", "\n", "def preprocess_valid(example_batch):\n", " \"\"\"Apply valid_transforms across a batch.\"\"\"\n", " example_batch[\"pixel_values\"] = [\n", " valid_transform(image.convert(\"RGB\")) for image in example_batch[\"image\"]\n", " ]\n", " return example_batch" ] }, { "cell_type": "code", "execution_count": 6, "id": "bf954459", "metadata": {}, "outputs": [], "source": [ "train_dataset.set_transform(preprocess_train)\n", "valid_dataset.set_transform(preprocess_valid)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c339921e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image': ,\n", " 'label': 0,\n", " 'pixel_values': tensor([[[ 2.1119, 2.0948, 2.0605, ..., 0.8618, 1.3755, 1.7523],\n", " [ 1.8893, 1.8379, 1.7523, ..., 0.8961, 1.2557, 1.3070],\n", " [ 1.9749, 1.8893, 1.8037, ..., 0.9132, 0.9988, 0.9646],\n", " ...,\n", " [-1.0048, -0.9877, -1.0219, ..., 0.4679, 0.8104, 0.6392],\n", " [-1.0390, -1.0390, -1.0219, ..., 1.4098, 0.8789, 0.3481],\n", " [-1.0733, -1.0219, -1.0390, ..., 1.2043, 0.5878, 0.3823]],\n", " \n", " [[ 2.3410, 2.3235, 2.2885, ..., -0.4601, -0.0574, 0.3627],\n", " [ 2.0784, 2.0259, 1.9384, ..., -0.4601, -0.1975, -0.1450],\n", " [ 2.2185, 2.0959, 1.9734, ..., -0.4426, -0.3901, -0.4426],\n", " ...,\n", " [-1.7031, -1.7206, -1.7556, ..., -0.2675, 0.0301, -0.0924],\n", " [-1.7031, -1.7381, -1.7381, ..., 0.8880, 0.1352, -0.3725],\n", " [-1.7206, -1.7206, -1.7206, ..., 0.5903, -0.2325, -0.3901]],\n", " \n", " [[ 2.4831, 2.4657, 2.4483, ..., -1.4210, -1.2119, -0.9678],\n", " [ 2.0823, 2.0300, 1.9603, ..., -1.4036, -1.2467, -1.2990],\n", " [ 2.1868, 2.0823, 1.9603, ..., -1.3687, -1.3513, -1.3861],\n", " ...,\n", " [-1.6127, -1.5953, -1.6127, ..., -0.8284, -0.6715, -0.7936],\n", " [-1.6127, -1.6127, -1.5953, ..., 0.1128, -0.6193, -1.0550],\n", " [-1.6302, -1.6127, -1.5953, ..., -0.2010, -0.8807, -1.0027]]])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset[0]" ] }, { "cell_type": "code", "execution_count": 8, "id": "c56444d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 64, 64])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset[0][\"pixel_values\"].shape" ] }, { "cell_type": "code", "execution_count": 9, "id": "1a0b21ea", "metadata": {}, "outputs": [], "source": [ "def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " labels = torch.tensor([example[\"label\"] for example in examples])\n", " return {\"pixel_values\": pixel_values, \"labels\": labels}" ] }, { "cell_type": "code", "execution_count": 10, "id": "15213620", "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True , collate_fn=collate_fn)\n", "valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "id": "d27d9ba9", "metadata": {}, "source": [ "## Define Model and Hyperparameters" ] }, { "cell_type": "code", "execution_count": 11, "id": "60ededda", "metadata": {}, "outputs": [], "source": [ "model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "num_ftrs = model.fc.in_features\n", "model.fc = nn.Linear(num_ftrs, len(train_dataset.features[\"label\"].names))" ] }, { "cell_type": "code", "execution_count": 12, "id": "bca6d219", "metadata": {}, "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "markdown", "id": "c8ac5796", "metadata": {}, "source": [ "## Train and Evaluate" ] }, { "cell_type": "code", "execution_count": 13, "id": "0ee15349", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9e175d78cb12413fa7745bb95f9749c5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epochs: 0%| | 0/5 [00:00