in vision_charts/recon.py [0:0]
def train(self, data, writer):
total_loss = 0
iterations = 0
self.encoder.train()
for k, batch in enumerate(tqdm(data)):
self.optimizer.zero_grad()
# initialize data
img_occ = batch['img_occ'].cuda()
img_unocc = batch['img_unocc'].cuda()
gt_points = batch['gt_points'].cuda()
# inference
# self.encoder.img_encoder.pooling(img_unocc, gt_points, debug=True)
verts = self.encoder(img_occ, img_unocc, batch)
# losses
loss = utils.chamfer_distance(verts, self.adj_info['faces'], gt_points, num=self.num_samples)
loss = self.args.loss_coeff * loss.mean()
# backprop
loss.backward()
self.optimizer.step()
# log
message = f'Train || Epoch: {self.epoch}, loss: {loss.item():.2f}, b_ptp: {self.best_loss:.2f}'
tqdm.write(message)
total_loss += loss.item()
iterations += 1.
writer.add_scalars('train_loss', {self.args.exp_id : total_loss / iterations}, self.epoch)