def validate()

in touch_charts/recon.py [0:0]


	def validate(self, data, writer):
		total_loss = 0
		self.encoder.eval()

		# loop through every class
		for v, valid_loader in enumerate(data):
			num_examples = 0
			class_loss = 0

			# loop through every batch
			for k, batch in enumerate(tqdm(valid_loader)):

				# initialize data
				sim_touch = batch['sim_touch'].cuda()
				depth = batch['depth'].cuda()
				ref_frame = batch['ref']
				gt_points = batch['samples'].cuda()
				obj_class = batch['class'][0]
				batch_size = gt_points.shape[0]

				# inference
				pred_depth, pred_points = self.encoder( sim_touch, depth, ref_frame, empty = batch['empty'].cuda())

				# losses
				point_loss = self.args.loss_coeff * utils.point_loss(pred_points, gt_points)

				# log
				num_examples += float(batch_size)
				class_loss += point_loss * float(batch_size)


			# log
			class_loss = (class_loss / num_examples)
			message = f'Valid || Epoch: {self.epoch}, class: {obj_class}, loss: {class_loss:.5f}'
			message += f' ||  best_loss: {self.best_loss:.5f}'
			tqdm.write(message)
			total_loss += (class_loss / float(len(self.classes)))

		# log
		print('*******************************************************')
		print(f'Total validation loss: {total_loss}')
		print('*******************************************************')
		if not self.args.eval:
			writer.add_scalars('valid', {self.args.exp_id: total_loss}, self.epoch)
		self.current_loss = total_loss