Page MenuHomec4science

losses.py
No OneTemporary

File Metadata

Created
Mon, Jun 3, 02:38

losses.py

import torch
from torch import nn
class CrossEntropyLoss(nn.Module):
def __init__(self, class_weights, ignore_index=255):
super().__init__()
self.loss = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights),
ignore_index=255,
reduction='none')
def forward(self, pred, target, weights=None):
loss = self.loss(pred, target)
if weights is not None:
loss *= weights
return loss.mean()

Event Timeline