def train_one_epoch()

in downstream/votenet_det_new/lib/train.py [0:0]


def train_one_epoch(net, train_dataloader, optimizer, bnm_scheduler, epoch_cnt, dataset_config, writer, config):
    stat_dict = {} # collect statistics
    adjust_learning_rate(optimizer, epoch_cnt, config)
    bnm_scheduler.step() # decay BN momentum
    net.train() # set model to training mode
    for batch_idx, batch_data_label in enumerate(train_dataloader):
        for key in batch_data_label:
            batch_data_label[key] = batch_data_label[key].cuda()

        # Forward pass
        optimizer.zero_grad()
        inputs = {'point_clouds': batch_data_label['point_clouds']}
        if 'voxel_coords' in batch_data_label:
            inputs.update({
                'voxel_coords': batch_data_label['voxel_coords'],
                'voxel_inds':   batch_data_label['voxel_inds'],
                'voxel_feats':  batch_data_label['voxel_feats']})

        end_points = net(inputs)
        
        # Compute loss and gradients, update parameters.
        for key in batch_data_label:
            assert(key not in end_points)
            end_points[key] = batch_data_label[key]
        loss, end_points = criterion(end_points, dataset_config)
        loss.backward()
        optimizer.step()

        # Accumulate statistics and print out
        for key in end_points:
            if 'loss' in key or 'acc' in key or 'ratio' in key:
                if key not in stat_dict: stat_dict[key] = 0
                stat_dict[key] += end_points[key].item()

        batch_interval = 10
        if (batch_idx+1) % batch_interval == 0:
            logging.info(' ---- batch: %03d ----' % (batch_idx+1))
            for key in stat_dict:
                writer.add_scalar('training/{}'.format(key), stat_dict[key]/batch_interval, 
                                  (epoch_cnt*len(train_dataloader)+batch_idx)*config.data.batch_size)
            for key in sorted(stat_dict.keys()):
                logging.info('mean %s: %f'%(key, stat_dict[key]/batch_interval))
                stat_dict[key] = 0