def train_single_proc()

in agents/neural_agent_contrastive.py [0:0]


def train_single_proc(
          rank,
          world_size,
          output_dir,
          action_tier_name,
          task_ids,
          cache,
          train_batch_size,
          learning_rate,
          max_train_actions,
          updates,
          negative_sampling_prob,
          save_checkpoints_every,
          fusion_place,
          network_type,
          balance_classes,
          num_auccess_actions,
          eval_every,
          action_layers,
          action_hidden_size,
          cosine_scheduler,
          n_samples_per_task,
          use_sample_distance_aux_loss,
          framewise_contrastive_n_frames,
          aux_loss_hyperparams,
          checkpoint_dir,
          tensorboard_dir="",
          dev_tasks_ids=None,
          debug=False,
          **excess_kwargs):
    logging.info(f"Starting traing rank={rank} world={world_size}")
    pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    mp_setup(rank, world_size)
    if torch.cuda.device_count() > 0:
        torch.cuda.set_device(rank)
    is_master = rank + 1 == world_size
    if not is_master:
        tensorboard_dir = None
        dev_tasks_ids = None

    if tensorboard_dir:
        from torch.utils.tensorboard import SummaryWriter
        logging.info("Tensorboard dir :" + tensorboard_dir)
        writer = SummaryWriter(tensorboard_dir)

    logging.info('Preprocessing train data')
    if debug:
        logging.warning("Debugging ON")
        logging.warning("Subsampling dataset")
        task_ids = task_ids[:10]
        if dev_tasks_ids:
            dev_tasks_ids = dev_tasks_ids[:10]
    training_data = cache.get_sample(task_ids, max_train_actions)
    task_indices, is_solved, actions, simulator, observations, positive_in_task, negative_in_task = \
        compact_simulation_data_to_trainset(action_tier_name, **training_data)
    logging.info('Creating eval subset from train')
    eval_train = create_balanced_eval_set(cache, simulator.task_ids,
                                          XE_EVAL_SIZE, action_tier_name)
    if dev_tasks_ids is not None:
        logging.info('Creating eval subset from dev')
        eval_dev = create_balanced_eval_set(cache, dev_tasks_ids, XE_EVAL_SIZE,
                                            action_tier_name)
    else:
        eval_dev = None

    aux_loss_eval = None
    aux_loss_eval_dev = None
    if use_sample_distance_aux_loss:
        logging.info("Creating eval set for auxiliary loss from train")
        aux_loss_eval = create_metric_eval_set(task_ids, action_tier_name,
                                               cache, AUX_LOSS_EVAL_TASKS,
                                               AUX_EVAL_ACTIONS_PER_TASK,
                                               simulator)
        if dev_tasks_ids is not None:
            logging.info("Creating eval set for auxiliary loss from dev")
            aux_loss_eval_dev = create_metric_eval_set(
                dev_tasks_ids,
                action_tier_name,
                cache,
                AUX_LOSS_EVAL_TASKS,
                AUX_EVAL_ACTIONS_PER_TASK,
                simulator=None)
    logging.info('Tran set: size=%d, positive_ratio=%.2f%%', len(is_solved),
                 is_solved.float().mean().item() * 100)

    assert not balance_classes or (negative_sampling_prob == 1), (
        balance_classes, negative_sampling_prob)

    if torch.cuda.device_count() > 0:
        assert nets.DEVICE != torch.device("cpu")
        device = f"cuda:{rank}"
    else:
        device = "cpu"
    model_kwargs = dict(network_type=network_type,
                        action_space_dim=simulator.action_space_dim,
                        fusion_place=fusion_place,
                        action_hidden_size=action_hidden_size,
                        action_layers=action_layers)
    if use_sample_distance_aux_loss:
        model_kwargs.update(
            dict(
                embedding_dim=aux_loss_hyperparams["embedding_dim"],
                embeddor_type=aux_loss_hyperparams["embeddor_type"],
                repr_merging_method=None))

    if network_type == "framewise_contrastive":
        model_kwargs["n_frames"] = framewise_contrastive_n_frames

    model = build_model(**model_kwargs)
    model.to(device)
    logging.debug("net {} DistributedDataParallel".format(rank))

    if torch.cuda.device_count() > 0:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True, broadcast_buffers=False)
    else:
        model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True, broadcast_buffers=False)
    model.train()
    logging.info(model)
    logging.info(f"Model will use {device} to train")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if cosine_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=updates)
    else:
        scheduler = None
    logging.info('Starting actual training for %d updates', updates)

    rng = np.random.RandomState(42 + rank)

    def task_balanced_sampler():
        n_tasks_per_batch = train_batch_size // n_samples_per_task
        n_tasks_total = len(simulator.initial_scenes)
        n_samples_per_task_label = n_samples_per_task // 2
        while True:
            batch_task_indices = rng.choice(n_tasks_total,
                                            size=n_tasks_per_batch,
                                            replace=debug)
            indices = []
            for task_idx in batch_task_indices:
                indices.append(
                    rng.choice(negative_in_task[task_idx],
                               n_samples_per_task_label))
                try:
                    indices.append(
                        rng.choice(positive_in_task[task_idx],
                                   n_samples_per_task_label))
                except ValueError:  # No solutions for given template in dataset
                    pass
            indices = np.concatenate(indices)
            yield indices, batch_task_indices

    def train_indices_sampler():
        indices = np.arange(len(is_solved))
        if balance_classes:
            solved_mask = is_solved.numpy() > 0
            positive_indices = indices[solved_mask]
            negative_indices = indices[~solved_mask]
            positive_size = train_batch_size // 2
            while True:
                positives = rng.choice(positive_indices, size=positive_size)
                negatives = rng.choice(negative_indices,
                                       size=train_batch_size - positive_size)
                positive_size = train_batch_size - positive_size
                yield np.stack((positives, negatives), axis=1).reshape(-1), None
        elif negative_sampling_prob < 1:
            probs = (is_solved.numpy() * (1.0 - negative_sampling_prob) +
                     negative_sampling_prob)
            probs /= probs.sum()
            while True:
                yield rng.choice(indices, size=train_batch_size, p=probs), None
        else:
            while True:
                yield rng.choice(indices, size=train_batch_size), None


    last_checkpoint = get_latest_checkpoint(checkpoint_dir)
    batch_start = 0
    if last_checkpoint is not None:
        logging.info('Going to load from %s', last_checkpoint)
        last_checkpoint = torch.load(last_checkpoint)
        model.load_state_dict(last_checkpoint['model'])
        optimizer.load_state_dict(last_checkpoint['optim'])
        rng.set_state(last_checkpoint['rng'])
        # create a difference between rng across processes
        for i in range(rank):
            rng.random()
        batch_start = last_checkpoint['done_batches']
        if scheduler is not None:
            scheduler.load_state_dict(last_checkpoint['scheduler'])

    def print_eval_stats(batch_id):
        return


    report_every = 125 if not debug else 5
    logging.info('Report every %d; eval every %d', report_every, eval_every)
    if save_checkpoints_every > eval_every:
        save_checkpoints_every -= save_checkpoints_every % eval_every

    if batch_start <= 1:
        print_eval_stats(0)

    losses = []
    aux_losses = []
    ce_losses = []
