in train.py [0:0]
def display_val(model, loss_criterion, writer, index, dataset_val, opt):
losses = []
with torch.no_grad():
for i, val_data in enumerate(dataset_val):
if i < opt.validation_batches:
output = model.forward(val_data)
loss = loss_criterion(output['binaural_spectrogram'], output['audio_gt'])
losses.append(loss.item())
else:
break
avg_loss = sum(losses)/len(losses)
if opt.tensorboard:
writer.add_scalar('data/val_loss', avg_loss, index)
print('val loss: %.3f' % avg_loss)
return avg_loss