{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "pytorch loss functions:\n", "\n", "- torch.nn.L1Loss\n", "- torch.nn.MSELoss\n", "- torch.nn.CrossEntropyLoss **need to be softmax-ed**\n", "- torch.nn.CTCLoss\n", "- torch.nn.NLLLoss **need to be exp-ed**\n", "- torch.nn.PoissonNLLLoss **need to be exp-ed if log_input is True (default, True)**\n", "- torch.nn.KLDivLoss **need to be exp-ed**\n", "- torch.nn.BCELoss \n", "- torch.nn.BCEWithLogitsLoss **need to be sigmoid-ed**\n", "- torch.nn.MarginRankingLoss\n", "- torch.nn.HingeEmbeddingLoss\n", "- torch.nn.MultiLabelMarginLoss\n", "- torch.nn.SmoothL1Loss\n", "- torch.nn.SoftMarginLoss\n", "- torch.nn.MultiLabelSoftMarginLoss\n", "- torch.nn.CosineEmbeddingLoss\n", "- torch.nn.MultiMarginLoss\n", "- torch.nn.TripletMarginLoss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In those on top that need an activation, F equivalents:\n", "\n", "- torch.nn.CrossEntropyLoss <-> F.cross_entropy\n", "- torch.nn.NLLLoss <-> F.nll_loss\n", "- torch.nn.PoissonNLLLoss <-> F.poisson_nll_loss\n", "- torch.nn.KLDivLoss <-> F.kl_div\n", "- torch.nn.BCEWithLogitsLoss <-> F.binary_cross_entropy_with_logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "custom fastai loss functions:\n", "- CrossEntropyFlat **need to be softmax-ed**\n", "- MSELossFlat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Class that need to be applied an activation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class_need_activ = [nn.CrossEntropyLoss(), nn.NLLLoss(), nn.PoissonNLLLoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()]\n", "class_need_activ += [CrossEntropyFlat()]\n", "class_names = [camel2snake(c.__class__.__name__) for c in class_need_activ]\n", "activs = [partial(F.softmax, dim=1), torch.exp, torch.exp, torch.exp, F.sigmoid, partial(F.softmax, dim=1)]\n", "loss_func_name2activ = {c:a for c,a in zip(class_names, activs)}\n", "loss_func_name2activ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "F_need_activ = [F.cross_entropy, F.nll_loss, F.poisson_nll_loss, F.kl_div, F.binary_cross_entropy_with_logits]\n", "for f,a in zip(F_need_activ, activs):\n", " if f.__name__ not in loss_func_name2activ:\n", " loss_func_name2activ[f.__name__] = a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func_name2activ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_func2activ(loss_func):\n", " cls_name = camel2snake(loss_func.__class__.__name__)\n", " if cls_name == 'mix_up_loss': \n", " loss_func = loss_func.crit\n", " cls_name = camel2snake(loss_func.__class__.__name__)\n", " if cls_name in loss_func_name2activ:\n", " if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return noop\n", " return loss_func_name2activ[cls_name]\n", " if hasattr(loss_func, 'func'): \n", " if loss_func.func.__name__ == 'poisson_nll_loss' and (not loss_func.keywords.get('log_input', True)): return noop\n", " loss_func = loss_func.func \n", " if getattr(loss_func,'__name__','') in loss_func_name2activ:\n", " return loss_func_name2activ[loss_func.__name__]\n", " return noop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.CrossEntropyLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.NLLLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.KLDivLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.PoissonNLLLoss(log_input=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.PoissonNLLLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.MSELoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.BCEWithLogitsLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(nn.BCELoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(F.cross_entropy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(partial(F.cross_entropy, reduce=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(partial(F.poisson_nll_loss, log_input=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func2activ(F.poisson_nll_loss)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }