{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Assess predictions on multilabel image classification fridge data with a pytorch model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates the use of the `responsibleai` API to assess an multilabel image classification pytorch model trained on the fridge dataset. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Load Model and Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "from zipfile import ZipFile\n", "import numpy as np\n", "import pandas as pd\n", "import datasets\n", "import matplotlib\n", "matplotlib.use('Agg')\n", "import matplotlib.pyplot as plt\n", "from responsibleai_vision.common.constants import ImageColumns\n", "import json\n", "from fastai.learner import load_learner\n", "from raiutils.common.retries import retry_function\n", "\n", "try:\n", " from urllib import urlretrieve\n", "except ImportError:\n", " from urllib.request import urlretrieve\n", "\n", "EPOCHS = 10\n", "LEARNING_RATE = 1e-4\n", "IM_SIZE = 300\n", "BATCH_SIZE = 16\n", "FRIDGE_MODEL_NAME = 'multilabel_fridge_model'\n", "FRIDGE_MODEL_WINDOWS_NAME = 'multilabel_fridge_model_windows'\n", "WIN = 'win'\n", "\n", "def load_fridge_dataset():\n", " # create data folder if it doesnt exist.\n", " os.makedirs(\"data\", exist_ok=True)\n", "\n", " # download data\n", " download_url = (\"https://cvbp-secondary.z19.web.core.windows.net/\" +\n", " \"datasets/image_classification/multilabelFridgeObjects.zip\")\n", " folder_path = './data/multilabelFridgeObjects'\n", " data_file = folder_path + '.zip'\n", " urlretrieve(download_url, filename=data_file)\n", "\n", " # extract files\n", " with ZipFile(data_file, \"r\") as zipfile:\n", " zipfile.extractall(path=\"./data\")\n", "\n", " # delete zip file\n", " os.remove(data_file)\n", " data = pd.read_csv(folder_path + '/labels.csv')\n", " image_col_name = ImageColumns.IMAGE.value\n", " label_col_name = ImageColumns.LABEL.value\n", " data.rename(columns = {'filename': image_col_name,\n", " 'labels': label_col_name}, inplace = True)\n", " image_col = data[image_col_name]\n", " for i in range(len(image_col)):\n", " image_col[i] = folder_path + '/images/' + image_col[i]\n", " return data" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Load Pretrained Faster RCNN model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import urllib.request as request_file\n", "\n", "#download fine-tuned recycling model from url\n", "def download_assets(filepath,force=False):\n", " if force or not os.path.exists(filepath):\n", " request_file.urlretrieve(\n", " \"https://publictestdatasets.blob.core.windows.net/models/fastrcnn.pt\",\n", " os.path.join(filepath))\n", " else:\n", " print('Found' + filepath)\n", "\n", " return filepath" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Loading in our pretrained model \n", "import torchvision\n", "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n", "import torch\n", "import os\n", "\n", "def get_instance_segmentation_model(num_classes):\n", " # load an instance segmentation model pre-trained on COCO\n", " model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", " in_features = model.roi_heads.box_predictor.cls_score.in_features\n", " # replace the pre-trained head with a new one\n", " model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)\n", " return model\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "print('Device: ', str(device))\n", "\n", "num_classes = 5\n", "model = get_instance_segmentation_model(num_classes)\n", "_ = download_assets('Recycling_finetuned_FastRCNN.pt')\n", "model.load_state_dict(torch.load('Recycling_finetuned_FastRCNN.pt', map_location = device))\n", "\n", "#if using the general torchvision pretrained model, comment above and uncomment below\n", "# model = detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", "\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.data.transforms import Normalize\n", "from fastai.metrics import accuracy_multi\n", "from fastai.vision.data import ImageDataLoaders, imagenet_stats\n", "from fastai.vision.augment import Resize\n", "from fastai.vision import models as fastai_models\n", "from fastai.vision.learner import vision_learner\n", "from fastai.losses import BCEWithLogitsLossFlat\n", "import torch.nn as nn\n", "\n", "\n", "def train_fastai_image_classifier(df):\n", " data = ImageDataLoaders.from_df(\n", " df, valid_pct=0.2, seed=10, label_delim=' ', bs=BATCH_SIZE,\n", " batch_tfms=[Resize(IM_SIZE), Normalize.from_stats(*imagenet_stats)])\n", " model = vision_learner(data, fastai_models.resnet18, metrics=[accuracy_multi], loss_func=BCEWithLogitsLossFlat())\n", " model.unfreeze()\n", " model.fit(EPOCHS, LEARNING_RATE)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FetchModel(object):\n", " def __init__(self):\n", " pass\n", "\n", " def fetch(self):\n", " if sys.platform.startswith(WIN):\n", " model_name = FRIDGE_MODEL_WINDOWS_NAME\n", " else:\n", " model_name = FRIDGE_MODEL_NAME\n", " url = ('https://publictestdatasets.blob.core.windows.net/models/' +\n", " model_name)\n", " urlretrieve(url, FRIDGE_MODEL_NAME)\n", "\n", "\n", "def retrieve_or_train_fridge_model(df, force_train=False):\n", " if force_train:\n", " model = train_fastai_image_classifier(df)\n", " # Save model to disk\n", " model.export(FRIDGE_MODEL_NAME)\n", " else:\n", " fetcher = FetchModel()\n", " action_name = \"Dataset download\"\n", " err_msg = \"Failed to download dataset\"\n", " max_retries = 4\n", " retry_delay = 60\n", " retry_function(fetcher.fetch, action_name, err_msg,\n", " max_retries=max_retries,\n", " retry_delay=retry_delay)\n", " model = load_learner(FRIDGE_MODEL_NAME)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = load_fridge_dataset()\n", "data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = retrieve_or_train_fridge_model(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Update the labels to be in grid format\n", "data_len = len(data)\n", "can = np.zeros(data_len)\n", "carton = np.zeros(data_len)\n", "milk_bottle = np.zeros(data_len)\n", "water_bottle = np.zeros(data_len)\n", "CAN = 'can'\n", "CARTON = 'carton'\n", "MILK_BOTTLE = 'milk_bottle'\n", "WATER_BOTTLE = 'water_bottle'\n", "target_columns = [CAN, CARTON, MILK_BOTTLE, WATER_BOTTLE]\n", "for i in range(len(data)):\n", " labels = data.iloc[i]['label']\n", " labels = set(labels.split(' '))\n", " if CAN in labels:\n", " can[i] = 1\n", " if CARTON in labels:\n", " carton[i] = 1\n", " if MILK_BOTTLE in labels:\n", " milk_bottle[i] = 1\n", " if WATER_BOTTLE in labels:\n", " water_bottle[i] = 1\n", "data[CAN] = can\n", "data[CARTON] = carton\n", "data[MILK_BOTTLE] = milk_bottle\n", "data[WATER_BOTTLE] = water_bottle\n", "data.drop(columns=ImageColumns.LABEL.value, inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = data\n", "test_data = data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Model and Data Insights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from raiwidgets import ResponsibleAIDashboard\n", "from responsibleai_vision import ModelTask, RAIVisionInsights" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rai_insights = RAIVisionInsights(model, test_data.sample(10, random_state=42),\n", " target_columns,\n", " ModelTask.MULTILABEL_IMAGE_CLASSIFICATION)\n", "\n", "rai_insights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on model assessment." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Interpretability\n", "rai_insights.explainer.add()\n", "# Error Analysis\n", "rai_insights.error_analysis.add()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rai_insights.compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ResponsibleAIDashboard(rai_insights)" ] } ], "metadata": { "kernelspec": { "display_name": "raivision", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "053dc3c540a9260e9464092636530e4f43e21211e2a5bed21a6f6ec172233695" } } }, "nbformat": 4, "nbformat_minor": 2 }