at the moment, the code is written for torch 1.4

binary cross entropy loss

## using pytorch 1.4

def logit_sanitation(val, min_val):

    unsqueezed_a = torch.unsqueeze(val, -1)
    limit = torch.ones_like(unsqueezed_a) * min_val
    a = torch.cat((unsqueezed_a, limit),-1)
    values, _= torch.max(a,-1)

    return values
	
	
def manual_bce_loss(pred_tensor, gt_tensor, epsilon = 1e-8):

    a = logit_sanitation(1-pred_tensor, epsilon)
    b = logit_sanitation(pred_tensor, epsilon)

    loss = - ( (1- gt_tensor) * torch.log(a) + gt_tensor * torch.log(b))

    return loss

currently, torch 1.6 is out there and according to the pytorch docs, the torch.max function can receive two tensors and return element-wise max values. However, in 1.4 this feature is not yet supported and that is why I had to unsqueeze, concatenate and then apply torch.max in the above snippet. If you are using torch 1.6, you can change refactor the logit_sanitation function with the updated torch.max function.

The above binary cross entropy calculation will try to avoid any NaN occurrences due to excessively small logits when calculating torch.log which should return a very large negative number which may be too big to process resulting in NaN. The epsilon value will be limiting the original logit value’s minimum value.

focal loss

using the functions defined above,

def manual_focal_loss(pred_tensor, gt_tensor, gamma, epsilon = 1e-8):


    a = logit_sanitation(1-pred_tensor, epsilon)
    b = logit_sanitation(pred_tensor, epsilon)

    logit = (1-gt_tensor) * a + gt_tensor * b
    focal_loss = - (1-logit) ** gamma * torch.log(logit)

    return focal_loss

focal loss is also used quite frequently so here it is.


0 Comments

Leave a Reply

Your email address will not be published.