in seamseg/modules/losses.py [0:0]
def ohem_loss(loss, ohem=None):
if isinstance(loss, torch.Tensor):
loss = loss.view(loss.size(0), -1)
if ohem is None:
return loss.mean()
top_k = min(max(int(ohem * loss.size(1)), 1), loss.size(1))
if top_k != loss.size(1):
loss, _ = loss.topk(top_k, dim=1)
return loss.mean()
elif isinstance(loss, PackedSequence):
if ohem is None:
return sum(loss_i.mean() for loss_i in loss) / len(loss)
loss_out = loss.data.new_zeros(())
for loss_i in loss:
loss_i = loss_i.view(-1)
top_k = min(max(int(ohem * loss_i.numel()), 1), loss_i.numel())
if top_k != loss_i.numel():
loss_i, _ = loss_i.topk(top_k, dim=0)
loss_out += loss_i.mean()
return loss_out / len(loss)