in depth_upsampling/train.py [0:0]
def main(args):
batch_size = args.batch_size
num_iter = args.num_iter
upsample_factor = args.upsample_factor
start_itr = 0
patch_size = 256 if args.upsample_factor == 2 else 512
print('loading train dataset')
transform = Compose([transfroms.RandomCrop(height=patch_size, width=patch_size, upsample_factor=upsample_factor),
transfroms.RandomFilpLR(),
transfroms.ValidDepthMask(gt_low_limit=0.01),
transfroms.AsContiguousArray()])
train_dataset = ARKitScenesDataset(root=args.data_path, split='train',
upsample_factor=upsample_factor, transform=transform)
sampler = MultiEpochSampler(train_dataset, num_iter, start_itr, batch_size)
train_dataloader = DataLoader(train_dataset,
batch_size,
sampler=sampler,
num_workers=8 * int(torch.cuda.is_available()),
pin_memory=torch.cuda.is_available(),
drop_last=True)
print('loading validation dataset')
transform = Compose([transfroms.ModCrop(modulo=32),
transfroms.ValidDepthMask(gt_low_limit=0.01)])
val_dataset = ARKitScenesDataset(root=args.data_path, split='val',
upsample_factor=upsample_factor, transform=transform)
val_dataloader = DataLoader(val_dataset,
batch_size=1,
num_workers=8 * int(torch.cuda.is_available()),
pin_memory=torch.cuda.is_available())
print('building the network')
model = get_network(args.network, upsample_factor)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
if torch.cuda.is_available():
model.cuda()
model.train()
cudnn.benchmark = True
# init logs
if args.tbp is not None:
print('starting tensorboard')
tensorboard_path = os.path.join(args.log_dir, TENSORBOARD_DIR)
command = f'tensorboard --logdir {tensorboard_path} --port {args.tbp}'
tensorboard_process = subprocess.Popen(shlex.split(command), env=os.environ.copy())
train_tensorboard_writer = SummaryWriter(os.path.join(tensorboard_path, 'train'), flush_secs=30)
val_tensorboard_writer = SummaryWriter(os.path.join(tensorboard_path, 'val'), flush_secs=30)
else:
print('no tensorboard')
tensorboard_process = None
train_tensorboard_writer = None
val_tensorboard_writer = None
loss_fn = get_loss(args.network)
start_time = time.time()
step = 1
duration = 0
current_lr = -1
print("start training")
for input_batch in train_dataloader:
before_op_time = time.time()
input_batch = batch_to_cuda(input_batch)
optimizer.zero_grad()
output_batch = model(input_batch)
loss = loss_fn(output_batch, input_batch)
if np.isnan(loss.cpu().item()):
exit('NaN in loss occurred. Aborting training.')
loss.backward()
optimizer.step()
duration += time.time() - before_op_time
train_log(step=step, loss=loss, input_batch=input_batch, output_batch=output_batch,
tensorboard_writer=train_tensorboard_writer, current_lr=current_lr)
if step % args.eval_freq == 0:
eval_log(step, model, val_dataloader, val_tensorboard_writer)
if step and step % args.log_freq == 0:
examples_per_sec = args.batch_size / duration * args.log_freq
time_sofar = (time.time() - start_time) / 3600
training_time_left = (num_iter / step - 1.0) * time_sofar
print_string = 'examples/s: {:4.2f} | time elapsed: {:.2f}h | time left: {:.2f}h'
print(print_string.format(examples_per_sec, time_sofar, training_time_left))
duration = 0
if step % args.save_freq == 0:
checkpoint = {'step': step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}
save_file = os.path.join(args.log_dir, 'checkpoint_step-{}'.format(step))
torch.save(checkpoint, save_file)
step += 1
print('finished training')
if tensorboard_process is not None:
tensorboard_process.terminate()