{ "cells": [ { "cell_type": "markdown", "id": "382bbe19", "metadata": {}, "source": [ "---\n", "layout: post\n", "title: \"Unbalanced Validation Data Losses\"\n", "description: \"How does having an unbalanced dataset affect your loss function?\"\n", "feature-img: \"assets/img/rainbow.jpg\"\n", "thumbnail: \"assets/img/anzac_hill_lights.jpg\"\n", "tags: [Computer Vision, Data Augmentation, Deep Learning, Python]\n", "---" ] }, { "cell_type": "markdown", "id": "3bb6889e", "metadata": {}, "source": [ "Most people would say you should use a validation set (and test set) that is representative of the real world data. In many cases, I think this is a good idea. However, we have to be careful when doing this." ] }, { "cell_type": "markdown", "id": "a498bd27", "metadata": {}, "source": [ "Let's look at this example, image you're doing image classification for a disease that 1 in 100 members of the population has." ] }, { "cell_type": "markdown", "id": "7ef4af8d", "metadata": {}, "source": [ "So you create a train, val, and test set to train your model and evaluate your results. There are tradeoffs for how to makeup the train set, but this post is going to focus on the validation set." ] }, { "cell_type": "markdown", "id": "f8f16b38", "metadata": {}, "source": [ "Let's start with it being representative, as that is the commonly provider guidance. Also, this will let you know how well your model will do in the real world." ] }, { "cell_type": "code", "execution_count": 1, "id": "fa51b9a1", "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.losses import BinaryCrossentropy\n", "import tensorflow_addons as tfa\n", "import numpy as np\n", "import tensorflow as tf\n", "from mlflow import log_metric, log_param, log_artifacts" ] }, { "cell_type": "markdown", "id": "818aa0a2", "metadata": {}, "source": [ "## Representative Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "9c1aeee1", "metadata": {}, "outputs": [], "source": [ "num_0s = 1000000\n", "num_1s = 10000" ] }, { "cell_type": "markdown", "id": "b80dc8cc", "metadata": {}, "source": [ "Let's use binary cross entropy, as that's the most common loss function in these cases." ] }, { "cell_type": "code", "execution_count": 3, "id": "7899250c", "metadata": {}, "outputs": [], "source": [ "bce = BinaryCrossentropy()" ] }, { "cell_type": "code", "execution_count": 4, "id": "370f34f3", "metadata": {}, "outputs": [], "source": [ "def calc_precision(tp, fp):\n", " return (tp + np.finfo(np.float64).eps) / (tp + fp + np.finfo(np.float64).eps)\n", "\n", "def calc_recall(tp, fn):\n", " return (tp + np.finfo(np.float64).eps) / (tp + fn + np.finfo(np.float64).eps)\n", "\n", "def calc_f1(pre, rec):\n", " return 2 * (pre * rec) / (pre + rec)\n", "\n", "def get_pred_0s(num_0s, acc_0s):\n", " y_true_0s = np.zeros(num_0s)\n", " max_rand_0s = (1 - 0.5) / acc_0s\n", " y_pred_0s = np.random.uniform(low=0, high=max_rand_0s, size=num_0s)\n", " y_label_pred_0s = np.round(y_pred_0s)\n", " # num_correct = len(y_label_preds_0s) - sum(y_label_preds_0s)\n", " return y_true_0s, y_pred_0s, y_label_pred_0s\n", "\n", "def get_pred_1s(num_1s, acc_1s):\n", " y_true_1s = np.ones(num_1s)\n", " min_rand_1s = 1 - (1 - 0.5) / acc_1s\n", " y_pred_1s = np.random.uniform(low=min_rand_1s, high=1, size=num_1s)\n", " y_label_pred_1s = np.round(y_pred_1s)\n", " # num_correct = sum(y_label_preds_1s)\n", " return y_true_1s, y_pred_1s, y_label_pred_1s\n", "\n", "def calc_values(num_0s, acc_0s, num_1s, acc_1s, loss_func):\n", " y_true_0s, y_pred_0s, y_label_pred_0s = get_pred_0s(num_0s, acc_0s)\n", " y_true_1s, y_pred_1s, y_label_pred_1s = get_pred_1s(num_1s, acc_1s)\n", " y_true = np.concatenate((y_true_0s, y_true_1s))\n", " y_pred = np.concatenate((y_pred_0s, y_pred_1s))\n", " y_label_preds = np.concatenate((y_label_pred_0s, y_label_pred_1s))\n", " loss = loss_func(y_true.astype(np.float32), y_pred.astype(np.float32)).numpy()\n", " tp = sum(y_label_pred_1s == y_true_1s)\n", " fp = sum(y_label_pred_0s - y_true_0s == 1)\n", " fn = sum(y_true_1s - y_label_pred_1s)\n", " pre = calc_precision(tp, fp)\n", " recall = calc_recall(tp, fn)\n", " f1 = calc_f1(pre, recall)\n", "\n", " return loss, pre, recall, f1\n", "\n", "def print_metrics(loss, pre, recall, f1):\n", " print('Loss: ', loss)\n", " print('Precision: ', pre)\n", " print(\"Recall: \", recall)\n", " print(\"F1 Score: \", f1)" ] }, { "cell_type": "code", "execution_count": 5, "id": "628f6f05", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.999\n", "acc_1s = 0.5" ] }, { "cell_type": "code", "execution_count": 6, "id": "4dec99b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.313947\n", "Precision: 0.8396806918343589\n", "Recall: 0.5049\n", "F1 Score: 0.6306126272403673\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c08cc209", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.98\n", "acc_1s = 0.9" ] }, { "cell_type": "code", "execution_count": 8, "id": "b565cda9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.3151517\n", "Precision: 0.30972900010383136\n", "Recall: 0.8949\n", "F1 Score: 0.4601856375183195\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "markdown", "id": "0f48b2a6", "metadata": {}, "source": [ "Which model is better? Well, for an initial screen, you would probably be better off with a high-recall model. However, if you're just following the loss, this would push you *away* from that model and towards the model with much lower recall. This isn't what you want at all." ] }, { "cell_type": "markdown", "id": "c9c5cb04", "metadata": {}, "source": [ "## Balanced Data" ] }, { "cell_type": "markdown", "id": "5f0fc11a", "metadata": {}, "source": [ "Now let's use a balanced split in the data." ] }, { "cell_type": "code", "execution_count": 9, "id": "02e930f7", "metadata": {}, "outputs": [], "source": [ "num_0s = 10000\n", "num_1s = 10000\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "212bfbca", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.999\n", "acc_1s = 0.5" ] }, { "cell_type": "code", "execution_count": 11, "id": "68e1ed6e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.65621394\n", "Precision: 0.9976128903918838\n", "Recall: 0.5015\n", "F1 Score: 0.6674652292540094\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "code", "execution_count": 12, "id": "d2a45cf0", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.98\n", "acc_1s = 0.9" ] }, { "cell_type": "code", "execution_count": 13, "id": "4ba6d514", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.33528897\n", "Precision: 0.9771986970684039\n", "Recall: 0.9\n", "F1 Score: 0.9370119729307653\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, bce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "markdown", "id": "71a957e2", "metadata": {}, "source": [ "Now, the loss is much more in-line with what you would want." ] }, { "cell_type": "code", "execution_count": null, "id": "4e478b5a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "3b9403b0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "c9f37eea", "metadata": {}, "source": [ "## Focal Loss" ] }, { "cell_type": "markdown", "id": "304c6b0a", "metadata": {}, "source": [ "Focal loss, which was introduced in the [RetinaNet paper](https://arxiv.org/pdf/1708.02002.pdf), was specifically designed for situations with unbalanced data. Let's try substituting this in." ] }, { "cell_type": "code", "execution_count": 14, "id": "73df09f6", "metadata": {}, "outputs": [], "source": [ "sfce = tfa.losses.SigmoidFocalCrossEntropy()" ] }, { "cell_type": "code", "execution_count": 15, "id": "543643d8", "metadata": {}, "outputs": [], "source": [ "num_0s = 1000000\n", "num_1s = 10000" ] }, { "cell_type": "code", "execution_count": 16, "id": "4c70697f", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.999\n", "acc_1s = 0.5" ] }, { "cell_type": "code", "execution_count": 17, "id": "61b775ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 31629.947\n", "Precision: 0.835135584761271\n", "Recall: 0.502\n", "F1 Score: 0.6270688901380302\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, sfce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "code", "execution_count": 18, "id": "c8f9e6ac", "metadata": {}, "outputs": [], "source": [ "acc_0s = 0.98\n", "acc_1s = 0.9" ] }, { "cell_type": "code", "execution_count": 19, "id": "32f56bdd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 32311.969\n", "Precision: 0.31202799140857757\n", "Recall: 0.9007\n", "F1 Score: 0.46348993979313535\n" ] } ], "source": [ "l, p, r, f = calc_values(num_0s, acc_0s, num_1s, acc_1s, sfce)\n", "print_metrics(l,p,r,f)" ] }, { "cell_type": "code", "execution_count": 20, "id": "bd81a2db", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sfce" ] }, { "cell_type": "code", "execution_count": null, "id": "ab4187ae", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "dd644a6c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "67bb4418", "metadata": {}, "source": [ "Result: Using a focal loss might help to mitigate some of the differences." ] }, { "cell_type": "code", "execution_count": null, "id": "ad64f10b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }