{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "QquaYpB00hfs"
},
"source": [
"[](http://edenlibrary.ai/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yp-OgP370hft"
},
"source": [
"## Instructions\n",
"To run any of Eden's notebooks, please check the guides on our [Wiki page](https://github.com/Eden-Library-AI/eden_library_notebooks/wiki). \n",
"There you will find instructions on how to deploy the notebooks on [your local system](https://github.com/Eden-Library-AI/eden_library_notebooks/wiki/Deploy-Notebooks-Locally), on [Google Colab](https://github.com/Eden-Library-AI/eden_library_notebooks/wiki/Deploy-Notebooks-on-GColab), or on [MyBinder](https://github.com/Eden-Library-AI/eden_library_notebooks/wiki/Deploy-Notebooks-on-MyBinder), as well as other useful links, troubleshooting tips, and more. \n",
"For this notebook you will need to download the **Cotton-100619-Healthy-zz-V1-20210225102300**, **Black nightsade-220519-Weed-zz-V1-20210225102034**, **Tomato-240519-Healthy-zz-V1-20210225103740** and **Velvet leaf-220519-Weed-zz-V1-20210225104123** datasets from [Eden Library](https://edenlibrary.ai/datasets), and you may want to use the **eden_pytorch_transfer_learning.yml** file to recreate a suitable conda environment.\n",
"\n",
"**Note:** If you find any issues while executing the notebook, don't hesitate to open an issue on Github. We will try to reply as soon as possible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dwqUKc-L0hft"
},
"source": [
"## Background\n",
"\n",
"Open Neural Network Exchange ONNX provides an open source format for AI models, both deep learning and traditional ML. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types. \n",
"\n",
"ONNX is widely supported and can be found in many frameworks, tools, and hardware. Enabling interoperability between different frameworks and streamlining the path from research to production helps increase the speed of innovation in the AI community\n",
"\n",
"ONNX Runtime is a performance-focused engine for ONNX models, which inferences efficiently across multiple platforms and hardware (Windows, Linux, and Mac and on both CPUs and GPUs). ONNX Runtime has proved to considerably increase performance over multiple models.\n",
"\n",
"For this tutorial, you will need to install ONNX and ONNX Runtime. You can get binary builds of ONNX and ONNX Runtime with pip install onnx onnxruntime. Note that ONNX Runtime is compatible with Python versions 3.5 to 3.7.\n",
"\n",
"In this notebook we are going to make use of ONNX format and export our model from PyTorch to ONNX. Furthermore, we are going to use onnnxruntime to run inference.\n",
"\n",
"In this notebook, we are going to cover a technique called **Transfer Learning**, which generally refers to a process where a machine learning model is trained on one problem, and afterwards, it is reused in some way on a second (possibly) related problem (Bengio, 2012). Specifically, in **deep learning**, this technique is used by training only some layers of the pre-trained network. Its promise is that the training will be more efficient and in the best of the cases the performance will be better compared to a model trained from scratch. In this example we are using ResNet architecture and the PyTorch framework.\n",
"\n",
"It is important to note that in this notebook, inspite of making use of ONNX, we are also using the PyTorch framework to design and train our neural networks. This represents an extension over the previous Eden notebooks:\n",
"1. https://github.com/Eden-Library-AI/eden_library_notebooks/blob/master/image_classification/weeds_identification-transfer_learning-1.ipynb\n",
"2. https://github.com/Eden-Library-AI/eden_library_notebooks/blob/master/image_classification/weeds_identification-transfer_learning-2.ipynb\n",
"3. https://github.com/Eden-Library-AI/eden_library_notebooks/blob/master/image_classification/weeds_identification-transfer_learning-3.ipynb\n",
"4. https://github.com/Eden-Library-AI/eden_library_notebooks/blob/master/image_classification/weeds_identification-transfer_learning-4.ipynb\n",
"5. https://github.com/Eden-Library-AI/eden_library_notebooks/blob/master/image_classification/weeds_identification-transfer_learning-5.ipynb"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HEZ1kYtlIJIf"
},
"source": [
"### Importing Libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "hhJqeMkfivXv",
"jupyter": {
"outputs_hidden": true
},
"outputId": "903c1b7e-6bac-4fdb-a3ba-7817a907f8b0",
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Collecting onnx\n",
" Downloading onnx-1.9.0-cp38-cp38-manylinux2010_x86_64.whl (12.2 MB)\n",
"\u001b[K |████████████████████████████████| 12.2 MB 3.2 MB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.16.6 in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from onnx) (1.20.2)\n",
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from onnx) (3.7.4.3)\n",
"Collecting protobuf\n",
" Downloading protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n",
"\u001b[K |████████████████████████████████| 1.0 MB 3.4 MB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: six in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from onnx) (1.15.0)\n",
"Installing collected packages: protobuf, onnx\n",
"Successfully installed onnx-1.9.0 protobuf-3.17.3\n",
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Collecting onnxruntime\n",
" Downloading onnxruntime-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n",
"\u001b[K |████████████████████████████████| 4.5 MB 2.7 MB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: protobuf in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from onnxruntime) (3.17.3)\n",
"Collecting flatbuffers\n",
" Downloading flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n",
"Requirement already satisfied: numpy>=1.16.6 in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from onnxruntime) (1.20.2)\n",
"Requirement already satisfied: six>=1.9 in /home/air/anaconda3/envs/eden_pytorch_transfer/lib/python3.8/site-packages (from protobuf->onnxruntime) (1.15.0)\n",
"Installing collected packages: flatbuffers, onnxruntime\n",
"Successfully installed flatbuffers-2.0 onnxruntime-1.8.0\n"
]
}
],
"source": [
"# In case it is not installed in your system run pip installs ( Google Colab doesn't have onnx by default)\n",
"!pip install onnx\n",
"!pip install onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xvqku2sVIO0s"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import torchvision\n",
"from torchvision import datasets, models, transforms\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import os\n",
"import copy\n",
"import random\n",
"import shutil\n",
"\n",
"# onnx necessary packages\n",
"import torch.onnx\n",
"import onnx\n",
"import onnxruntime\n",
"\n",
"plt.ion() # interactive mode"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zO6c2eoAvt_"
},
"source": [
"### Folder structuring \n",
"We are going to create a main data folder 'eden_data' that will contain the 4 different datasets. We will also split the datasets into train and validation sub-sets. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LxFPCQ91iatm"
},
"outputs": [],
"source": [
"# Change this path to correspong to your system. It needs to point to your eden-library-datasets folder \n",
"\n",
"DATA_PATH = '/home/air/Desktop/EDEN-REPO/eden_library_notebooks/eden-library-datasets'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hMp60uWfULoe",
"outputId": "991765e0-12af-47a0-ed0c-0c26f2ec34f0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Moved DSC_0514.JPG to validation images\n",
"Moved DSC_0536.JPG to validation images\n",
"Moved DSC_0741.JPG to validation images\n",
"Moved DSC_0504.JPG to validation images\n",
"Moved DSC_0501.JPG to validation images\n",
"Moved DSC_0553.JPG to validation images\n",
"Moved DSC_0532.JPG to validation images\n",
"Moved DSC_0726.JPG to validation images\n",
"Moved DSC_0740.JPG to validation images\n",
"Moved DSC_0516.JPG to validation images\n",
"Moved DSC_0550.JPG to validation images\n",
"Moved DSC_0528.JPG to validation images\n",
"Moved DSC_0506.JPG to validation images\n",
"Moved DSC_0717.JPG to validation images\n",
"Moved DSC_0628.JPG to validation images\n",
"Moved DSC_0508.JPG to validation images\n",
"Moved DSC_0517.JPG to validation images\n",
"Moved DSC_0538.JPG to validation images\n",
"Moved DSC_0711.JPG to validation images\n",
"Moved DSC_0723.JPG to validation images\n",
"Moved DSC_0646.JPG to validation images\n",
"Moved DSC_0567.JPG to validation images\n",
"Moved DSC_0500.JPG to validation images\n",
"Moved DSC_0551.JPG to validation images\n",
"Moved DSC_0689.JPG to validation images\n",
"Moved DSC_0787.JPG to validation images\n",
"Moved DSC_0828.JPG to validation images\n",
"Moved DSC_0247.JPG to validation images\n",
"Moved DSC_0195.JPG to validation images\n",
"Moved DSC_0215.JPG to validation images\n",
"Moved DSC_0259.JPG to validation images\n",
"Moved DSC_0204.JPG to validation images\n",
"Moved DSC_0268.JPG to validation images\n",
"Moved DSC_0205.JPG to validation images\n",
"Moved DSC_0795.JPG to validation images\n",
"Moved DSC_0278.JPG to validation images\n",
"Moved DSC_0203.JPG to validation images\n",
"Moved DSC_0253.JPG to validation images\n",
"Moved DSC_0210.JPG to validation images\n",
"Moved DSC_0312.JPG to validation images\n",
"Moved DSC_0817.JPG to validation images\n",
"Moved DSC_0212.JPG to validation images\n",
"Moved DSC_0230.JPG to validation images\n",
"Moved DSC_0246.JPG to validation images\n",
"Moved DSC_0326.JPG to validation images\n",
"Moved DSC_0805.JPG to validation images\n",
"Moved DSC_0315.JPG to validation images\n",
"Moved DSC_0301.JPG to validation images\n",
"Moved DSC_0832.JPG to validation images\n",
"Moved DSC_0257.JPG to validation images\n",
"Moved DSC_0280.JPG to validation images\n",
"Moved DSC_0809.JPG to validation images\n",
"Moved DSC_0792.JPG to validation images\n",
"Moved DSC_0199.JPG to validation images\n",
"Moved DSC_0813.JPG to validation images\n",
"Moved DSC_0783.JPG to validation images\n",
"Moved DSC_0273.JPG to validation images\n",
"Moved DSC_0281.JPG to validation images\n",
"Moved DSC_0272.JPG to validation images\n",
"Moved DSC_0310.JPG to validation images\n",
"Moved DSC_0261.JPG to validation images\n",
"Moved DSC_0294.JPG to validation images\n",
"Moved DSC_0250.JPG to validation images\n",
"Moved DSC_0789.JPG to validation images\n",
"Moved DSC_0798.JPG to validation images\n",
"Moved DSC_0672.JPG to validation images\n",
"Moved DSC_0651.JPG to validation images\n",
"Moved DSC_0673.JPG to validation images\n",
"Moved DSC_0666.JPG to validation images\n",
"Moved DSC_0681.JPG to validation images\n",
"Moved DSC_0663.JPG to validation images\n",
"Moved DSC_0654.JPG to validation images\n",
"Moved DSC_0635.JPG to validation images\n",
"Moved DSC_0653.JPG to validation images\n",
"Moved DSC_0583.JPG to validation images\n",
"Moved DSC_0486.JPG to validation images\n",
"Moved DSC_0642.JPG to validation images\n",
"Moved DSC_0610 - Copy.JPG to validation images\n",
"Moved DSC_0607.JPG to validation images\n",
"Moved DSC_0493.JPG to validation images\n",
"Moved DSC_0737.JPG to validation images\n",
"Moved DSC_0638.JPG to validation images\n",
"Moved DSC_0498.JPG to validation images\n",
"Moved DSC_0705.JPG to validation images\n",
"Moved DSC_0644.JPG to validation images\n",
"Moved DSC_0562.JPG to validation images\n",
"Moved DSC_0612 - Copy.JPG to validation images\n",
"Moved DSC_0608.JPG to validation images\n",
"Moved DSC_0619.JPG to validation images\n",
"Moved DSC_0617.JPG to validation images\n",
"Moved DSC_0708.JPG to validation images\n",
"Moved DSC_0580.JPG to validation images\n",
"Moved DSC_0743.JPG to validation images\n",
"Moved DSC_0611.JPG to validation images\n",
"Moved DSC_0728.JPG to validation images\n",
"Moved DSC_0609.JPG to validation images\n",
"Moved DSC_0489.JPG to validation images\n",
"Moved DSC_0606 - Copy.JPG to validation images\n",
"Moved DSC_0602.JPG to validation images\n",
"Moved DSC_0579.JPG to validation images\n"
]
}
],
"source": [
"## WARNING : This cell script will Move the 4 datasets used here from your eden-library-datasets directory and put them to a new one created for this particular dataset. \n",
"# Your initial eden-library-datasets will not contain the 4 datasets used here after the script. \n",
"\n",
"# Change paths to suit your system (this was created for google colab)\n",
"if not os.path.exists(DATA_PATH) :\n",
" os.makedirs(DATA_PATH)\n",
"# Directory that will contain all of the data needed for training\n",
"notebook_dataset = os.path.join(DATA_PATH, 'pytorch-onnx')\n",
"\n",
"# Create train and val folders that will host the data.\n",
"train_path = os.path.join(notebook_dataset, 'train')\n",
"if not os.path.exists(train_path):\n",
" os.makedirs(train_path)\n",
"val_path = os.path.join(notebook_dataset, 'val')\n",
"if not os.path.exists(val_path):\n",
" os.makedirs(val_path)\n",
"\n",
"# names of the datasets we are going to use \n",
"classes = [\"Black nightsade-220519-Weed-zz-V1-20210225102034\", \"Tomato-240519-Healthy-zz-V1-20210225103740\", \n",
" \"Cotton-100619-Healthy-zz-V1-20210225102300\", \"Velvet leaf-220519-Weed-zz-V1-20210225104123\"]\n",
"num_classes = len(classes) # we will need this later\n",
"for class_name in classes:\n",
" # Path to source folders\n",
" class_path = DATA_PATH + os.path.sep + class_name\n",
"\n",
" # Create subfolder for each class in validation folder\n",
" class_val_path = val_path + os.path.sep + class_name\n",
" os.mkdir(class_val_path)\n",
" # Move original folder to train folder, created above\n",
" class_train_path = train_path + os.path.sep + class_name\n",
" shutil.move(class_path, train_path)\n",
"\n",
" # List of all files\n",
" images = os.listdir(class_train_path)\n",
"\n",
" # Splitting randomly, choosing some files for validation.\n",
" valid_images = random.sample(\n",
" images, (int(round(len(images) * 0.2)))\n",
" ) # Change ' *0.1 ' to whatever train-test split value you want\n",
" # Move validation images to validation folder\n",
" for val_image in valid_images:\n",
" shutil.move(\n",
" class_train_path + os.path.sep + val_image,\n",
" class_val_path + os.path.sep + val_image,\n",
" )\n",
" print(\"Moved \", val_image, \" to validation images\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TSz6HE1Fvj_2"
},
"source": [
"### Auxiliar functions "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PAE-Uq355fpL"
},
"outputs": [],
"source": [
"\"\"\"Training function. Train input model based on the parameters given.\n",
" Input:\n",
" model: model to train\n",
" criterion: loss function to be used for training \n",
" optimizer: optimizer \n",
" scheduler: learning rate scheduler\n",
" num_epochs: number of training epochs\n",
"\n",
" Returns: Trained model\n",
" \n",
"\"\"\"\n",
"\n",
"\n",
"def train_model(model, criterion, optimizer, scheduler, num_epochs=50):\n",
" since = time.time()\n",
" best_model = copy.deepcopy(model.state_dict())\n",
" best_acc = 0.0\n",
"\n",
" for epoch in range(num_epochs):\n",
" print(\"Epoch {}/{}\".format(epoch, num_epochs - 1))\n",
" print(\"-\" * 10)\n",
"\n",
" # Each epoch has a training and a validation phase\n",
" for phase in [\"train\", \"val\"]:\n",
" if phase == \"train\":\n",
" model.train() # Set model to training mode\n",
" else:\n",
" model.eval() # Set model to evaluation mode\n",
"\n",
" # Reset loss\n",
" running_loss = 0.0\n",
" running_corrects = 0\n",
"\n",
" # Iterate over data\n",
" for inputs, labels in dataloaders[phase]:\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # zero the parameters gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward\n",
" # track history if only in train\n",
" with torch.set_grad_enabled(phase == \"train\"):\n",
" outputs = model(inputs)\n",
" _, preds = torch.max(outputs, 1)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # backward pass + optimize only if in training phase\n",
" if phase == \"train\":\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # statistics\n",
" running_loss += loss.item() * inputs.size(0)\n",
" running_corrects += torch.sum(preds == labels.data)\n",
" if phase == \"train\":\n",
" scheduler.step()\n",
"\n",
" epoch_loss = running_loss / dataset_sizes[phase]\n",
" epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
"\n",
" print(\"{} Loss: {:.4f} Acc: {:.4f}\".format(phase, epoch_loss, epoch_acc))\n",
"\n",
" # deep copy the model\n",
" if phase == \"val\" and epoch_acc > best_acc:\n",
" best_acc = epoch_acc\n",
" best_model_wts = copy.deepcopy(model.state_dict())\n",
"\n",
" print()\n",
"\n",
" time_elapsed = time.time() - since\n",
" print(\n",
" \"Training complete in {:.0f}m {:.0f}s\".format(\n",
" time_elapsed // 60, time_elapsed % 60\n",
" )\n",
" )\n",
" print(\"Best val Acc: {:4f}\".format(best_acc))\n",
"\n",
" # load best model weights\n",
" model.load_state_dict(best_model_wts)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fiSNg6MJarNj"
},
"outputs": [],
"source": [
"\"\"\"\n",
"Runs inference on a defined number of images with the specified model. Plots the images with the model's predictions. \n",
" Input:\n",
" model : model to run inference with\n",
" num_images : number of validation set images to make predictions on \n",
"\n",
" Returns : Plotted images and predictions\n",
"\"\"\"\n",
"\n",
"\n",
"def visualize_predictions(model, num_images=6):\n",
" was_training = model.training\n",
" model.eval()\n",
" images_so_far = 0\n",
" fig = plt.figure()\n",
"\n",
" with torch.no_grad():\n",
" for i, (inputs, labels) in enumerate(dataloaders[\"val\"]):\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" outputs = model(inputs)\n",
" _, preds = torch.max(outputs, 1)\n",
"\n",
" for j in range(inputs.size()[0]):\n",
" images_so_far += 1\n",
" ax = plt.subplot(num_images // 2, 2, images_so_far)\n",
" ax.axis(\"off\")\n",
" ax.set_title(\n",
" \"predicted: {} with\".format(class_names[preds[j]])\n",
" + \" Actual class: {}\".format(class_names[labels[j]])\n",
" )\n",
" imshow(inputs.cpu().data[j])\n",
"\n",
" if images_so_far == num_images:\n",
" model.train(mode=was_training)\n",
" return\n",
" model.train(mode=was_training)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RmhCuICPg9os"
},
"outputs": [],
"source": [
"\"\"\"\n",
"Plot images\n",
"\"\"\"\n",
"\n",
"\n",
"def imshow(inp, title=None):\n",
" \"\"\"Imshow for Tensor.\"\"\"\n",
" inp = inp.numpy().transpose((1, 2, 0))\n",
" mean = np.array([0.485, 0.456, 0.406])\n",
" std = np.array([0.229, 0.224, 0.225])\n",
" inp = std * inp + mean\n",
" inp = np.clip(inp, 0, 1)\n",
" plt.imshow(inp)\n",
" if title is not None:\n",
" plt.title(title)\n",
" plt.pause(0.001) # pause a bit so that plots are updated"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qb7d5siZIpuw"
},
"source": [
"### Data loading and augmentation\n",
"\n",
"First we need to load our data into our pipeline. Since our dataset is not very big we are going to apply some data augmentation in order to increase the generalization power of the network. Lastly, we are going to normalize our training and validation data for better performance and accuracy.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E_iZOQTFJOp6"
},
"outputs": [],
"source": [
"# Defining some Augmentation techniques\n",
"data_transforms = {\n",
" # We are going to use Compose,in order to chain together multiple transformations\n",
" \"train\": transforms.Compose(\n",
" [\n",
" transforms.RandomResizedCrop((224, 224)),\n",
" transforms.RandomHorizontalFlip(),\n",
" # Converting images to tensors. PyTorch needs input in tensor form.\n",
" transforms.ToTensor(),\n",
" # Normalizing inputs, these values are porposed by pytorch for ResNet\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
" ),\n",
" \"val\": transforms.Compose(\n",
" [\n",
" transforms.Resize((224, 224)),\n",
" transforms.CenterCrop(224),\n",
" # Converting images to tensors. PyTorch needs input in tensor form.\n",
" transforms.ToTensor(),\n",
" # Normalizing inputs, these values are porposed by pytorch for ResNet\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
" ),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p9d715AFOdRQ",
"outputId": "3856a9ab-bcfa-48b0-c75a-3129955e88c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class names : ['Black nightsade-220519-Weed-zz-V1-20210225102034', 'Cotton-100619-Healthy-zz-V1-20210225102300', 'Tomato-240519-Healthy-zz-V1-20210225103740', 'Velvet leaf-220519-Weed-zz-V1-20210225104123']\n",
"Dataset_sizes : {'train': 399, 'val': 100}\n"
]
}
],
"source": [
"# Loading the datasets\n",
"data_dir = notebook_dataset\n",
"image_datasets = {\n",
" x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])\n",
" for x in [\"train\", \"val\"]\n",
"}\n",
"\n",
"dataloaders = {\n",
" x: torch.utils.data.DataLoader(\n",
" image_datasets[x], batch_size=4, shuffle=True, num_workers=1\n",
" )\n",
" for x in [\"train\", \"val\"]\n",
"}\n",
"\n",
"dataset_sizes = {x: len(image_datasets[x]) for x in [\"train\", \"val\"]}\n",
"class_names = image_datasets[\"train\"].classes\n",
"print(\"Class names :\", class_names)\n",
"print(\"Dataset_sizes : \", dataset_sizes)\n",
"# Setting up device either cuda GPU or CPU\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5S0_-SbgCPx7"
},
"source": [
"#### Visualizing some of the augmented training data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 166
},
"id": "yeFgaKIkyVrG",
"outputId": "b27464de-9c5b-4710-cfd8-c96d0e569f47"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Images' size: torch.Size([4, 3, 224, 224])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"