in vision_charts/recon.py [0:0]
def validate(self, data, writer):
total_loss = 0
# local losses at different distances from the touch sites
self.encoder.eval()
all_losses = []
for v, valid_loader in enumerate(data):
num_examples = 0
class_loss = 0
for k, batch in enumerate(tqdm(valid_loader)):
# initialize data
img_occ = batch['img_occ'].cuda()
img_unocc = batch['img_unocc'].cuda()
gt_points = batch['gt_points'].cuda()
batch_size = img_occ.shape[0]
obj_class = batch['class'][0]
# model prediction
verts = self.encoder(img_occ, img_unocc, batch)
# losses
loss = utils.chamfer_distance(verts, self.adj_info['faces'], gt_points, num=self.num_samples)
all_losses += [l.item() for l in loss*self.args.loss_coeff]
loss = self.args.loss_coeff * loss.mean() * batch_size
# logs
num_examples += float(batch_size)
class_loss += loss
print_loss = (class_loss / num_examples)
message = f'Valid || Epoch: {self.epoch}, class: {obj_class}, f1: {print_loss:.2f}'
tqdm.write(message)
total_loss += (print_loss / float(len(self.classes)))
print('*******************************************************')
print(f'Validation Accuracy: {total_loss}')
print('*******************************************************')
writer.add_scalars('valid_ptp', {self.args.exp_id: total_loss}, self.epoch)
self.current_loss = total_loss