In [None]:
from fastai import *
from fastai.vision import *

pytorch loss functions:

- torch.nn.L1Loss
- torch.nn.MSELoss
- torch.nn.CrossEntropyLoss **need to be softmax-ed**
- torch.nn.CTCLoss
- torch.nn.NLLLoss **need to be exp-ed**
- torch.nn.PoissonNLLLoss **need to be exp-ed if log_input is True (default, True)**
- torch.nn.KLDivLoss **need to be exp-ed**
- torch.nn.BCELoss 
- torch.nn.BCEWithLogitsLoss **need to be sigmoid-ed**
- torch.nn.MarginRankingLoss
- torch.nn.HingeEmbeddingLoss
- torch.nn.MultiLabelMarginLoss
- torch.nn.SmoothL1Loss
- torch.nn.SoftMarginLoss
- torch.nn.MultiLabelSoftMarginLoss
- torch.nn.CosineEmbeddingLoss
- torch.nn.MultiMarginLoss
- torch.nn.TripletMarginLoss

In those on top that need an activation, F equivalents:

- torch.nn.CrossEntropyLoss <-> F.cross_entropy
- torch.nn.NLLLoss <-> F.nll_loss
- torch.nn.PoissonNLLLoss <-> F.poisson_nll_loss
- torch.nn.KLDivLoss <-> F.kl_div
- torch.nn.BCEWithLogitsLoss <-> F.binary_cross_entropy_with_logits

custom fastai loss functions:
- CrossEntropyFlat **need to be softmax-ed**
- MSELossFlat

Class that need to be applied an activation:

In [None]:
class_need_activ = [nn.CrossEntropyLoss(), nn.NLLLoss(), nn.PoissonNLLLoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()]
class_need_activ += [CrossEntropyFlat()]
class_names = [camel2snake(c.__class__.__name__) for c in class_need_activ]
activs = [partial(F.softmax, dim=1), torch.exp, torch.exp, torch.exp, F.sigmoid, partial(F.softmax, dim=1)]
loss_func_name2activ = {c:a for c,a in zip(class_names, activs)}
loss_func_name2activ

In [None]:
F_need_activ = [F.cross_entropy, F.nll_loss, F.poisson_nll_loss, F.kl_div, F.binary_cross_entropy_with_logits]
for f,a in zip(F_need_activ, activs):
 if f.__name__ not in loss_func_name2activ:
 loss_func_name2activ[f.__name__] = a

In [None]:
loss_func_name2activ

In [None]:
def loss_func2activ(loss_func):
 cls_name = camel2snake(loss_func.__class__.__name__)
 if cls_name == 'mix_up_loss': 
 loss_func = loss_func.crit
 cls_name = camel2snake(loss_func.__class__.__name__)
 if cls_name in loss_func_name2activ:
 if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return noop
 return loss_func_name2activ[cls_name]
 if hasattr(loss_func, 'func'): 
 if loss_func.func.__name__ == 'poisson_nll_loss' and (not loss_func.keywords.get('log_input', True)): return noop
 loss_func = loss_func.func 
 if getattr(loss_func,'__name__','') in loss_func_name2activ:
 return loss_func_name2activ[loss_func.__name__]
 return noop

In [None]:
loss_func2activ(nn.CrossEntropyLoss())

In [None]:
loss_func2activ(nn.NLLLoss())

In [None]:
loss_func2activ(nn.KLDivLoss())

In [None]:
loss_func2activ(nn.PoissonNLLLoss(log_input=False))

In [None]:
loss_func2activ(nn.PoissonNLLLoss())

In [None]:
loss_func2activ(nn.MSELoss())

In [None]:
loss_func2activ(nn.BCEWithLogitsLoss())

In [None]:
loss_func2activ(nn.BCELoss())

In [None]:
loss_func2activ(F.cross_entropy)

In [None]:
loss_func2activ(partial(F.cross_entropy, reduce=True))

In [None]:
loss_func2activ(partial(F.poisson_nll_loss, log_input=False))

In [None]:
loss_func2activ(F.poisson_nll_loss)