def train()

in touch_charts/recon.py [0:0]


	def train(self, data, writer):
		total_loss = 0
		iterations = 0
		self.encoder.train()
		for k, batch in enumerate(tqdm(data)):
			self.optimizer.zero_grad()

			# initialize data
			sim_touch = batch['sim_touch'].cuda()
			depth = batch['depth'].cuda()
			ref_frame = batch['ref']
			gt_points = batch['samples'].cuda()

			# inference
			pred_depth, pred_points = self.encoder(sim_touch, depth, ref_frame, empty = batch['empty'].cuda())

			# losses
			loss = point_loss = self.args.loss_coeff * utils.point_loss(pred_points, gt_points)
			total_loss += point_loss.item()

			# backprop
			loss.backward()
			self.optimizer.step()

			# log
			message = f'Train || Epoch: {self.epoch},  loss: {loss.item():.5f} '
			message += f'|| best_loss:  {self.best_loss :.5f}'
			tqdm.write(message)
			iterations += 1.

		writer.add_scalars('train', {self.args.exp_id: total_loss / iterations}, self.epoch)