in agents/neural_agent.py [0:0]
def train(output_dir,
action_tier_name,
task_ids,
cache,
train_batch_size,
learning_rate,
max_train_actions,
updates,
negative_sampling_prob,
save_checkpoints_every,
fusion_place,
network_type,
balance_classes,
num_auccess_actions,
eval_every,
action_layers,
action_hidden_size,
cosine_scheduler,
n_samples_per_task,
use_sample_distance_aux_loss,
aux_loss_hyperparams,
checkpoint_dir,
tensorboard_dir="",
dev_tasks_ids=None,
debug=False,
**excess_kwargs):
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
if tensorboard_dir:
from torch.utils.tensorboard import SummaryWriter
logging.info("Tensorboard dir :" + tensorboard_dir)
writer = SummaryWriter(tensorboard_dir)
logging.info('Preprocessing train data')
if debug:
logging.warning("Debugging ON")
logging.warning("Subsampling dataset")
task_ids = task_ids[:10]
if dev_tasks_ids:
dev_tasks_ids = dev_tasks_ids[:10]
training_data = cache.get_sample(task_ids, max_train_actions)
task_indices, is_solved, actions, simulator, observations, positive_in_task, negative_in_task = (
compact_simulation_data_to_trainset(action_tier_name, **training_data))
logging.info('Creating eval subset from train')
eval_train = create_balanced_eval_set(cache, simulator.task_ids,
XE_EVAL_SIZE, action_tier_name)
if dev_tasks_ids is not None:
logging.info('Creating eval subset from dev')
eval_dev = create_balanced_eval_set(cache, dev_tasks_ids, XE_EVAL_SIZE,
action_tier_name)
else:
eval_dev = None
aux_loss_eval = None
aux_loss_eval_dev = None
if use_sample_distance_aux_loss:
logging.info("Creating eval set for auxiliary loss from train")
aux_loss_eval = create_metric_eval_set(task_ids, action_tier_name,
cache, AUX_LOSS_EVAL_TASKS,
AUX_EVAL_ACTIONS_PER_TASK,
simulator)
if dev_tasks_ids is not None:
logging.info("Creating eval set for auxiliary loss from dev")
aux_loss_eval_dev = create_metric_eval_set(
dev_tasks_ids,
action_tier_name,
cache,
AUX_LOSS_EVAL_TASKS,
AUX_EVAL_ACTIONS_PER_TASK,
simulator=None)
logging.info('Tran set: size=%d, positive_ratio=%.2f%%', len(is_solved),
is_solved.float().mean().item() * 100)
assert not balance_classes or (negative_sampling_prob == 1), (
balance_classes, negative_sampling_prob)
device = nets.DEVICE
model_kwargs = dict(network_type=network_type,
action_space_dim=simulator.action_space_dim,
fusion_place=fusion_place,
action_hidden_size=action_hidden_size,
action_layers=action_layers)
if use_sample_distance_aux_loss:
n_regressor_outputs = 1 if aux_loss_hyperparams[
"regression_type"] == "mse" else aux_loss_hyperparams[
"n_regression_bins"]
model_kwargs.update(
dict(
repr_merging_method=aux_loss_hyperparams["repr_merging_method"],
embedding_dim=aux_loss_hyperparams["embedding_dim"],
embeddor_type=aux_loss_hyperparams["embeddor_type"],
n_regressor_outputs=n_regressor_outputs))
model = build_model(**model_kwargs)
model = nn.DataParallel(model)
model.train()
model.to(device)
logging.info(model)
logging.info(f"Model will use {torch.cuda.device_count()} GPU's to train")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if cosine_scheduler:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=updates)
else:
scheduler = None
logging.info('Starting actual training for %d updates', updates)
rng = np.random.RandomState(42)
def task_balanced_sampler():
n_tasks_per_batch = train_batch_size // n_samples_per_task
n_tasks_total = len(simulator.initial_scenes)
n_samples_per_task_label = n_samples_per_task // 2
while True:
batch_task_indices = rng.choice(n_tasks_total,
size=n_tasks_per_batch,
replace=debug)
indices = []
for task_idx in batch_task_indices:
indices.append(
rng.choice(negative_in_task[task_idx],
n_samples_per_task_label))
try:
indices.append(
rng.choice(positive_in_task[task_idx],
n_samples_per_task_label))
except ValueError: # No solutions for given template in dataset hence add more negative samples
indices.append(
rng.choice(negative_in_task[task_idx],
n_samples_per_task_label))
indices = np.concatenate(indices)
yield indices, batch_task_indices
def train_indices_sampler():
indices = np.arange(len(is_solved))
if balance_classes:
solved_mask = is_solved.numpy() > 0
positive_indices = indices[solved_mask]
negative_indices = indices[~solved_mask]
positive_size = train_batch_size // 2
while True:
positives = rng.choice(positive_indices, size=positive_size)
negatives = rng.choice(negative_indices,
size=train_batch_size - positive_size)
positive_size = train_batch_size - positive_size
yield np.stack((positives, negatives), axis=1).reshape(-1), None
elif negative_sampling_prob < 1:
probs = (is_solved.numpy() * (1.0 - negative_sampling_prob) +
negative_sampling_prob)
probs /= probs.sum()
while True:
yield rng.choice(indices, size=train_batch_size, p=probs), None
else:
while True:
yield rng.choice(indices, size=train_batch_size), None
last_checkpoint = get_latest_checkpoint(checkpoint_dir)
batch_start = 0
if last_checkpoint is not None:
logging.info('Going to load from %s', last_checkpoint)
last_checkpoint = torch.load(last_checkpoint)
model.load_state_dict(last_checkpoint['model'])
optimizer.load_state_dict(last_checkpoint['optim'])
rng.set_state(last_checkpoint['rng'])
batch_start = last_checkpoint['done_batches']
if scheduler is not None:
scheduler.load_state_dict(last_checkpoint['scheduler'])
def print_eval_stats(batch_id):
with torch.no_grad():
logging.info('Start eval')
eval_batch_size = train_batch_size * 2
stats = {}
stats['batch_id'] = batch_id + 1
stats['train_loss'] = eval_loss(model, eval_train, eval_batch_size)
if eval_dev:
stats['dev_loss'] = eval_loss(model, eval_dev, eval_batch_size)
if num_auccess_actions > 0:
logging.info('Start AUCCESS eval')
stats['train_auccess'] = _eval_and_score_actions(
cache, model, eval_train[3], num_auccess_actions,
eval_batch_size, eval_train[4])
if eval_dev:
stats['dev_auccess'] = _eval_and_score_actions(
cache, model, eval_dev[3], num_auccess_actions,
eval_batch_size, eval_dev[4])
if aux_loss_eval:
logging.info("starting aux loss eval")
stats["aux_loss"] = compute_loss_on_eval_batch(
model, n_samples_per_task, aux_loss_hyperparams, *aux_loss_eval)
if aux_loss_eval_dev:
logging.info("starting aux loss dev eval")
stats["aux_loss_dev"] = compute_loss_on_eval_batch(
model, n_samples_per_task, aux_loss_hyperparams, *aux_loss_eval_dev)
if tensorboard_dir:
for stat in stats:
if stat == "batch_id":
continue
writer.add_scalar(stat, stats[stat], batch_id)
logging.info('__log__: %s', stats)
report_every = 5 if debug else 100
logging.info('Report every %d; eval every %d', report_every, eval_every)
if save_checkpoints_every > eval_every:
save_checkpoints_every -= save_checkpoints_every % eval_every
if batch_start == 0:
print_eval_stats(0)
losses = []
aux_losses = []
ce_losses = []
last_time = time.time()
observations = observations.to(device)
# actions = actions.pin_memory()
# is_solved = is_solved.pin_memory()
if use_sample_distance_aux_loss:
num_workers = 5 if debug else 20
parallel_simulator = ParallelPhyreSimulator(task_ids,
action_tier_name,
num_workers,
MAX_LEN,
train_batch_size,
requires_imgs=False)
logging.info(f"Starting parallel simulator with {num_workers} workers")
f = weakref.finalize(parallel_simulator, parallel_simulator.close)
if n_samples_per_task == 0:
assert not use_sample_distance_aux_loss
sampler = train_indices_sampler()
else:
assert (n_samples_per_task % 2 == 0 and balance_classes)
sampler = task_balanced_sampler()
for batch_id, (batch_indices,
unique_task_indices) in enumerate(sampler,
start=batch_start):
if batch_id >= updates:
break
if scheduler is not None:
scheduler.step()
model.train()
batch_task_indices = task_indices[batch_indices]
batch_observations = observations[batch_task_indices]
batch_actions = actions[batch_indices]
batch_is_solved = is_solved[batch_indices].to(device, non_blocking=True)
if use_sample_distance_aux_loss:
_, _, batch_rollouts, batch_masks = parallel_simulator.simulate_parallel(
batch_task_indices,
batch_actions,
need_images=False,
need_featurized_objects=True)
batch_rollouts = torch.from_numpy(batch_rollouts)
batch_masks = torch.from_numpy(batch_masks)
batch_actions = batch_actions.to(device, non_blocking=True)
if use_sample_distance_aux_loss:
logits, embeddings = model(
batch_observations,
batch_actions,
get_embeddings=use_sample_distance_aux_loss)
aux_loss = sample_distance_loss(model, embeddings,
batch_task_indices, batch_rollouts,
batch_masks, unique_task_indices, n_samples_per_task,
aux_loss_hyperparams)
else:
logits = model(batch_observations, batch_actions)
aux_loss = torch.tensor([0.0]).to(nets.DEVICE)
#TODO add tensorboard, detailed evaluator and parallel simulator
optimizer.zero_grad()
classification_loss = model.module.ce_loss(logits,
batch_is_solved).mean()
loss = classification_loss + aux_loss * aux_loss_hyperparams["weight"]
loss.backward()
optimizer.step()
losses.append(loss.mean().item())
aux_losses.append(aux_loss.mean().item())
ce_losses.append(classification_loss.mean().item())
if save_checkpoints_every > 0:
if (batch_id + 1) % save_checkpoints_every == 0 or (batch_id +
1) == updates:
fpath = os.path.join(checkpoint_dir, 'ckpt.%08d' % (batch_id + 1))
logging.info('Saving: %s', fpath)
torch.save(
dict(
model_kwargs=model_kwargs,
model=model.state_dict(),
optim=optimizer.state_dict(),
done_batches=batch_id + 1,
rng=rng.get_state(),
scheduler=scheduler and scheduler.state_dict(),
), fpath)
if (batch_id + 1) % eval_every == 0:
print_eval_stats(batch_id)
if (batch_id + 1) % report_every == 0:
speed = report_every / (time.time() - last_time)
last_time = time.time()
logging.debug(
'Iter: %s, examples: %d, mean loss: %f, mean ce: %f, mean aux: %f,'
'speed: %.1f batch/sec, lr: %f', batch_id + 1,
(batch_id + 1) * train_batch_size,
np.mean(losses[-report_every:]),
np.mean(ce_losses[-report_every:]),
np.mean(aux_losses[-report_every:]), speed, get_lr(optimizer))
return model.module.cpu()