def train()

in vision_charts/recon.py [0:0]


	def train(self, data, writer):

		total_loss = 0
		iterations = 0
		self.encoder.train()
		for k, batch in enumerate(tqdm(data)):
			self.optimizer.zero_grad()

			# initialize data
			img_occ = batch['img_occ'].cuda()
			img_unocc = batch['img_unocc'].cuda()
			gt_points = batch['gt_points'].cuda()

			# inference
			# self.encoder.img_encoder.pooling(img_unocc, gt_points, debug=True)
			verts = self.encoder(img_occ, img_unocc, batch)

			# losses
			loss = utils.chamfer_distance(verts, self.adj_info['faces'], gt_points, num=self.num_samples)
			loss = self.args.loss_coeff * loss.mean()

			# backprop
			loss.backward()
			self.optimizer.step()

			# log
			message = f'Train || Epoch: {self.epoch}, loss: {loss.item():.2f}, b_ptp:  {self.best_loss:.2f}'
			tqdm.write(message)
			total_loss += loss.item()
			iterations += 1.

		writer.add_scalars('train_loss', {self.args.exp_id : total_loss / iterations}, self.epoch)