#    is_solved = is_solved.pin_memory()

    simulator = phyre.initialize_simulator(task_ids, action_tier_name)

    if n_samples_per_task == 0:
        sampler = train_indices_sampler()
    else:
        assert (n_samples_per_task % 2 == 0 and balance_classes)
        sampler = task_balanced_sampler()

    start = time.time()
    x = 0
    n_examples = 0
    tt = TimingCtx()
    last_time = time.time()
    for batch_id, (batch_indices,
                   unique_task_indices) in enumerate(sampler,
                                                     start=batch_start):
        if batch_id >= updates:
            break
        if scheduler is not None:
            scheduler.step()
        x += 1
        n_examples += len(batch_indices)
        if batch_id < 32 or (batch_id & (batch_id + 1)) == 0:
            print("Speed %s: batches=%.2f, eps=%.2f %s" % (rank, x / (time.time() - start), n_examples / (time.time() - start), {k: v/ x for k, v in tt.items()}))
        model.train()
        batch_task_indices = task_indices[batch_indices]
        batch_actions = actions[batch_indices]
        batch_is_solved = is_solved[batch_indices].to(device, non_blocking=True)

        tt.start("sim")

        batch_videos = []
        for idx, action in zip(batch_task_indices.tolist(), batch_actions.numpy()):
            # print(idx, action)
            simulation = simulator.simulate_action(
                idx,
                action,
                need_images=True,
                get_random_images=framewise_contrastive_n_frames,
            )

            batch_videos.append(simulation.images)

        batch_videos = np.stack(batch_videos, 0)

        tt.start("rest")
        batch_videos = torch.from_numpy(batch_videos).to(device)
        batch_observations = batch_videos[:, 0]
        batch_contrastive_targets = batch_videos[:, 1:]

        tt.start("model")
        if use_sample_distance_aux_loss:
            logits, embeddings1, embeddings2 = model(batch_observations, batch_contrastive_targets)
            aux_loss = sample_distance_loss(model, embeddings1, embeddings2, aux_loss_hyperparams, rng)

        else:
            logits = model(batch_observations)
            aux_loss = torch.tensor([0.0]).to(device)

        #TODO add tensorboard, detailed evaluator and parallel simulator
        classification_loss = nets.FramewiseResnetModel.ce_loss(logits, batch_is_solved).mean()
        loss = classification_loss + aux_loss * aux_loss_hyperparams["weight"]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.mean().item())
        aux_losses.append(aux_loss.mean().item())
        ce_losses.append(classification_loss.mean().item())
        tt.start("rest")
        if save_checkpoints_every > 0 and is_master:
            if (batch_id + 1) % save_checkpoints_every == 0 or (batch_id +
                                                                1) == updates:
                fpath = os.path.join(checkpoint_dir, 'ckpt.%08d' % (batch_id + 1))
                logging.info('Saving: %s', fpath)
                torch.save(
                    dict(
                        model_kwargs=model_kwargs,
                        model=model.state_dict(),
                        optim=optimizer.state_dict(),
                        done_batches=batch_id + 1,
                        rng=rng.get_state(),
                        scheduler=scheduler and scheduler.state_dict(),
                    ), fpath)
        if (batch_id + 1) % eval_every == 0 and is_master:
            print_eval_stats(batch_id)
        if (batch_id + 1) % report_every == 0 and is_master:
            speed = report_every / (time.time() - last_time)
            last_time = time.time()
            logging.debug(
                'Iter: %s, examples: %d, mean loss: %f, mean ce: %f, mean aux: %f,'
                'speed: %.1f batch/sec, lr: %f', batch_id + 1,
                (batch_id + 1) * train_batch_size,
                np.mean(losses[-report_every:]),
                np.mean(ce_losses[-report_every:]),
                np.mean(aux_losses[-report_every:]), speed, get_lr(optimizer))
    return model.cpu()