in touch_charts/recon.py [0:0]
def train(self, data, writer):
total_loss = 0
iterations = 0
self.encoder.train()
for k, batch in enumerate(tqdm(data)):
self.optimizer.zero_grad()
# initialize data
sim_touch = batch['sim_touch'].cuda()
depth = batch['depth'].cuda()
ref_frame = batch['ref']
gt_points = batch['samples'].cuda()
# inference
pred_depth, pred_points = self.encoder(sim_touch, depth, ref_frame, empty = batch['empty'].cuda())
# losses
loss = point_loss = self.args.loss_coeff * utils.point_loss(pred_points, gt_points)
total_loss += point_loss.item()
# backprop
loss.backward()
self.optimizer.step()
# log
message = f'Train || Epoch: {self.epoch}, loss: {loss.item():.5f} '
message += f'|| best_loss: {self.best_loss :.5f}'
tqdm.write(message)
iterations += 1.
writer.add_scalars('train', {self.args.exp_id: total_loss / iterations}, self.epoch)