at the moment, the code is written for torch 1.4
binary cross entropy loss
## using pytorch 1.4 def logit_sanitation(val, min_val): unsqueezed_a = torch.unsqueeze(val, -1) limit = torch.ones_like(unsqueezed_a) * min_val a = torch.cat((unsqueezed_a, limit),-1) values, _= torch.max(a,-1) return values def manual_bce_loss(pred_tensor, gt_tensor, epsilon = 1e-8): a = logit_sanitation(1-pred_tensor, epsilon) b = logit_sanitation(pred_tensor, epsilon) loss = - ( (1- gt_tensor) * torch.log(a) + gt_tensor * torch.log(b)) return loss
currently, torch 1.6 is out there and according to the pytorch docs, the
torch.max function can receive two tensors and return element-wise max values. However, in 1.4 this feature is not yet supported and that is why I had to unsqueeze, concatenate and then apply
torch.max in the above snippet. If you are using torch 1.6, you can change refactor the
logit_sanitation function with the updated
The above binary cross entropy calculation will try to avoid any NaN occurrences due to excessively small logits when calculating
torch.log which should return a very large negative number which may be too big to process resulting in NaN. The epsilon value will be limiting the original logit value’s minimum value.
using the functions defined above,
def manual_focal_loss(pred_tensor, gt_tensor, gamma, epsilon = 1e-8): a = logit_sanitation(1-pred_tensor, epsilon) b = logit_sanitation(pred_tensor, epsilon) logit = (1-gt_tensor) * a + gt_tensor * b focal_loss = - (1-logit) ** gamma * torch.log(logit) return focal_loss
focal loss is also used quite frequently so here it is.