in agents/im_fwd_agent.py [0:0]
def train(cls, model, dataset, output_dir, summary_writer,
full_eval_from_model, cfg):
"""Main train function."""
updates = cfg.train.num_iter
report_every = cfg.train.report_every
save_checkpoints_every = cfg.train.save_checkpoints_every
full_eval_every = cfg.train.full_eval_every
train_batch_size = cfg.train.batch_size
max_frames_fwd = cfg.train.frames_per_clip
n_hist_frames = cfg.train.n_hist_frames # Frames used to predict the future
loss_cfg = cfg.train.loss
opt_params = cfg.opt
# action_tier_name = cfg.tier
n_fwd_times = cfg.train.n_fwd_times
n_fwd_times_incur_loss = cfg.train.n_fwd_times_incur_loss
run_decode = cfg.train.run_decode
train_modules_subset = cfg.train.modules_to_train
# nslices (slice out the input for training)
num_slices = cfg.train.num_slices
if max_frames_fwd is not None and (max_frames_fwd <= n_hist_frames):
logging.warning(
'Cant train prediction model, max_frames_fwd too low')
assert loss_cfg.wt_pix == 0 or run_decode is True, (
'If the loss is non zero, the decoder should be running')
# logging.info('Creating eval subset from train')
# eval_train = create_balanced_eval_set(cache, dataset.task_ids,
# XE_EVAL_SIZE, action_tier_name)
# if dev_tasks_ids is not None:
# logging.info('Creating eval subset from dev')
# eval_dev = create_balanced_eval_set(cache, dev_tasks_ids, XE_EVAL_SIZE,
# action_tier_name)
# else:
# eval_dev = None
device = nets.DEVICE
model.train()
model.to(device)
logging.info("%s", model)
params_to_train = []
if train_modules_subset is not None:
mod_names = train_modules_subset.split('%')
logging.warning(
'Training only a few modules, listed below. NOTE: '
'BNs/dropout will still be in train mode. Explicitly '
'set those to eval mode if thats not desired.')
for mod_name in mod_names:
mod = getattr(model.module, mod_name)
logging.warning('Training %s: %s', mod_name, mod)
params_to_train.extend(mod.parameters())
else:
params_to_train = model.parameters()
optimizer = hydra.utils.instantiate(opt_params, params_to_train)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=updates)
logging.info('Starting actual training for %d updates', updates)
last_checkpoint = get_latest_checkpoint(output_dir)
batch_start = 0 # By default, starting from iteration 0, unles loading mdl
if last_checkpoint is not None:
logging.info('Going to load from %s', last_checkpoint)
last_checkpoint = torch.load(last_checkpoint)
model.load_state_dict(last_checkpoint['model'])
optimizer.load_state_dict(last_checkpoint['optim'])
# Subtracting 1 since we store batch_id + 1 when calling save_agent
batch_start = last_checkpoint['done_batches'] - 1
if scheduler is not None:
scheduler.load_state_dict(last_checkpoint['scheduler'])
def run_full_eval(batch_id):
logging.info('Running full eval')
results = {} # To store to a json
eval_stats = full_eval_from_model(model)
metric = eval_stats.compute_all_metrics()
results['metrics'] = metric
results[
'metrics_rollout'] = eval_stats.compute_all_metrics_over_rollout(
)
results[
'metrics_per_task'] = eval_stats.compute_all_metrics_per_task(
)
max_test_attempts_per_task = (cfg.max_test_attempts_per_task
or phyre.MAX_TEST_ATTEMPTS)
results['parsed_args'] = dict(
# cfg=cfg, # Not json serializable, anyway will be stored in dir
main_kwargs=dict(
eval_setup_name=cfg.eval_setup_name,
fold_id=cfg.fold_id,
use_test_split=cfg.use_test_split,
agent_type=cfg.agent.type,
max_test_attempts_per_task=max_test_attempts_per_task,
output_dir=output_dir))
results['target_metric'] = (
results['metrics']['independent_solved_by_aucs']
[max_test_attempts_per_task])
results['target_metric_over_time'] = [
el['independent_solved_by_aucs'][max_test_attempts_per_task]
for el in results['metrics_rollout']
]
logging.info('Iter %d: %s; Over rollout: %s', (batch_id + 1),
results['target_metric'],
results['target_metric_over_time'])
score = metric['independent_solved_by_aucs'][-1]
summary_writer.add_scalar('FullEval/AUCCESS', score, batch_id + 1)
for solved_by_iter in metric['global_solved_by']:
summary_writer.add_scalar(
'FullEval/solved_by_{}'.format(solved_by_iter),
metric['global_solved_by'][solved_by_iter], batch_id + 1)
logging.info('Full eval perf @ %d: %s', batch_id + 1, score)
for i, metric in enumerate(results['metrics_rollout']):
summary_writer.add_scalar(
'FullEvalRollout/AUCCESS/{}'.format(i + 1),
metric['independent_solved_by_aucs'][-1], batch_id + 1)
summary_writer.add_scalar(
'FullEvalRollout/solved_by_100/{}'.format(i),
metric['global_solved_by'][100], batch_id + 1)
respath = os.path.join(
output_dir,
'results_intermediate/{:08d}.json'.format(batch_id + 1))
os.makedirs(os.path.dirname(respath), exist_ok=True)
with open(respath, 'w') as fout:
json.dump(results, fout)
logging.info('Report every %d; full eval every %d', report_every,
full_eval_every)
if save_checkpoints_every > full_eval_every:
save_checkpoints_every -= save_checkpoints_every % full_eval_every
losses_report = {}
last_time = time.time()
assert train_batch_size > 1 and train_batch_size % 2 == 0, (
'Needs to get 2 elements at least to balance out')
for batch_data_id, batch_data in enumerate(
torch.utils.data.DataLoader(
dataset,
num_workers=get_num_workers(
cfg.train.data_loader.num_workers,
dataset.frames_per_clip),
pin_memory=False,
# Asking for half the batch size since the dataloader is designed
# to give 2 elements per batch (for class balancing)
batch_size=train_batch_size // 2)):
# When the training restarts, it resets to the start of the data loader
batch_id = batch_data_id + batch_start
if (batch_id + 1) >= updates:
save_agent(output_dir, batch_id + 1, model, optimizer,
scheduler)
break
model.train()
batch_is_solved = batch_data['is_solved']
batch_is_solved = batch_is_solved.to(device, non_blocking=True)
batch_is_solved = batch_is_solved.reshape((-1, ))
batch_vid_obs = batch_data['vid_obs']
batch_vid_obs = batch_vid_obs.reshape(
[-1] + list(batch_vid_obs.shape[2:]))
batch_vid_obs = batch_vid_obs.to(device)
# Run the forward image model on the video
_, batch_losses = model.forward(
batch_vid_obs,
batch_is_solved,
n_hist_frames=n_hist_frames,
n_fwd_times=n_fwd_times,
n_fwd_times_incur_loss=n_fwd_times_incur_loss,
run_decode=run_decode,
compute_losses=True,
need_intermediate=loss_cfg.on_intermediate,
autoenc_loss_ratio=loss_cfg.autoenc_loss_ratio,
nslices=num_slices)
optimizer.zero_grad()
total_loss = 0
# Mean over each loss type from each replica
for loss_type in batch_losses:
loss_wt = getattr(loss_cfg, 'wt_' + loss_type)
if loss_wt <= 0:
continue
loss_val = loss_wt * torch.mean(batch_losses[loss_type], dim=0)
if loss_type not in losses_report:
losses_report[loss_type] = []
losses_report[loss_type].append(loss_val.item())
total_loss += loss_val
total_loss.backward()
optimizer.step()
if (save_checkpoints_every > 0
and (batch_id + 1) % save_checkpoints_every == 0):
save_agent(output_dir, batch_id + 1, model, optimizer,
scheduler)
# Removing intermediate eval since it doesnt seem very useful, using the
# full eval for now.
# if (batch_id + 1) % eval_every == 0:
# print_eval_stats(batch_id)
if (batch_id + 1) % report_every == 0:
speed = report_every / (time.time() - last_time)
last_time = time.time()
loss_stats = {
typ: np.mean(losses_report[typ][-report_every:])
for typ in losses_report if len(losses_report[typ]) > 0
}
logging.info(
'Iter: %s, examples: %d, mean loss: %s, speed: %.1f batch/sec,'
' lr: %f', batch_id + 1, (batch_id + 1) * train_batch_size,
loss_stats, speed, get_lr(optimizer))
for typ in loss_stats:
summary_writer.add_scalar('Loss/{}'.format(typ),
loss_stats[typ], batch_id + 1)
summary_writer.add_scalar('Loss/Total',
sum(loss_stats.values()),
batch_id + 1)
summary_writer.add_scalar('LR', get_lr(optimizer),
batch_id + 1)
summary_writer.add_scalar('Speed', speed, batch_id + 1)
# Add a histogram of the batch task IDs, to make sure it picks a
# variety of task
batch_templates = np.array(
dataset.task_ids)[batch_data['task_indices'].reshape(
(-1, ))].tolist()
batch_templates = np.array(
[int(el.split(':')[0]) for el in batch_templates])
gpu_mem_max = max([
torch.cuda.max_memory_allocated(device=i)
for i in range(torch.cuda.device_count())
])
summary_writer.add_scalar('GPU/Mem/Max', gpu_mem_max,
batch_id + 1)
summary_writer.add_histogram('Templates',
batch_templates,
global_step=(batch_id + 1),
bins=25)
# Visualize a couple train videos, and actual rollouts if pix is
# being trained
# Just visualizing the first 256 videos in case the batch size is
# larger; somehow the visualizations get corrupted (grey bg) for
# more. Also no point filling up the memory.
# Storing less frequently than the rest of the logs (takes lot of space)
if n_fwd_times > 0 and (batch_id + 1) % (report_every * 10) == 0:
summary_writer.add_video(
'InputAndRollout/train',
gen_vis_vid_preds(batch_vid_obs[:256],
model,
n_fwd_times=None,
run_decode=run_decode,
n_hist_frames=n_hist_frames),
(batch_id + 1))
if (batch_id + 1) % full_eval_every == 0:
run_full_eval(batch_id)
if scheduler is not None:
scheduler.step()
return model.cpu()