cross entropy loss / focal loss implmentation in pytorch
at the moment, the code is written for torch 1.4
binary cross entropy loss
\## using pytorch 1.4
def logit\_sanitation(val, min\_val):
unsqueezed\_a = torch.unsqueeze(val, -1)
limit = torch.ones\_like(unsqueezed\_a) \* min\_val
a = torch.cat((unsqueezed\_a, limit),-1)
values, \_= torch.max(a,-1)
return values
def manual\_bce\_loss(pred\_tensor, gt\_tensor, epsilon = 1e-8):
a = logit\_sanitation(1-pred\_tensor, epsilon)
b = logit\_sanitation(pred\_tensor, epsilon)
loss = - ( (1- gt\_tensor) \* torch.log(a) + gt\_tensor \* torch.log(b))
return loss
currently, torch 1.6 is out there and according to the pytorch docs, the torch.max
function can receive two tensors and return element-wise max values. However, in 1.4 this feature is not yet supported and that is why I had to unsqueeze, concatenate and then apply torch.max
in the above snippet. If you are using torch 1.6, you can change refactor the logit_sanitation
function with the updated torch.max
function.
The above binary cross entropy calculation will try to avoid any NaN occurrences due to excessively small logits when calculating torch.log
which should return a very large negative number which may be too big to process resulting in NaN. The epsilon value will be limiting the original logit value’s minimum value.
focal loss
using the functions defined above,
def manual\_focal\_loss(pred\_tensor, gt\_tensor, gamma, epsilon = 1e-8):
a = logit\_sanitation(1-pred\_tensor, epsilon)
b = logit\_sanitation(pred\_tensor, epsilon)
logit = (1-gt\_tensor) \* a + gt\_tensor \* b
focal\_loss = - (1-logit) \*\* gamma \* torch.log(logit)
return focal\_loss
focal loss is also used quite frequently so here it is.