def in_contexter()

in vision/m4/evaluation/evaluators/in_contexter.py [0:0]


def in_contexter(task, accelerator, model, args):
    vision_encoder, vision_encoder_processor, dummy_accelerator = None, None, None
    if (args.tasks.in_context_params.shot_selection_mode != ShotSelectionMode.random) and (
        args.tasks.in_context_params.num_shots != 0
    ):
        vision_encoder, vision_encoder_processor = load_vision_encoder(args)

        kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(seconds=args.hparams.timeout))]
        rngs_types = ["torch", "cuda", "generator"] if torch.cuda.is_available() else ["torch", "generator"]
        dummy_accelerator = Accelerator(rng_types=rngs_types, kwargs_handlers=kwargs_handlers)
        dummy_dataloader = torch.utils.data.DataLoader(
            [0 for _ in range(10)], batch_size=args.hparams.batch_size_per_gpu
        )
        vision_encoder, dummy_dataloader = dummy_accelerator.prepare(vision_encoder, dummy_dataloader)
        vision_encoder = dummy_accelerator.unwrap_model(vision_encoder)

    # If, in few-shot, with deepspeed, with several processes, the compute of the embeddings are hanging forever,
    # either remove `if accelerator.is_main_process`, or compute first the embeddings with a setting that works
    # (with 1 process or with pure accelerate for example)
    if accelerator.is_main_process:
        support_dataset, query_dataset, support_dataset_vision_encoder_embeddings = _get_datasets(
            task, args, vision_encoder, vision_encoder_processor
        )
    accelerator.wait_for_everyone()

    if not accelerator.is_main_process:
        support_dataset, query_dataset, support_dataset_vision_encoder_embeddings = _get_datasets(
            task, args, vision_encoder, vision_encoder_processor
        )
    accelerator.wait_for_everyone()
    logger.warning(f"support_dataset: {support_dataset}")
    logger.warning(f"query_dataset: {query_dataset}")

    del vision_encoder
    del vision_encoder_processor
    del dummy_accelerator

    show_gpu_mem_util(args)

    data_loader = build_dataloader(
        task, model, args, support_dataset, query_dataset, support_dataset_vision_encoder_embeddings, accelerator
    )

    if args.hparams.only_load_datasets:
        return

    metric_class = getattr(custom_metrics, task.metric_name)
    metric_kwargs = task.metric_kwargs if task.metric_kwargs is not None else {}
    save_generations = args.tasks.save_generations
    experiment_id = str(uuid.uuid4())
    experiment_id = broadcast_object_list([experiment_id])[0]
    metric = metric_class(
        experiment_id=experiment_id,
        num_process=accelerator.num_processes,
        process_id=accelerator.process_index,
        save_generations=save_generations,
        **metric_kwargs,
    )
    for batch in tqdm(data_loader, desc="Compute scores:"):
        # Splits batches that get augmented by data_collator. Mostly usefull for classification tasks
        mini_batches = split_batch(batch, chunk_size=args.hparams.batch_size_per_gpu)
        show_gpu_mem_util(args)
        for mini_batch in mini_batches:
            if (
                "ClassificationInContext" in task.__class__.__name__
                or "ClassificationVQAInContext" in task.__class__.__name__
                or "PerplexityInContext" in task.__class__.__name__
                or "ImageCaptionMatching" in task.__class__.__name__
            ):
                kwargs = {"model": model, **mini_batch}
            elif (
                "OpenEndedVQAInContext" in task.__class__.__name__
                or "ImageCaptioningInContext" in task.__class__.__name__
            ):
                kwargs = {
                    "model": model,
                    "num_beams": args.tasks.text_generation_params.num_beams,
                    "no_repeat_ngram_size": args.tasks.text_generation_params.no_repeat_ngram_size,
                    "max_new_tokens": args.tasks.text_generation_params.max_new_tokens,
                    **mini_batch,
                }
            else:
                raise ValueError(
                    f"Task class ({task.__class__.__name__}) is not supported. Expected it to be among"
                    " ['ClassificationInContext', 'OpenEndedVQAInContext', 'ImageCaptioningInContext',"
                    " 'PerplexityInContext', ImageCaptionMatching]."
                )
            accelerator.wait_for_everyone()
            metric = task.add_batch_metric(metric, **kwargs)

    # Trick suggested here: https://huggingface.slack.com/archives/C02UAKD75L7/p1664475037694469?thread_ts=1664461500.952079&cid=C02UAKD75L7
    if not accelerator.is_main_process:
        score = metric.compute()
    accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        score = metric.compute()
    return score