{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom Loss Function in Tensorflow 2. \n", "\n", "> In this post, we will learn how to build custom loss functions with function and class. This is the summary of lecture \"Custom Models, Layers and Loss functions with Tensorflow\" from DeepLearning.AI.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow, DeepLearning.AI]\n", "- image: images/huber_loss_ex.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.utils import plot_model\n", "from tensorflow.keras import backend as K\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1 - Huber Loss\n", "\n", "In this section, we'll walk through how to create custom loss functions. In particular, we'll code the [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss) and use that in training the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the Data\n", "\n", "Our dummy dataset is just a pair of arrays `xs` and `ys` defined by the relationship $y = 2x - 1$. `xs` are the inputs while `ys` are the labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# inputs\n", "xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)\n", "# labels\n", "ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAM3UlEQVR4nO3dX2hk93nG8eepIpNpk6ALC1ppTZXSMNSENirCJBhacAzaOCZRSwsOJPQf7E1THAhKLXLV64GQQkPL4qS9iKkpiaqGNO10Q2xCoXWitbaRHXmKMQ3eUYoVypC0HWqt/PZC0sZSZGtm5zdz5tV8PyDYOdL+znuw98vhnKMZR4QAAHn9VNUDAAAGQ8gBIDlCDgDJEXIASI6QA0Byb6lip3fffXcsLCxUsWsASOv69es/iIjZ09srCfnCwoI2Nzer2DUApGX7e2dt59IKACRHyAEgOUIOAMkRcgBIjpADQHKEHACSq+TxQwCYJBtbbTWaLe12upqbqWl1ua6Vxfli6xNyABiija221ta31d0/kCS1O12trW9LUrGYc2kFAIao0Wzdjvix7v6BGs1WsX0QcgAYot1Ot6/td4KQA8AQzc3U+tp+Jwg5AAzR6nJdtempE9tq01NaXa4X2wc3OwFgiI5vaPLUCgAktrI4XzTcp3FpBQCSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcoQcAJIj5ACQHCEHgOSKhNz2jO0v2X7B9o7t95VYFwBwvlJvmvWnkv4xIn7L9l2SfrrQugCAcwwcctvvkPRrkn5XkiLiVUmvDrouAKA3JS6t/IKkPUl/aXvL9uO2f+b0D9m+YnvT9ube3l6B3QIApDIhf4ukX5X05xGxKOl/JD12+oci4mpELEXE0uzsbIHdAgCkMiG/KelmRDxz9PpLOgw7AGAEBg55RPynpJdtH38A3fslfXfQdQEAvSn11MofSXri6ImVlyT9XqF1AQDnKBLyiLghaanEWgCA/vCbnQCQHCEHgOQIOQAkR8gBIDlCDgDJEXIASI6QA0ByhBwAkiPkAJAcIQeA5Ag5ACRX6k2zAKAnG1ttNZot7Xa6mpupaXW5rpXF+arHSo2QAxiZja221ta31d0/kCS1O12trW9LEjEfAJdWAIxMo9m6HfFj3f0DNZqtiia6GAg5gJHZ7XT72o7eEHIAIzM3U+trO3pDyAGMzOpyXbXpqRPbatNTWl2uv8HfQC+42QlgZI5vaPLUSlmEHMBIrSzOE+7CuLQCAMkRcgBIjpADQHKEHACSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcsVCbnvK9pbtr5ZaEwBwvpJn5I9K2im4HgCgB0VCbvuSpA9KerzEegCA3pU6I/+spE9Jeq3QegCAHg0cctsPS3olIq6f83NXbG/a3tzb2xt0twCAIyXOyO+X9CHb/yHpSUkP2P7i6R+KiKsRsRQRS7OzswV2CwCQCoQ8ItYi4lJELEh6RNI3IuKjA08GAOgJz5EDQHJFP7MzIp6W9HTJNQEAb44zcgBIjpADQHKEHACSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcoQcAJIr+l4rAPq3sdVWo9nSbqeruZmaVpfrWlmcr3osJELIgQptbLW1tr6t7v6BJKnd6WptfVuSiDl6xqUVoEKNZut2xI919w/UaLYqmggZEXKgQrudbl/bgbMQcqBCczO1vrYDZyHkQIVWl+uqTU+d2FabntLqcr2iiZARNzuBCh3f0OSpFQyCkAMVW1mcJ9wYCJdWACA5Qg4AyRFyAEiOkANAcoQcAJIj5ACQHCEHgOQIOQAkR8gBIDlCDgDJEXIASI6QA0ByA4fc9j22n7K9Y/t524+WGAwA0JsS7354S9InI+JZ22+XdN32tYj4boG1AQDnGPiMPCK+HxHPHv35R5J2JPGenAAwIkWvkdtekLQo6ZkzvnfF9qbtzb29vZK7BYCJVizktt8m6cuSPhERPzz9/Yi4GhFLEbE0OztbarcAMPGKhNz2tA4j/kRErJdYEwDQmxJPrVjS5yXtRMRnBh8JANCPEmfk90v6mKQHbN84+nqowLoAgB4M/PhhRPyzJBeYBQBwB/jNTgBIjpADQHKEHACSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcoQcAJIr8VFvQDEbW201mi3tdrqam6lpdbmulUU+cAp4M4QcY2Njq6219W119w8kSe1OV2vr25JEzIE3waUVjI1Gs3U74se6+wdqNFsVTQTkQMgxNnY73b62AzhEyDE25mZqfW0HcIiQY2ysLtdVm546sa02PaXV5XpFEwE5cLMTY+P4hiZPrQD9IeQYKyuL84Qb6BOXVgAgOUIOAMkRcgBIjpADQHKEHACSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiuSMhtX7bdsv2i7cdKrAkA6M3AIbc9Jelzkj4g6V5JH7F976DrAgB6U+KM/D5JL0bESxHxqqQnJX24wLoAgB6UCPm8pJdf9/rm0TYAwAiUCLnP2BY/8UP2Fdubtjf39vYK7BYAIJUJ+U1J97zu9SVJu6d/KCKuRsRSRCzNzs4W2C0AQCoT8m9Lepftd9q+S9Ijkr5SYF0AQA8G/qi3iLhl++OSmpKmJH0hIp4feDIAQE+KfGZnRHxN0tdKrAUA6A+/2QkAyRFyAEiOkANAcoQcAJIj5ACQHCEHgOQIOQAkR8gBIDlCDgDJEXIASI6QA0ByRd5rBcOxsdVWo9nSbqeruZmaVpfrWlnkMzsAnETIx9TGVltr69vq7h9IktqdrtbWtyWJmAM4gUsrY6rRbN2O+LHu/oEazVZFEwEYV4R8TO12un1tBzC5CPmYmpup9bUdwOQi5GNqdbmu2vTUiW216SmtLtcrmgjAuOJm55g6vqHJUysAzkPIx9jK4jzhBnAuLq0AQHKEHACSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcoQcAJIj5ACQHCEHgOQGCrnthu0XbH/H9t/anik1GACgN4OekV+T9O6I+GVJ/y5pbfCRAAD9GCjkEfFPEXHr6OW/Sro0+EgAgH6UvEb++5L+4Y2+afuK7U3bm3t7ewV3CwCT7dwPlrD9dUk/e8a3Ph0Rf3f0M5+WdEvSE2+0TkRclXRVkpaWluKOpgUA/IRzQx4RD77Z923/jqSHJb0/Igg0AIzYQB/1ZvuypD+W9OsR8b9lRgIA9GPQa+R/Juntkq7ZvmH7LwrMBADow0Bn5BHxi6UGAQDcGX6zEwCSI+QAkBwhB4DkCDkAJEfIASA5Qg4AyRFyAEiOkANAcoQcAJIj5ACQHCEHgOQGeq+VUdvYaqvRbGm309XcTE2ry3WtLM5XPRYAVCpNyDe22lpb31Z3/0CS1O50tba+LUnEHMBES3NppdFs3Y74se7+gRrNVkUTAcB4SBPy3U63r+0AMCnShHxuptbXdgCYFGlCvrpcV2166sS22vSUVpfrFU0EAOMhzc3O4xuaPLUCACelCbl0GHPCDQAnpbm0AgA4GyEHgOQIOQAkR8gBIDlCDgDJOSJGv1N7T9L3Bljibkk/KDROBpN2vBLHPAkm7XilwY/55yNi9vTGSkI+KNubEbFU9RyjMmnHK3HMk2DSjlca3jFzaQUAkiPkAJBc1pBfrXqAEZu045U45kkwaccrDemYU14jBwD8WNYzcgDAEUIOAMmlDLnt37b9vO3XbF/ox5dsX7bdsv2i7ceqnmfYbH/B9iu2n6t6llGwfY/tp2zvHP0//WjVMw2b7bfa/pbtfzs65j+peqZRsD1le8v2V0uvnTLkkp6T9JuSvln1IMNke0rS5yR9QNK9kj5i+95qpxq6v5J0ueohRuiWpE9GxC9Jeq+kP5yA/8b/J+mBiPgVSe+RdNn2eyueaRQelbQzjIVThjwidiJiEj51+T5JL0bESxHxqqQnJX244pmGKiK+Kem/qp5jVCLi+xHx7NGff6TDf+gX+k3349B/H72cPvq60E9d2L4k6YOSHh/G+ilDPkHmJb38utc3dcH/kU8y2wuSFiU9U+0kw3d0meGGpFckXYuIi37Mn5X0KUmvDWPxsQ257a/bfu6Mrwt9RnqKz9h2oc9cJpXtt0n6sqRPRMQPq55n2CLiICLeI+mSpPtsv7vqmYbF9sOSXomI68Pax9h+1FtEPFj1DGPgpqR7Xvf6kqTdimbBkNie1mHEn4iI9arnGaWI6Nh+Wof3RS7qDe77JX3I9kOS3irpHba/GBEfLbWDsT0jhyTp25LeZfudtu+S9Iikr1Q8EwqybUmfl7QTEZ+pep5RsD1re+bozzVJD0p6odqphici1iLiUkQs6PDf8DdKRlxKGnLbv2H7pqT3Sfp7282qZxqGiLgl6eOSmjq8CfY3EfF8tVMNl+2/lvQvkuq2b9r+g6pnGrL7JX1M0gO2bxx9PVT1UEP2c5Kesv0dHZ6sXIuI4o/kTRJ+RR8Akkt5Rg4A+DFCDgDJEXIASI6QA0ByhBwAkiPkAJAcIQeA5P4f28so75ivN5MAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(xs, ys);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training the model\n", "\n", "Let's build a simple model and train using a built-in loss function like the `mean_squared_error`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Dense(1, input_shape=[1])\n", "])\n", "\n", "model.compile(optimizer='sgd', loss='mean_squared_error')\n", "model.fit(xs, ys, epochs=500, verbose=0)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[18.978544]], dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_mse = model.predict([10.0])\n", "y_mse" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAMvklEQVR4nO3dX4hc9RnG8ecxUZqoJZGsEtc/a0tIKxWNLGIbENvUGtvSpIJgSCUtQnqhrZZiG/XC3oiCrW0vRNhqaqBbRTTVUIJRrEUKRdyo1KRpMFiNm6TuiFilBtT49mLOtpN1J7M7c2bPvrPfD4SZ+WV2zzskfDN7ZiY/R4QAAPkcV/UAAID2EHAASIqAA0BSBBwAkiLgAJAUAQeApFoG3PaZtp+xvcf2bts3FOun2H7K9ivF5eLujwsAGOdW7wO3vVTS0oh4wfbJknZKWivpu5Lejog7bW+StDgifnqs77VkyZIYGBgoZXAAmCt27tz5VkT0TVyf3+oLI+KQpEPF9fds75HUL2mNpEuLu22R9GdJxwz4wMCARkZGpjU4AMx1tl+fbH1a58BtD0haIek5SacVcR+P/KlNvmaj7RHbI7VabTqHAwAcw5QDbvskSY9KujEi3p3q10XEUEQMRsRgX98nfgIAALRpSgG3fbzq8R6OiK3F8pvF+fHx8+Rj3RkRADCZqbwLxZLul7QnIu5u+K1tkjYU1zdIerz88QAAzbR8EVPSSknXSHrZ9kvF2i2S7pT0sO1rJe2XdFV3RgQATGYq70L5iyQ3+e1V5Y4DAJgqPokJAEkRcADohuFhaWBAOu64+uXwcOmHmMo5cADAdAwPSxs3Su+/X7/9+uv125K0fn1ph+EZOACU7dZb/x/vce+/X18vEQEHgLLt3z+99TYRcAAo21lnTW+9TQQcAMp2++3SwoVHry1cWF8vEQEHgLKtXy8NDUlnny3Z9cuhoVJfwJR4FwoAdMf69aUHeyKegQNAUgQcAJIi4ACQFAEHgKQIOAAkRcABICkCDgBJEXAASIqAA0BSBBwAkiLgAJAUAQeApAg4ACRFwAEgKQIOAEkRcABIioADQFIEHACSIuAAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUgQcAJIi4ACQVMuA295se8z2roa1n9k+YPul4tfXuzsmAGCiqTwDf0DS6knWfxkRFxS/tpc7FgCglZYBj4hnJb09A7MAAKahk3Pg19v+W3GKZXGzO9neaHvE9kitVuvgcACARu0G/F5Jn5V0gaRDkn7R7I4RMRQRgxEx2NfX1+bhAAATtRXwiHgzIo5ExMeSfiPponLHAgC00lbAbS9tuPltSbua3RcA0B3zW93B9oOSLpW0xPaopNskXWr7Akkh6TVJ3+/ijACASbQMeESsm2T5/i7MAgCYBj6JCQBJEXAASIqAA0BSBBwAkiLgAJAUAQeApAg4ACRFwAEgKQIOAEkRcABIioADQFIEHACSIuAAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUgQcAJIi4ACQFAEHgKQIOAAkRcABICkCDgBJEXAASIqAA0BSBBwAkiLgAJAUAQeApAg4ACRFwAEgKQIOAEkRcABIioADQFItA257s+0x27sa1k6x/ZTtV4rLxd0dEwAw0VSegT8gafWEtU2Sno6IZZKeLm4DAGZQy4BHxLOS3p6wvEbSluL6FklrS54LANBCu+fAT4uIQ5JUXJ5a3kgAgKno+ouYtjfaHrE9UqvVun04AJgz2g34m7aXSlJxOdbsjhExFBGDETHY19fX5uEAABO1G/BtkjYU1zdIeryccQAAUzWVtxE+KOmvkpbbHrV9raQ7JV1m+xVJlxW3AQAzaH6rO0TEuia/tarkWQAA08AnMQEgKQIOAEkRcABIioADQFIEHACSIuAAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUgQcAJIi4ACQFAEHgKQIOAAkRcABICkCDgBJEXAASKrlnphABo+9eEB37dirg+8c1umLFuimy5dr7Yr+qscCuoqAI73HXjygm7e+rMMfHpEkHXjnsG7e+rIkEXH0NE6hIL27duz9X7zHHf7wiO7asbeiiYCZQcCR3sF3Dk9rHegVBBzpnb5owbTWgV5BwJHeTZcv14Lj5x21tuD4ebrp8uUVTQTMDF7ERHrjL1TyLhTMNQQcPWHtin6CjTmHUygAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUgQcAJIi4ACQVEefxLT9mqT3JB2R9FFEDJYxFHoDmywA3VXGR+m/HBFvlfB90EPYZAHoPk6hoCvYZAHovk4DHpKetL3T9sbJ7mB7o+0R2yO1Wq3DwyELNlkAuq/TgK+MiAslXSHpOtuXTLxDRAxFxGBEDPb19XV4OGTBJgtA93UU8Ig4WFyOSfqDpIvKGAr5sckC0H1tB9z2ibZPHr8u6WuSdpU1GHJbu6Jfd1x5nvoXLZAl9S9aoDuuPI8XMIESdfIulNMk/cH2+Pf5fUQ8UcpU6AlssgB0V9sBj4hXJZ1f4iwAgGngbYQAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUgQcAJIi4ACQVBkbOiAJdsgBegsBnyPYIQfoPZxCmSPYIQfoPQR8jmCHHKD3EPA5gh1ygN5DwOcIdsgBeg8vYs4R4y9U8i4UoHcQ8DmEHXKA3sIpFABIioADQFIEHACSIuAAkBQBB4CkCDgAJEXAASApAg4ASRFwAEiKgANAUnyUvkLskAOgEwS8IuyQA6BTnEKpCDvkAOgUAa8IO+QA6BQBrwg75ADoFAGvCDvkAOgUL2JWhB1yAHSKgFeIHXIAdIJTKACQVEcBt73a9l7b+2xvKmsoAEBrbQfc9jxJ90i6QtK5ktbZPreswQAAx9bJM/CLJO2LiFcj4gNJD0laU85YAIBWOgl4v6Q3Gm6PFmtHsb3R9ojtkVqt1sHhAACNOgm4J1mLTyxEDEXEYEQM9vX1dXA4AECjTgI+KunMhttnSDrY2TgAgKnqJODPS1pm+xzbJ0i6WtK2csYCALTS9gd5IuIj29dL2iFpnqTNEbG7tMkAAMfU0ScxI2K7pO0lzQIAmAY+Sj8Bu+QAyIKAN2CXHACZ8H+hNGCXHACZEPAG7JIDIBMC3oBdcgBkQsAbsEsOgEx4EbMBu+QAyISAT8AuOQCy4BQKACRFwAEgKQIOAEkRcABIioADQFIEHACSIuAAkBQBB4CkCDgAJEXAASCpWf9RenbIAYDJzeqAs0MOADQ3q0+hsEMOADQ3qwPODjkA0NysDjg75ABAc7M64OyQAwDNzeoXMdkhBwCam9UBl9ghBwCamdWnUAAAzRFwAEiKgANAUgQcAJIi4ACQlCNi5g5m1yS9PmMH7MwSSW9VPUSX8Nhy4rHlVMZjOzsi+iYuzmjAM7E9EhGDVc/RDTy2nHhsOXXzsXEKBQCSIuAAkBQBb26o6gG6iMeWE48tp649Ns6BA0BSPAMHgKQIOAAkRcAnsL3a9l7b+2xvqnqeMtk+0/YztvfY3m37hqpnKpPtebZftP3Hqmcpm+1Fth+x/Y/iz++LVc9UFts/Kv4+7rL9oO1PVT1Tu2xvtj1me1fD2im2n7L9SnG5uKzjEfAGtudJukfSFZLOlbTO9rnVTlWqjyT9OCI+L+liSdf12OO7QdKeqofokl9LeiIiPifpfPXI47TdL+mHkgYj4guS5km6utqpOvKApNUT1jZJejoilkl6urhdCgJ+tIsk7YuIVyPiA0kPSVpT8UyliYhDEfFCcf091SPQE//Zuu0zJH1D0n1Vz1I225+WdImk+yUpIj6IiHeqnapU8yUtsD1f0kJJByuep20R8ayktycsr5G0pbi+RdLaso5HwI/WL+mNhtuj6pHATWR7QNIKSc9VO0lpfiXpJ5I+rnqQLviMpJqk3xaniO6zfWLVQ5UhIg5I+rmk/ZIOSfp3RDxZ7VSlOy0iDkn1J1GSTi3rGxPwo3mStZ57n6XtkyQ9KunGiHi36nk6ZfubksYiYmfVs3TJfEkXSro3IlZI+o9K/DG8SsX54DWSzpF0uqQTbX+n2qnyIOBHG5V0ZsPtM5T4x7nJ2D5e9XgPR8TWqucpyUpJ37L9muqnvb5i+3fVjlSqUUmjETH+09Ijqge9F3xV0j8johYRH0raKulLFc9UtjdtL5Wk4nKsrG9MwI/2vKRlts+xfYLqL6Zsq3im0ti26udR90TE3VXPU5aIuDkizoiIAdX/zP4UET3zLC4i/iXpDdvLi6VVkv5e4Uhl2i/pYtsLi7+fq9QjL9A22CZpQ3F9g6THy/rGs35T45kUER/Zvl7SDtVfDd8cEbsrHqtMKyVdI+ll2y8Va7dExPYKZ8LU/EDScPHE4lVJ36t4nlJExHO2H5H0gurvknpRiT9Wb/tBSZdKWmJ7VNJtku6U9LDta1X/B+uq0o7HR+kBICdOoQBAUgQcAJIi4ACQFAEHgKQIOAAkRcABICkCDgBJ/RdFp4ssAzlBMwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(xs, ys)\n", "plt.scatter(10.0, y_mse, c='r');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training with Custom Loss\n", "\n", "Now let's see how we can use a custom loss. We first define a function that accepts the ground truth labels (`y_true`) and model predictions (`y_pred`) as parameters. We then compute and return the loss value in the function definition.\n", "\n", "The definition of Huber Loss is like this:\n", "\n", "$$\n", "L_{\\delta}(a) = \n", "\\begin{cases}\n", " \\frac{1}{2} (y - f(x))^2 \\quad & \\text{ for } \\vert a \\vert \\le \\delta, \\\\\n", " \\delta (\\vert y - f(x) \\vert - \\frac{1}{2} \\delta) \\quad & \\text{ otherwise} \\\\\n", "\\end{cases}\n", "$$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def my_huber_loss(y_true, y_pred):\n", " threshold = 1.\n", " error = y_true - y_pred\n", " is_small_error = tf.abs(error) <= threshold\n", " small_error_loss = tf.square(error) / 2\n", " big_error_loss = threshold * (tf.abs(error) - threshold / 2)\n", " return tf.where(is_small_error, small_error_loss, big_error_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the loss function is as simple as specifying the loss function in the `loss` argument of `model.compile()`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Dense(units=1, input_shape=[1,])\n", "])\n", "\n", "model.compile(optimizer='sgd', loss=my_huber_loss)\n", "model.fit(xs, ys, epochs=500, verbose=0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[18.722095]], dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hl = model.predict([10.0])\n", "y_hl" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAV3klEQVR4nO3df3BV5Z3H8c+XJBIkKAgagTAGZxxcTGgwwR+bjiayBpQd8cfW0aJFpxZnVMra1gX3D6W2W2jZqeIMs7PasjjFJVWqyBZHsZDIaJ0iASwoMlBAScDyyyBhyBTCd//IhSYhmB/33Jw8l/dr5s695zk393yfwHxy7nN+PObuAgCEp0/cBQAAuocAB4BAEeAAECgCHAACRYADQKAIcAAIVIcBbmYjzKzKzLaY2cdmNiPRfpGZvWNm2xLPg1JfLgDgFOvoPHAzGyppqLuvN7MBkmok3S7pAUmH3H2umc2SNMjdZ37dZw0ZMsTz8/MjKTzVjh49qv79+8ddRkqka9/StV8SfQtVVH2rqak54O4Xn7HC3bv0kPSGpJslbVVzsEvSUElbO/rZ4uJiD0VVVVXcJaRMuvYtXfvlTt9CFVXfJK3zdjK1wz3wlswsX9IaSQWSPnf3gS3WfenuZwyjmNk0SdMkKTc3t7iysrLT24tTQ0ODcnJy4i4jJdK1b+naL4m+hSqqvpWXl9e4e8kZK9pL9fYeknLUPHxyZ2K5vs36Lzv6DPbAe4d07Vu69sudvoUq1XvgnToLxcyyJP1O0svu/lqi+a+J8fFT4+T7kvsbAwDoisyO3mBmJunXkra4+y9brFouaaqkuYnnN7pTwPHjx1VbW6vGxsbu/HjKXHjhhdqyZUvcZUQiOztbeXl5ysrKirsUABHqMMAllUq6X9ImM9uYaPt3NQf3K2b2XUmfS/pWdwqora3VgAEDlJ+fr+a/Fb3DkSNHNGDAgLjLSJq76+DBg6qtrdXIkSPjLgdAhDoMcHd/T9LZknV8sgU0Njb2uvBOJ2amwYMHa//+/XGXAiBiveJKTMI7tfj9AumpVwQ4AKSVP78iPVsg7d3Y/PznV1Kymc6MgQMAOuvPr0j/933p+DHpUkmHdzcvS9KYuyPdFHvgABClVc9oxXmmirxh+uS881SRN0wrzjNp1TORbyq8AD/11WT2wMi+muzatUtXXnmlHnroIRUUFGjKlCmqqqpSaWmprrjiCq1du1bvvvuuioqKVFRUpLFjx+rIkSOSpHnz5mncuHEaM2aMnn766aRrARC2FScOafaQi7Q3q3mAY29WpmYPuUgrThyKfFthDaG0/GoiRfrVZPv27Xr11Vf1wgsvaNy4cXr11Vf13nvvafny5frZz36mpqYmLViwQKWlpWpoaFB2drZWrlypbdu2ae3atXJ33XbbbVqzZo1uuOGGJDsKIFTzB1+kxj6tTxxo7NNH8wdfpEkRbyusPfBVz/w9vE85fiySryYjR45UYWGh+vTpo6uuuko33nijzEyFhYXatWuXSktL9YMf/EDPP/+86uvrlZmZqZUrV2rlypUaO3asrr76an366afatm1b0rUACNcXGe2f9XW29mSEtQd+uLZr7V3Qt2/f06/79OlzerlPnz46ceKEZs2apUmTJunNN9/Uddddpz/84Q9ydz355JN6+OGHk94+gPRwaf+h2nt0b7vtUQtrD/zCvK61R+gvf/mLCgsLNXPmTJWUlOjTTz/VhAkTtHDhQjU0NEiS6urqtG8ft4QBzmUzrp6h7IzsVm3ZGdmacfWMyLcV1h74+Kdaj4FLUla/5vYUe+6551RVVaWMjAyNHj1at9xyi/r27astW7bo+uuvlyTl5ORo8eLFuuSSS1JeD4DeadLlzSPd89fPlyQN7T9UM66ecbo9SmEF+KkDlaueaR42uTCvObyTPICZn5+vzZs3n15etGjR6bNM2q5ra8aMGZoxI/q/rADCNenySZp0+SRVV1frkbJHUradsAJcag7riE+GB4AQhTUGDgA4jQAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAa7muxEWFBR0+v2LFi3SY489FnkdZWVlWrduXeSfCyA9BRfgK3asUMXSCo15aYwqllZoxY4VcZfUZU1NTXGXACANBBXgK3as0Ow/ztbeo3vlcu09ulez/zg7khBvamrS9773PV111VWqqKjQsWPHWu0RHzhwQPn5+affv3v3bk2cOFGjRo3Sj3/849Ptixcv1jXXXKOioiI9/PDDp8M6JydHTz31lK699lp98MEHHdazZMkSFRYWqqCgQDNnzjxd4wMPPKCCggIVFhbq2WeflSQ9//zzGj16tMaMGaN77rkn6d8FgDAEdSXm/PXz1djU2KqtsalR89fPT/o+A9u2bdOSJUv04osv6u6779Ybb7zxte9fu3atNm/erPPPP1/jxo3TpEmT1L9/f/32t7/V+++/r6ysLD3yyCN6+eWX9Z3vfEdHjx5VQUGBnnmm41vf7tmzRzNnzlRNTY0GDRqkiooKLVu2TCNGjFBdXd3pS/vr6+slSXPnztXOnTvVt2/f020A0l9Qe+BfHP2iS+1dMXLkSBUVFUmSiouL9fnnn3/t+2+++WYNHjxY/fr105133qn33ntPq1atUk1NjcaNG6eioiKtWrVKO3bskCRlZGTorrvu6lQtH374ocrKynTxxRcrMzNTU6ZM0Zo1a3T55Zdrx44dmj59ut566y1dcMEFkqQxY8ZoypQpWrx4sTIzg/qbDCAJQQX4pf0v7VJ7V7S8H3hGRoZOnDihzMxMnTx5UpLU2Nh6z9/Mzlh2d02dOlUbN27Uxo0btXXrVs2ePVuSlJ2drYyMjE7V4u7ttg8aNEgfffSRysrKtGDBAj300EOSpBUrVujRRx9VTU2NiouLdeLEiU5tB0DYggrwnrzPrtR8J8KamhpJ0tKlS1ute+edd3To0CEdO3ZMy5YtU2lpqcaPH6+lS5eevif4oUOH9Nlnn3V5u9dee63effddHThwQE1NTVqyZIluvPFGHThwQCdPntRdd92ln/zkJ1q/fr1Onjyp3bt3q7y8XL/4xS9UX19/+v7kANJbUN+3W95n94ujX+jS/pem7D67kvSjH/1Id999t37zm9/opptuarXum9/8pu6//35t375d3/72t1VSUiJJ+ulPf6qKigqdPHlSWVlZWrBggS677LIubXfo0KGaM2eOysvL5e669dZbNXnyZH300Ud68MEHT38rmDNnjpqamnTffffp8OHDcnc9/vjjGjhwYDS/AAC9m7v32KO4uNjb+uSTT85o6w2++uqruEuIVMvfc1VVVXyFpFC69sudvoUqqr5JWuftZGpQQygAgL8LagglXdxxxx3auXNnq7af//znmjBhQkwVAQhRrwhwdz/jrI509vrrr/fo9vwsZ7UACFvsQyjZ2dk6ePAgIZMi7q6DBw8qOzu74zcDCErse+B5eXmqra3V/v374y6llcbGxrQJvezsbOXl5cVdBoCIxR7gWVlZGjlyZNxlnKG6ulpjx46NuwwAOKvYh1AAAN1DgANAoAhwAAhUhwFuZgvNbJ+ZbW7RNtvM6sxsY+Jxa2rLBAC01Zk98EWSJrbT/qy7FyUeb0ZbFgCgIx0GuLuvkXSoB2oBAHSBdeYCGjPLl/R7dy9ILM+W9ICkryStk/RDd//yLD87TdI0ScrNzS2urKyMoOzUa2hoUE5OTtxlpES69i1d+yXRt1BF1bfy8vIady85Y0V7d7hq+5CUL2lzi+VcSRlq3oP/D0kLO/M57d2NsLfiDmnhSdd+udO3UPXKuxG6+1/dvcndT0p6UdI13fkcAED3dSvAzWxoi8U7JG0+23sBAKnR4aX0ZrZEUpmkIWZWK+lpSWVmViTJJe2S9HAKawQAtKPDAHf3e9tp/nUKagEAdAFXYgJAoAhwAAgUAQ4AgSLAASBQBDgABIoAB4BAEeAAECgCHAACRYADQKAIcAAIFAEOAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBIsABIFAEOAAEigAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4AgSLAASBQBDgABIoAB4BAEeAAECgCHAACRYADQKAIcAAIFAEOAIHqMMDNbKGZ7TOzzS3aLjKzd8xsW+J5UGrLBAC01Zk98EWSJrZpmyVplbtfIWlVYhkA0IM6DHB3XyPpUJvmyZJeSrx+SdLtEdcFAOhAd8fAc919ryQlni+JriQAQGeYu3f8JrN8Sb9394LEcr27D2yx/kt3b3cc3MymSZomSbm5ucWVlZURlJ16DQ0NysnJibuMlEjXvqVrvyT6Fqqo+lZeXl7j7iVnrHD3Dh+S8iVtbrG8VdLQxOuhkrZ25nOKi4s9FFVVVXGXkDLp2rd07Zc7fQtVVH2TtM7bydTuDqEslzQ18XqqpDe6+TkAgG7qzGmESyR9IGmUmdWa2XclzZV0s5ltk3RzYhkA0IMyO3qDu997llXjI64FANAFXIkJAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBIsABIFAEOAAEigAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4AgSLAASBQBDgABIoAR/CWbahT6dzV2lR3WKVzV2vZhrq4SwJ6RIeTGgO92bINdXrytU06drxJGiHV1R/Tk69tkiTdPnZ4zNUBqcUeOII27+2tzeHdwrHjTZr39taYKgJ6DgGOoO2pP9aldiCdEOAI2rCB/brUDqQTAhxBe2LCKPXLymjV1i8rQ09MGBVTRUDP4SAmgnbqQGXzmPcRDR/YT09MGMUBTJwTCHAE7/axw3X72OGqrq7W9CllcZcD9BiGUAAgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBIsABIFBJXYlpZrskHZHUJOmEu5dEURTCt2xDnea9vVV76o9pGJe3AykRxaX05e5+IILPQZpoNcmCmGQBSBWGUBA5JlkAeoa5e/d/2GynpC8luaT/dvcX2nnPNEnTJCk3N7e4srKy29vrSQ0NDcrJyYm7jJRIdd821R0+67rC4RembLv8m4WJvnWsvLy8pr0h6mQDfJi77zGzSyS9I2m6u6852/tLSkp83bp13d5eT6qurlZZWVncZaREqvtWOne16tqZEWf4wH56f9ZNKdsu/2Zhom8dM7N2AzypIRR335N43ifpdUnXJPN5SA9MsgD0jG4HuJn1N7MBp15LqpC0OarCEK7bxw7XnDsLNXxgP5ma97zn3FnIAUwgYsmchZIr6XUzO/U5/+vub0VSFYJ3apIFAKnT7QB39x2SvhFhLQCALuA0QgAIFAEOAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBfg5ZtqFOpXNXa1PdYZXOXa1lG+riLglAEqKYkQcBaDVLzghmyQHSAXvg5whmyQHSDwF+jtjTzgQLX9cOoPcjwM8Rwwb261I7gN6PAD9HMEsOkH44iHmOOHWgsnnM+4iGD+ynJyaM4gAmEDAC/Bxyapac6upqTZ9SFnc5AJLEEAoABIoAB4BAEeAAECgCHAACRYADQKAIcAAIFAEOAIEiwAEgUAQ4AASKAAeAQHEpfYyWbajTvLe3ak/9MQ3j3iQAuogAj0mrGXLEDDkAuo4hlJgwQw6AZBHgMWGGHADJIsBjwgw5AJJFgMeEGXIAJIuDmDFpOUMOZ6EA6A4CPEanZsgBgO5gCAUAApVUgJvZRDPbambbzWxWVEUBADrW7QA3swxJCyTdImm0pHvNbHRUhQEAvl4ye+DXSNru7jvc/W+SKiVNjqYsAEBHzN2794Nm/yJpors/lFi+X9K17v5Ym/dNkzRNknJzc4srKyuTq7iHNDQ0KCcnJ+4yUiJd+5au/ZLoW6ii6lt5eXmNu5e0bU/mLBRrp+2Mvwbu/oKkFySppKTEy8rKkthkz6murlYotXZVuvYtXfsl0bdQpbpvyQyh1Eoa0WI5T9Ke5MoBAHRWMgH+oaQrzGykmZ0n6R5Jy6MpCwDQkW4Pobj7CTN7TNLbkjIkLXT3jyOrDADwtZK6EtPd35T0ZkS1AAC6gCsx21i2oU6lc1drU91hlc5drWUb6uIuCQDaxb1QWmg1S84IZskB0LuxB94Cs+QACAkB3gKz5AAICQHeArPkAAgJAd4Cs+QACAkHMVtoOUuOdETDmSUHQC9GgLdxapac6upqTZ9SFnc5AHBWDKEAQKAIcAAIFAEOAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACFSvv5R+2YY6zXt7q/bUH9Mw7k0CAKf16gBvNUOOmCEHAFrq1UMozJADAGfXqwOcGXIA4Ox6dYAzQw4AnF2vDnBmyAGAs+vVBzFbzpDDWSgA0FqvDnDp7zPkAABa69VDKACAsyPAASBQBDgABIoAB4BAEeAAEChz957bmNl+SZ/12AaTM0TSgbiLSJF07Vu69kuib6GKqm+XufvFbRt7NMBDYmbr3L0k7jpSIV37lq79kuhbqFLdN4ZQACBQBDgABIoAP7sX4i4ghdK1b+naL4m+hSqlfWMMHAACxR44AASKAAeAQBHgbZjZRDPbambbzWxW3PVExcxGmFmVmW0xs4/NbEbcNUXNzDLMbIOZ/T7uWqJkZgPNbKmZfZr497s+7pqiYGaPJ/4vbjazJWaWHXdN3WVmC81sn5ltbtF2kZm9Y2bbEs+Dot4uAd6CmWVIWiDpFkmjJd1rZqPjrSoyJyT90N3/QdJ1kh5No76dMkPSlriLSIH5kt5y9yslfUNp0EczGy7p+5JK3L1AUoake+KtKimLJE1s0zZL0ip3v0LSqsRypAjw1q6RtN3dd7j73yRVSpocc02RcPe97r4+8fqImkMgbW60bmZ5kiZJ+lXctUTJzC6QdIOkX0uSu//N3evjrSoymZL6mVmmpPMl7Ym5nm5z9zWSDrVpnizppcTrlyTdHvV2CfDWhkva3WK5VmkUcqeYWb6ksZL+FG8lkXpO0r9JOhl3IRG7XNJ+Sf+TGB76lZn1j7uoZLl7naT/lPS5pL2SDrv7ynirilyuu++VmnegJF0S9QYI8Nasnba0Os/SzHIk/U7Sv7r7V3HXEwUz+2dJ+9y9Ju5aUiBT0tWS/svdx0o6qhR8Fe9pifHgyZJGShomqb+Z3RdvVeEhwFurlTSixXKeAv5a15aZZak5vF9299firidCpZJuM7Ndah72usnMFsdbUmRqJdW6+6lvS0vVHOih+ydJO919v7sfl/SapH+Muaao/dXMhkpS4nlf1BsgwFv7UNIVZjbSzM5T80GV5THXFAkzMzWPo25x91/GXU+U3P1Jd89z93w1/5utdve02Jtz9y8k7TazUYmm8ZI+ibGkqHwu6TozOz/xf3O80uDgbBvLJU1NvJ4q6Y2oN9DrJzXuSe5+wswek/S2mo+KL3T3j2MuKyqlku6XtMnMNiba/t3d34yxJnTOdEkvJ3Yqdkh6MOZ6kubufzKzpZLWq/kMqQ0K+JJ6M1siqUzSEDOrlfS0pLmSXjGz76r5D9a3It8ul9IDQJgYQgGAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBIsABIFD/D159Lle8rmDcAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(xs, ys);\n", "plt.scatter(10.0, y_mse, label='mse');\n", "plt.scatter(10.0, y_hl, label='huber_loss');\n", "plt.grid()\n", "plt.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2 - Huber Loss Hyperparameter and Loss class\n", "\n", "In this section, we'll extend our previous Huber loss function and show how you can include hyperparameters in defining loss functions. We'll also look at how to implement a custom loss as an object by inheriting the [Loss](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss) class." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the Data\n", "\n", "As before, this model will be trained on the `xs` and `ys` below where the relationship is $y = 2x-1$. Thus, later, when we test for `x=10`, whichever version of the model gets the closest answer to `19` will be deemed more accurate." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom loss with hyperparameter\n", "\n", "The `loss` argument in `model.compile()` only accepts functions that accepts two parameters: the ground truth (`y_true`) and the model predictions (`y_pred`). If we want to include a hyperparameter that we can tune, then we can define a wrapper function that accepts this hyperparameter." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# wrapper function that accepts the hyperparameter\n", "def my_huber_loss_with_threshold(threshold):\n", " # function that accepts the ground truth and predictions\n", " def my_huber_loss(y_true, y_pred):\n", " error = y_true - y_pred\n", " is_small_error = tf.abs(error) <= threshold\n", " small_error_loss = tf.square(error) / 2\n", " big_error_loss = threshold * (tf.abs(error) - (threshold / 2))\n", " return tf.where(is_small_error, small_error_loss, big_error_loss)\n", " # return the inner function tuned by the hyperparameter\n", " return my_huber_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now specify the `loss` as the wrapper function above. Notice that we can now set the `threshold` value. Try varying this value and see the results you get." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Dense(1, input_shape=[1])\n", "])\n", "\n", "model.compile(optimizer='sgd', loss=my_huber_loss_with_threshold(threshold=1.2))\n", "model.fit(xs, ys, epochs=500, verbose=0)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[18.618975]], dtype=float32)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hlt = model.predict([10.0])\n", "y_hlt" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZKUlEQVR4nO3df3BV5b3v8feXJJJIEFAk8sMx2HGomiBIULlpNZFrULgWqz2OilZpLXZUpNp6wc4p/min0Np7BOfSM5dSqle5UOUoarGIRSKjp3OQCBYqMFBUTAAREUyQYBK+94+ENAmBkOy1s/JsPq+ZzN772StrfR+2frL2s/Z+HnN3REQkPN3iLkBERDpGAS4iEigFuIhIoBTgIiKBUoCLiARKAS4iEqg2A9zMzjazlWa20cz+bmZTGtpPN7PXzWxLw22f5JcrIiJHWFufAzez/kB/d3/XzHoCZcB1wB3AXnefaWbTgD7uPvV4++rbt6/n5uZGUniyHThwgB49esRdRlKkat9StV+gvoUqqr6VlZXtcfczj3rC3dv1A7wEXAVspj7YAfoDm9v63REjRngoVq5cGXcJSZOqfUvVfrmrb6GKqm/AGm8lU9s8A2/KzHKBVUAesN3dezd57nN3P2oYxcwmAZMAcnJyRixatOiEjxenqqoqsrOz4y4jKVK1b6naL1DfQhVV34qLi8vcveCoJ1pL9dZ+gGzqh0+ub3i8r8Xzn7e1D52Bdw2p2rdU7Ze7+haqZJ+Bn9CnUMwsA/gPYIG7v9DQ/EnD+PiRcfLdif2NERGR9khvawMzM+D3wEZ3/7cmT70M3A7MbLh9qSMF1NTUUF5eTnV1dUd+PWl69erFxo0b4y4jKaLoW2ZmJoMGDSIjIyOiqkSkvdoMcKAQuA1Yb2brGtp+Sn1wP2dm3we2A//SkQLKy8vp2bMnubm51P+t6BoqKyvp2bNn3GUkRaJ9c3c+++wzysvLGTx4cISViUh7tBng7v4WcKxkHZ1oAdXV1V0uvOX4zIwzzjiDTz/9NO5SRE5qXeKbmArv8Og1E4lflwhwEZGU8rfn4Ik82Lmu/vZvzyXlMCcyBi4iIifqb8/BK/dBzUE4C9j/cf1jgKE3RnoonYGLiERpxWMsPcUoGTSA9085hZJBA1h6isGKxyI/VHgBfuStySO9I3tr8uGHH/L1r3+dO++8k7y8PCZMmMDKlSspLCzkvPPOY/Xq1bz55psMGzaMYcOGMXz4cCorKwF4/PHHGTlyJEOHDuXhhx9OuBYRCdvS2r289klv/nUunLvL+de58NonvVlauzfyY4U1hNL0rQlE+tZk69atPP/888ydO5eRI0fy/PPP89Zbb/Hyyy/zy1/+krq6OubMmUNhYSFVVVVkZmayfPlytmzZwurVq3F3vvWtb7Fq1Souv/zyBDsqIqF665PeTFzmZNbCR8CZX8DEZfDc1b0ZF/GxwjoDX/HYP8P7iJqDkbw1GTx4MPn5+XTr1o0LL7yQK664AjMjPz+fDz/8kMLCQh544AGefPJJ9u3bR3p6OsuXL2f58uUMHz6ciy++mE2bNrFly5aEaxGRcF2zqj68m8qsrW+PWlhn4PvL29feDt27d2+8361bt8bH3bp1o7a2lmnTpjFu3DheffVVLrvsMv7yl7/g7jz00EPcddddCR9fRFJD3y/a156IsM7Aew1qX3uE/vGPf5Cfn8/UqVMpKChg06ZNjBkzhvnz51NVVQVARUUFu3drShiRk1ltv97tak9EWAE+ejpkZDVvy8iqb0+yWbNmkZeXx0UXXURWVhbXXHMNJSUl3HLLLYwaNYr8/Hy+853vNF7cFJGT0zkP/pTD3ZvPEXS4ewbnPPjTyI8V1hDKkQuVKx6rHzbpNag+vBO8gJmbm8uGDRsaHz/11FONQdzyuZamTJnClClTEjq+iKSOXtdeC8DuJ2YBkD5gAP3u/1Fje5TCCnCoD+uIPwwvIhKlXtdeS69rr6WitJTz7rknaccJawhFREQaKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlAKd+NsK8vLwT3v6pp57i3nvvjbyOoqIi1qxZE/l+RSQ1BRfgS7ctpWRxCUOfHkrJ4hKWblsad0ntVldXF3cJIpICggrwpduW8sh/PsLOAztxnJ0HdvLIfz4SSYjX1dXxgx/8gAsvvJCSkhIOHjzY7Ix4z5495ObmNm7/8ccfc/XVVzNkyBAeffTRxvZnn32WSy65hGHDhnHXXXc1hnV2djbTp0/n0ksv5a9//Wub9SxcuJD8/Hzy8vKYOnVqY4133HEHeXl55Ofn88QTTwDw5JNPcsEFFzB06FBuuummhP8tRCQMQX0Tc/a7s6muq27WVl1Xzex3ZzPu3MRm2t2yZQsLFy7kd7/7HTfeeCMvvfTScbdfvXo1GzZs4NRTT2XkyJGMGzeOHj168Mc//pG3336bjIwM7r77bhYsWMB3v/tdDhw4QF5eHo891vbUtzt27GDq1KmUlZXRp08fSkpKWLJkCWeffTYVFRWNX+3ft28fADNnzuSDDz6ge/fujW0ikvqCOgPfdWBXu9rbY/DgwQwbNgyAESNGsH379uNuf9VVV3HGGWeQlZXF9ddfz1tvvcWKFSsoKytj5MiRDBs2jBUrVrBt2zYA0tLSuOGGG06olnfeeYeioiLOPPNM0tPTmTBhAqtWreLcc89l27ZtTJ48mWXLlnHaaacBMHToUCZMmMCzzz5LenpQf5NFJAFBBfhZPc5qV3t7NJ0PPC0tjdraWtLT0zl8+DAA1dXNz/zN7KjH7s7tt9/OunXrWLduHZs3b+aRRx4BIDMzk7S0tBOqxb31id/79OnDe++9R1FREXPmzOHOO+8EYOnSpdxzzz2UlZUxYsQIamtrW/19EUktQQX4lIunkJmW2awtMy2TKRcnZzbA3NxcysrKAFi8eHGz515//XX27t3LwYMHWbJkCYWFhYwePZrFixc3zgm+d+9ePvroo3Yf99JLL+XNN99kz5491NXVsXDhQq644gr27NnD4cOHueGGG/j5z3/Ou+++y+HDh/n4448pLi7m17/+Nfv27Wucn1xEUltQ77ePjHPPfnc2uw7s4qweZzHl4ikJj38fy09+8hNuvPFGnnnmGa688spmz33jG9/gtttuY+vWrdxyyy0UFBQA8Itf/IKSkhIOHz5MRkYGc+bM4ZxzzmnXcfv378+MGTMoLi7G3Rk7dizjx4/nvffeY+LEiY3vCmbMmEFdXR233nor+/fvx925//776d07+onjRaTrsWO9XU+GgoICb/k5540bN3L++ed3Wg0nqrKykp49e8ZdRlJE1beu9tqVlpZSVFQUdxlJob6FKaq+mVmZuxe0bA9qCEVERP4pqCGUVPHtb3+bDz74oFnbr371K8aMGRNTRSISIgV4DF588cW4SxCRFKAhFBGRQCnARUQCpQAXEQmUAlxEJFAKcE6e+cA3bdrEqFGj6Nu3L7/5zW+Oud2ECRMYMmQIeXl5fO9736OmpiZpNYlIxwUX4PtfeYUtV45m4/kXsOXK0ex/5ZW4S2q3uOYDP/3003nyySe57777jrvdhAkT2LRpE+vXr+fgwYPMmzevkyoUkfZoM8DNbL6Z7TazDU3aHjGzCjNb1/AzNrll1tv/yivs/Nl0anfsAHdqd+xg58+mRxLiJ8N84P369WPkyJFtzlg4duxYzAwz45JLLqG8vLzNekWk853I58CfAv438H9btD/h7sd+H54Eu5+YhbeYFdCrq9n9xCx6XXttQvvWfOBHq6mp4ZlnnmH27NmR7VNEotPmGbi7rwL2dkItbardubNd7e2h+cCPdvfdd3P55ZfzzW9+M7J9ikh0Evm//V4z+y6wBvixu3/e2kZmNgmYBJCTk0NpaWmz53v16kVlZeUJHTAtJ4e6XUcv3pCWk3PC+2hNVVUVGRkZjfuora2lpqYGM6OyspLKykr27NmDu1NZWUl1dTW1tbWN2x86dIhDhw5RU1PDzTff3DgH+BGVlZVkZmby5ZdfHreOuro6Dhw4wJdffklNTU3j/qurq/nqq69IT09v/EMxe/ZsFixYwG9/+1sWLVrE22+/zauvvsqjjz7K6tWrjxvk7s6hQ4eO+282Y8YMdu7cyYIFC465XXV19VGvZ5yqqqq6VD1RUt/ClOy+dTTA/x34OeANt/8L+F5rG7r7XGAu1M9G2HJmro0bN57wzHiHf/wAO382vdkwimVmkvPjBxKaXS87O5tu3bo17qN79+5UVVXxta99jU2bNlFcXMyyZcswM3r27ElmZialpaXU1NSQlZXFn//8Z+bPn8+pp57K+PHjmTp1Kv369WPv3r1UVlY2TifbVo1paWn06NGDoqIipk2bxqFDh+jTpw8vvvgikydP5tChQ/To0YNbb72VvLw87rjjDnr06MH27dsZN24cJSUlDBo0qLHOYzEzunfvfsxt5s2bR2lpKStWrCArK+uY+8nMzGT48OFt/fN2Gs1qFyb1reM6FODu/smR+2b2O+BPkVV0HEfGuXc/MYvanTtJ79+ffvf/KOHx72NJtfnAd+3aRUFBAV988QXdunVj1qxZvP/++5x22mmMHTuWefPmMWDAAH74wx9yzjnnMGrUKACuv/56pk+f3t5/PhFJNndv8wfIBTY0edy/yf37gUUnsp8RI0Z4S++///5RbV3BF198EXcJSRNV37raa7dy5cq4S0ga9S1MUfUNWOOtZGqbZ+BmthAoAvqaWTnwMFBkZsOoH0L5ELgr8r8sIiJyXG0GuLvf3Erz75NQy0kj2fOB/+EPfzjqo3+FhYXMmTMnkv2LSNfQJeYDd/ejVnlPZcmeD3zixIlMnDgxqcfwTlyKT0RaF/tX6TMzM/nss88UCAFxdz777DMyMzPjLkXkpBb7GfigQYMoLy/n008/jbuUZqqrq1M2oKLoW2ZmJoMGDYqoIhHpiNgDPCMjg8GDB8ddxlFKS0u71Geco5TKfRM5mcQ+hCIiIh2jABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUApwEZFAKcBFRAKlABcRCZQCXEQkUG0GuJnNN7PdZrahSdvpZva6mW1puO2T3DJFRKSlEzkDfwq4ukXbNGCFu58HrGh4LCIinajNAHf3VcDeFs3jgacb7j8NXBdxXSIi0oaOjoHnuPtOgIbbftGVJCIiJ8Lcve2NzHKBP7l7XsPjfe7eu8nzn7t7q+PgZjYJmASQk5MzYtGiRRGUnXxVVVVkZ2fHXUZSpGrfUrVfoL6FKqq+FRcXl7l7wVFPuHubP0AusKHJ481A/4b7/YHNJ7KfESNGeChWrlwZdwlJk6p9S9V+uatvoYqqb8AabyVTOzqE8jJwe8P924GXOrgfERHpoBP5GOFC4K/AEDMrN7PvAzOBq8xsC3BVw2MREelE6W1t4O43H+Op0RHXIiIi7aBvYoqIBEoBLiISKAW4iEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOASvCVrKyic+QbrK/ZTOPMNlqytiLskkU7R5qLGIl3ZkrUVPPTCeg7W1MHZULHvIA+9sB6A64YPjLk6keTSGbgE7fHXNteHdxMHa+p4/LXNMVUk0nkU4BK0HfsOtqtdJJUowCVoA3pntatdJJUowCVoD44ZQlZGWrO2rIw0HhwzJKaKRDqPLmJK0I5cqKwf865kYO8sHhwzRBcw5aSgAJfgXTd8INcNH0hpaSmTJxTFXY5Ip9EQiohIoBTgIiKBUoCLiARKAS4iEigFuIhIoBTgIiKBUoCLiARKAS4iEigFuIhIoBL6JqaZfQhUAnVArbsXRFGUhG/J2goef20zO/YdZIC+3i6SFFF8lb7Y3fdEsB9JEc0WWUCLLIgki4ZQJHJaZEGkc5i7d/yXzT4APgcc+D/uPreVbSYBkwBycnJGLFq0qMPH60xVVVVkZ2fHXUZSJLtv6yv2H/O5/IG9knZcvWZhUt/aVlxcXNbaEHWiAT7A3XeYWT/gdWCyu6861vYFBQW+Zs2aDh+vM5WWllJUVBR3GUmR7L4VznyDilZWxBnYO4u3p12ZtOPqNQuT+tY2M2s1wBMaQnH3HQ23u4EXgUsS2Z+kBi2yINI5OhzgZtbDzHoeuQ+UABuiKkzCdd3wgcy4Pp+BvbMw6s+8Z1yfrwuYIhFL5FMoOcCLZnZkP//P3ZdFUpUE78giCyKSPB0OcHffBlwUYS0iItIO+hihiEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAnkSVrKyic+QbrK/ZTOPMNlqytiLskEUlAFCvySACarZJztlbJEUkFOgM/SWiVHJHUowA/SexoZYGF47WLSNenAD9JDOid1a52Een6FOAnCa2SI5J6dBHzJHHkQmX9mHclA3tn8eCYIbqAKRIwBfhJ5MgqOaWlpUyeUBR3OSKSIA2hiIgESgEuIhIoBbiISKAU4CIigVKAi4gESgEuIhIoBbiISKAU4CIigVKAi4gESgEuIhIofZU+RkvWVvD4a5vZse8gAzQ3iYi0kwI8Js1WyEEr5IhI+2kIJSZaIUdEEqUAj4lWyBGRRCnAY6IVckQkUQrwmGiFHBFJlC5ixqTpCjn6FIqIdIQCPEZHVsgREekIDaGIiAQqoQA3s6vNbLOZbTWzaVEVJSIibetwgJtZGjAHuAa4ALjZzC6IqjARETm+RM7ALwG2uvs2d/8KWASMj6YsERFpi7l7x37R7DvA1e5+Z8Pj24BL3f3eFttNAiYB5OTkjFi0aFFiFXeSqqoqsrOz4y4jKVK1b6naL1DfQhVV34qLi8vcvaBleyKfQrFW2o76a+Duc4G5AAUFBV5UVJTAITtPaWkpodTaXqnat1TtF6hvoUp23xIZQikHzm7yeBCwI7FyRETkRCUS4O8A55nZYDM7BbgJeDmaskREpC0dHkJx91ozuxd4DUgD5rv73yOrTEREjiuhb2K6+6vAqxHVIiIi7aBvYrawZG0FhTPfYH3FfgpnvsGStRVxlyQi0irNhdJEs1VyztYqOSLStekMvAmtkiMiIVGAN6FVckQkJArwJrRKjoiERAHehFbJEZGQ6CJmE01XyYFKBmqVHBHpwhTgLRxZJae0tJTJE4riLkdE5Jg0hCIiEigFuIhIoBTgIiKBUoCLiARKAS4iEigFuIhIoBTgIiKBUoCLiARKAS4iEigFuIhIoLr8V+mXrK3g8dc2s2PfQQZobhIRkUZdOsCbrZCDVsgREWmqSw+haIUcEZFj69IBrhVyRESOrUsHuFbIERE5ti4d4FohR0Tk2Lr0RcymK+ToUygiIs116QCHf66QIyIizXXpIRQRETk2BbiISKAU4CIigVKAi4gESgEuIhIoc/fOO5jZp8BHnXbAxPQF9sRdRJKkat9StV+gvoUqqr6d4+5ntmzs1AAPiZmtcfeCuOtIhlTtW6r2C9S3UCW7bxpCEREJlAJcRCRQCvBjmxt3AUmUqn1L1X6B+haqpPZNY+AiIoHSGbiISKAU4CIigVKAt2BmV5vZZjPbambT4q4nKmZ2tpmtNLONZvZ3M5sSd01RM7M0M1trZn+Ku5YomVlvM1tsZpsaXr9RcdcUBTO7v+G/xQ1mttDMMuOuqaPMbL6Z7TazDU3aTjez181sS8Ntn6iPqwBvwszSgDnANcAFwM1mdkG8VUWmFvixu58PXAbck0J9O2IKsDHuIpJgNrDM3b8OXEQK9NHMBgL3AQXungekATfFW1VCngKubtE2DVjh7ucBKxoeR0oB3twlwFZ33+buXwGLgPEx1xQJd9/p7u823K+kPgRSZqJ1MxsEjAPmxV1LlMzsNOBy4PcA7v6Vu++Lt6rIpANZZpYOnArsiLmeDnP3VcDeFs3jgacb7j8NXBf1cRXgzQ0EPm7yuJwUCrkjzCwXGA78V7yVRGoW8D+Bw3EXErFzgU+BPzQMD80zsx5xF5Uod68AfgNsB3YC+919ebxVRS7H3XdC/QkU0C/qAyjAm7NW2lLqc5Zmlg38B/Ajd/8i7nqiYGb/A9jt7mVx15IE6cDFwL+7+3DgAEl4K97ZGsaDxwODgQFADzO7Nd6qwqMAb64cOLvJ40EE/LauJTPLoD68F7j7C3HXE6FC4Ftm9iH1w15Xmtmz8ZYUmXKg3N2PvFtaTH2gh+6/Ax+4+6fuXgO8APy3mGuK2idm1h+g4XZ31AdQgDf3DnCemQ02s1Oov6jycsw1RcLMjPpx1I3u/m9x1xMld3/I3Qe5ey71r9kb7p4SZ3Puvgv42MyGNDSNBt6PsaSobAcuM7NTG/7bHE0KXJxt4WXg9ob7twMvRX2ALr+ocWdy91ozuxd4jfqr4vPd/e8xlxWVQuA2YL2ZrWto+6m7vxpjTXJiJgMLGk4qtgETY64nYe7+X2a2GHiX+k9IrSXgr9Sb2UKgCOhrZuXAw8BM4Dkz+z71f7D+JfLj6qv0IiJh0hCKiEigFOAiIoFSgIuIBEoBLiISKAW4iEigFOAiIoFSgIuIBOr/Axh52kWKbEgDAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(xs, ys);\n", "plt.scatter(10.0, y_mse, label='mse');\n", "plt.scatter(10.0, y_hl, label='huber_loss');\n", "plt.scatter(10.0, y_hlt, label='huber_loss_1.2')\n", "plt.grid()\n", "plt.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implement Custom Loss as a Class\n", "\n", "We can also implement our custom loss as a class. It inherits from the Keras Loss class and the syntax and required methods are shown below." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.losses import Loss\n", "\n", "class MyHuberLoss(Loss):\n", " # initialize instance attributes\n", " def __init__(self, threshold=1):\n", " super(MyHuberLoss, self).__init__()\n", " self.threshold = threshold\n", " \n", " # Compute loss\n", " def call(self, y_true, y_pred):\n", " error = y_true - y_pred\n", " is_small_error = tf.abs(error) <= self.threshold\n", " small_error_loss = tf.square(error) / 2\n", " big_error_loss = self.threshold * (tf.abs(error) - self.threshold / 2)\n", " return tf.where(is_small_error, small_error_loss, big_error_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can specify the loss by instantiating an object from your custom loss class." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Dense(1, input_shape=[1,])\n", "])\n", "\n", "model.compile(optimizer='sgd', loss=MyHuberLoss(threshold=1.02))\n", "model.fit(xs, ys, epochs=500, verbose=0)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[18.58202]], dtype=float32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hltc = model.predict([10.0])\n", "y_hltc" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }