{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Jensen-Shannon Divergence & Cross-Entropy Loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import timm\n", "import torch\n", "import torch.nn.functional as F\n", "from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\n", "from timm.loss import JsdCrossEntropy\n", "from timm.data.mixup import mixup_target\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create a example of the `output` of a model, and our `labels`. Note we have 3 output predictions, but only 1 label. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output = F.one_hot(torch.tensor([0,9,0])).float()\n", "labels=torch.tensor([0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we set label `smoothing` and `alpha` to 0, then we will have the regular `cross_entropy loss`, if we look only at the first element of our output and labels. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jsd = JsdCrossEntropy(smoothing=0,alpha=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.4612)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsd(output,labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.4612)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "base_loss = F.cross_entropy(output[0,None],labels[0,None])\n", "base_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jsd = JsdCrossEntropy(num_splits=1,smoothing=0,alpha=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also change the number of splits,changing the size of each group. In `Augmix` this would equate to the number of transformation mixtures. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jsd = JsdCrossEntropy(num_splits=2,smoothing=0,alpha=0)\n", "output = F.one_hot(torch.tensor([0,9,1,0])).float()\n", "labels=torch.tensor([0,9])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(1.4612), tensor(1.4612))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsd(output,labels),F.cross_entropy(output[[0,1]],labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default we have 1 label for 3 predictions, this is a two part loss, and measures both cross entropy and jason-shannon divergence. Jason-shannon entropy does not need a label, instead measuring the how significantly different the 3 predictions are." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jsd = JsdCrossEntropy(smoothing=0)\n", "output = F.one_hot(torch.tensor([0,0,0]),num_classes=10).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.1000, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deltas = torch.cat((torch.zeros([2,10]),torch.tensor([[-1,1,0,0,0,0,0,0,0,0]])))*0.1\n", "deltas[2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "deltas=(torch.arange(-10,11))[...,None,None]*deltas" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "losses = [jsd((output+delta),labels)-base_loss for delta in deltas]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The below graph shows how changes in one of the model's outputs(prediction), in a group, effects the Jason-Shannon Divergence. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([ .1*i-1 for i in range(len(losses))],[loss for loss in losses])\n", "plt.ylabel('JS Divergence')\n", "plt.xlabel('Change in output')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_model_architectures.ipynb.\n", "Converted 01_training_scripts.ipynb.\n", "Converted 02_dataset.ipynb.\n", "Converted 03_loss.cross_entropy.ipynb.\n", "Converted 04_models.ipynb.\n", "Converted 05_loss.jsd_cross_entropy.ipynb.\n", "Converted index.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }