in train.py [0:0]
def train_one_epoch():
stat_dict = {} # collect statistics
adjust_learning_rate(optimizer, EPOCH_CNT)
bnm_scheduler.step() # decay BN momentum
net.train() # set model to training mode
for batch_idx, batch_data_label in enumerate(TRAIN_DATALOADER):
for key in batch_data_label:
batch_data_label[key] = batch_data_label[key].to(device)
# Forward pass
optimizer.zero_grad()
inputs = {'point_clouds': batch_data_label['point_clouds']}
end_points = net(inputs)
# Compute loss and gradients, update parameters.
for key in batch_data_label:
assert(key not in end_points)
end_points[key] = batch_data_label[key]
loss, end_points = criterion(end_points, DATASET_CONFIG)
loss.backward()
optimizer.step()
# Accumulate statistics and print out
for key in end_points:
if 'loss' in key or 'acc' in key or 'ratio' in key:
if key not in stat_dict: stat_dict[key] = 0
stat_dict[key] += end_points[key].item()
batch_interval = 10
if (batch_idx+1) % batch_interval == 0:
log_string(' ---- batch: %03d ----' % (batch_idx+1))
TRAIN_VISUALIZER.log_scalars({key:stat_dict[key]/batch_interval for key in stat_dict},
(EPOCH_CNT*len(TRAIN_DATALOADER)+batch_idx)*BATCH_SIZE)
for key in sorted(stat_dict.keys()):
log_string('mean %s: %f'%(key, stat_dict[key]/batch_interval))
stat_dict[key] = 0