in touch_charts/recon.py [0:0]
def validate(self, data, writer):
total_loss = 0
self.encoder.eval()
# loop through every class
for v, valid_loader in enumerate(data):
num_examples = 0
class_loss = 0
# loop through every batch
for k, batch in enumerate(tqdm(valid_loader)):
# initialize data
sim_touch = batch['sim_touch'].cuda()
depth = batch['depth'].cuda()
ref_frame = batch['ref']
gt_points = batch['samples'].cuda()
obj_class = batch['class'][0]
batch_size = gt_points.shape[0]
# inference
pred_depth, pred_points = self.encoder( sim_touch, depth, ref_frame, empty = batch['empty'].cuda())
# losses
point_loss = self.args.loss_coeff * utils.point_loss(pred_points, gt_points)
# log
num_examples += float(batch_size)
class_loss += point_loss * float(batch_size)
# log
class_loss = (class_loss / num_examples)
message = f'Valid || Epoch: {self.epoch}, class: {obj_class}, loss: {class_loss:.5f}'
message += f' || best_loss: {self.best_loss:.5f}'
tqdm.write(message)
total_loss += (class_loss / float(len(self.classes)))
# log
print('*******************************************************')
print(f'Total validation loss: {total_loss}')
print('*******************************************************')
if not self.args.eval:
writer.add_scalars('valid', {self.args.exp_id: total_loss}, self.epoch)
self.current_loss = total_loss