def train()

in agents/neural_agent.py [0:0]


def train(output_dir,
          action_tier_name,
          task_ids,
          cache,
          train_batch_size,
          learning_rate,
          max_train_actions,
          updates,
          negative_sampling_prob,
          save_checkpoints_every,
          fusion_place,
          network_type,
          balance_classes,
          num_auccess_actions,
          eval_every,
          action_layers,
          action_hidden_size,
          cosine_scheduler,
          dev_tasks_ids=None):

    logging.info('Preprocessing train data')

    training_data = cache.get_sample(task_ids, max_train_actions)
    task_indices, is_solved, actions, simulator, observations = (
        compact_simulation_data_to_trainset(action_tier_name, **training_data))

    logging.info('Creating eval subset from train')
    eval_train = create_balanced_eval_set(cache, simulator.task_ids,
                                          XE_EVAL_SIZE, action_tier_name)
    if dev_tasks_ids is not None:
        logging.info('Creating eval subset from dev')
        eval_dev = create_balanced_eval_set(cache, dev_tasks_ids, XE_EVAL_SIZE,
                                            action_tier_name)
    else:
        eval_dev = None

    logging.info('Tran set: size=%d, positive_ratio=%.2f%%', len(is_solved),
                 is_solved.float().mean().item() * 100)

    assert not balance_classes or (negative_sampling_prob == 1), (
        balance_classes, negative_sampling_prob)

    device = nets.DEVICE
    model_kwargs = dict(network_type=network_type,
                        action_space_dim=simulator.action_space_dim,
                        fusion_place=fusion_place,
                        action_hidden_size=action_hidden_size,
                        action_layers=action_layers)
    model = build_model(**model_kwargs)
    model.train()
    model.to(device)
    logging.info(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if cosine_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=updates)
    else:
        scheduler = None
    logging.info('Starting actual training for %d updates', updates)

    rng = np.random.RandomState(42)

    def train_indices_sampler():
        indices = np.arange(len(is_solved))
        if balance_classes:
            solved_mask = is_solved.numpy() > 0
            positive_indices = indices[solved_mask]
            negative_indices = indices[~solved_mask]
            positive_size = train_batch_size // 2
            while True:
                positives = rng.choice(positive_indices, size=positive_size)
                negatives = rng.choice(negative_indices,
                                       size=train_batch_size - positive_size)
                positive_size = train_batch_size - positive_size
                yield np.concatenate((positives, negatives))
        elif negative_sampling_prob < 1:
            probs = (is_solved.numpy() * (1.0 - negative_sampling_prob) +
                     negative_sampling_prob)
            probs /= probs.sum()
            while True:
                yield rng.choice(indices, size=train_batch_size, p=probs)
        else:
            while True:
                yield rng.choice(indices, size=train_batch_size)

    last_checkpoint = get_latest_checkpoint(output_dir)
    batch_start = 0
    if last_checkpoint is not None:
        logging.info('Going to load from %s', last_checkpoint)
        last_checkpoint = torch.load(last_checkpoint)
        model.load_state_dict(last_checkpoint['model'])
        optimizer.load_state_dict(last_checkpoint['optim'])
        rng.set_state(last_checkpoint['rng'])
        batch_start = last_checkpoint['done_batches']
        if scheduler is not None:
            scheduler.load_state_dict(last_checkpoint['scheduler'])

    def print_eval_stats(batch_id):
        logging.info('Start eval')
        eval_batch_size = train_batch_size * 4
        stats = {}
        stats['batch_id'] = batch_id + 1
        stats['train_loss'] = eval_loss(model, eval_train, eval_batch_size)
        if eval_dev:
            stats['dev_loss'] = eval_loss(model, eval_dev, eval_batch_size)
        if num_auccess_actions > 0:
            logging.info('Start AUCCESS eval')
            stats['train_auccess'] = _eval_and_score_actions(
                cache, model, eval_train[3], num_auccess_actions,
                eval_batch_size, eval_train[4])
            if eval_dev:
                stats['dev_auccess'] = _eval_and_score_actions(
                    cache, model, eval_dev[3], num_auccess_actions,
                    eval_batch_size, eval_dev[4])

        logging.info('__log__: %s', stats)

    report_every = 125
    logging.info('Report every %d; eval every %d', report_every, eval_every)
    if save_checkpoints_every > eval_every:
        save_checkpoints_every -= save_checkpoints_every % eval_every

    print_eval_stats(0)

    losses = []
    last_time = time.time()
    observations = observations.to(device)
    actions = actions.pin_memory()
    is_solved = is_solved.pin_memory()
    for batch_id, batch_indices in enumerate(train_indices_sampler(),
                                             start=batch_start):
        if batch_id >= updates:
            break
        if scheduler is not None:
            scheduler.step()
        model.train()
        batch_task_indices = task_indices[batch_indices]
        batch_observations = observations[batch_task_indices]
        batch_actions = actions[batch_indices].to(device, non_blocking=True)
        batch_is_solved = is_solved[batch_indices].to(device, non_blocking=True)

        optimizer.zero_grad()
        loss = model.ce_loss(model(batch_observations, batch_actions),
                             batch_is_solved)
        loss.backward()
        optimizer.step()
        losses.append(loss.mean().item())
        if save_checkpoints_every > 0:
            if (batch_id + 1) % save_checkpoints_every == 0:
                fpath = os.path.join(output_dir, 'ckpt.%08d' % (batch_id + 1))
                logging.info('Saving: %s', fpath)
                torch.save(
                    dict(
                        model_kwargs=model_kwargs,
                        model=model.state_dict(),
                        optim=optimizer.state_dict(),
                        done_batches=batch_id + 1,
                        rng=rng.get_state(),
                        scheduler=scheduler and scheduler.state_dict(),
                    ), fpath)
        if (batch_id + 1) % eval_every == 0:
            print_eval_stats(batch_id)
        if (batch_id + 1) % report_every == 0:
            speed = report_every / (time.time() - last_time)
            last_time = time.time()
            logging.debug(
                'Iter: %s, examples: %d, mean loss: %f, speed: %.1f batch/sec,'
                ' lr: %f', batch_id + 1, (batch_id + 1) * train_batch_size,
                np.mean(losses[-report_every:]), speed, get_lr(optimizer))
    return model.cpu()