def train_one_epoch()

in downstream/votenet/lib/ddp_trainer.py [0:0]


    def train_one_epoch(self, epoch_cnt):
        stat_dict = {} # collect statistics
        DetectionTrainer.adjust_learning_rate(self.optimizer, epoch_cnt, self.config)
        self.bnm_scheduler.step() # decay BN momentum
        self.net.train() # set model to training mode
        for batch_idx, batch_data_label in enumerate(self.train_dataloader):
            for key in batch_data_label:
                if key == 'scan_name':
                    continue
                batch_data_label[key] = batch_data_label[key].cuda()

            # Forward pass
            self.optimizer.zero_grad()
            inputs = {'point_clouds': batch_data_label['point_clouds']}
            if 'voxel_coords' in batch_data_label:
                inputs.update({
                    'voxel_coords': batch_data_label['voxel_coords'],
                    'voxel_inds':   batch_data_label['voxel_inds'],
                    'voxel_feats':  batch_data_label['voxel_feats']})

            end_points = self.net(inputs)
            
            # Compute loss and gradients, update parameters.
            for key in batch_data_label:
                assert(key not in end_points)
                end_points[key] = batch_data_label[key]
            loss, end_points = criterion(end_points, self.dataset_config)
            loss.backward()
            self.optimizer.step()

            # Accumulate statistics and print out
            for key in end_points:
                if 'loss' in key or 'acc' in key or 'ratio' in key:
                    if key not in stat_dict: stat_dict[key] = 0
                    stat_dict[key] += end_points[key].item()

            batch_interval = 10
            if ((batch_idx+1) % batch_interval == 0) and self.is_master:
                logging.info(' ---- batch: %03d ----' % (batch_idx+1))
                for key in stat_dict:
                    self.writer.add_scalar('training/{}'.format(key), stat_dict[key]/batch_interval, 
                                          (epoch_cnt*len(self.train_dataloader)+batch_idx)*self.config.data.batch_size)
                for key in sorted(stat_dict.keys()):
                    logging.info('mean %s: %f'%(key, stat_dict[key]/batch_interval))
                    stat_dict[key] = 0