def _prepare_megatron_lm()

in src/accelerate/accelerator.py [0:0]


    def _prepare_megatron_lm(self, *args):
        megatron_lm_plugin = self.state.megatron_lm_plugin
        micro_batch_size = None
        if not megatron_lm_plugin.megatron_dataset_flag:
            batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
            if len(batch_sizes) == 0:
                raise ValueError(
                    "You must specify a training or evaluation dataloader in `accelerate.prepare()` when using Megatron-LM."
                )

            micro_batch_size = min(batch_sizes) if megatron_lm_plugin.is_train_batch_min else max(batch_sizes)
            if len(batch_sizes) > 1:
                logger.info(
                    "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
                    f"{megatron_lm_plugin.is_train_batch_min} will decide the `train_batch_size` ({micro_batch_size})."
                )
        else:
            for obj in args:
                if isinstance(obj, MegatronLMDummyDataLoader):
                    micro_batch_size = obj.dataset_args["micro_batch_size"]
                    break
        if micro_batch_size is not None:
            dp_degree = self.num_processes // (megatron_lm_plugin.tp_degree * megatron_lm_plugin.pp_degree)
            megatron_lm_plugin.set_training_args(micro_batch_size, dp_degree)
        else:
            raise ValueError(
                "When you do not pass the dataloader parameter, the `data_parallel_size`, "
                "`micro_batch_size`, and `global_batch_size` megatron parameters will not be updated."
            )
        model = None
        optimizer = None
        scheduler = None
        batch_data = None
        for obj in args:
            if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None:
                batch_data = next(iter(obj))
            elif isinstance(obj, torch.nn.Module):
                model = obj
            elif isinstance(obj, (torch.optim.Optimizer)):
                optimizer = obj
            elif isinstance(obj, (LRScheduler, MegatronLMDummyScheduler)):
                scheduler = obj

        if model is not None:
            megatron_lm_plugin.set_network_size_args(model, batch_data)
        if optimizer is not None:
            megatron_lm_plugin.set_optimizer_type(optimizer)
        if scheduler is not None:
            if not isinstance(scheduler, MegatronLMDummyScheduler):
                raise ValueError(
                    "You can't use a custom scheduler with Megatron-LM. Please use the `accelerate.utils.MegatronLMDummyScheduler` instead."
                )
            megatron_lm_plugin.set_scheduler_args(scheduler)

        # initialize megatron-lm
        megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args)

        (model, optimizer, scheduler) = megatron_lm_prepare_model_optimizer_scheduler(self)
        self.wait_for_everyone()

        counter = 0
        result = []
        for obj in args:
            if isinstance(obj, torch.utils.data.DataLoader):
                result.append(megatron_lm_prepare_data_loader(self, obj))
                counter += 1
            elif isinstance(obj, MegatronLMDummyDataLoader):
                if counter == 0:
                    obj.set_megatron_data_args()
                    dataloaders = megatron_lm_prepare_data_loader(self, obj)
                result.append(dataloaders[counter])
                counter += 1
            else:
                result.append(obj)

        if model is not None:
            model = MegatronEngine(self, model, optimizer, scheduler)
        if optimizer is not None:
            optimizer = MegatronLMOptimizerWrapper(optimizer)
        if scheduler is not None:
            scheduler = MegatronLMSchedulerWrapper(scheduler, optimizer)

        for i in range(len(result)):
            if isinstance(result[i], torch.nn.Module):
                result[i] = model
            elif isinstance(result[i], torch.optim.Optimizer):
                result[i] = optimizer
            elif isinstance(result[i], MegatronLMDummyScheduler):
                result[i] = scheduler

        if model is not None:
            self._models.append(model)
            if len(self._models) > 1:
                raise AssertionError(
                    "You can't use same `Accelerator()` instance with multiple models when using Megatron-LM"
                )
        if optimizer is not None:
            self._optimizers.append(optimizer)
        if scheduler is not None:
            self._schedulers.append(scheduler)

        return tuple(result)