def _update_dataloader_based_on_training_stages()

in src/nanotron/trainer.py [0:0]


    def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
        from collections.abc import Generator

        if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
            if self.current_dataloader is None:
                if isinstance(dataloaders, tuple):
                    dataloader = dataloaders[0]
                else:
                    dataloader = dataloaders
                self.current_dataloader = sanity_check_dataloader(
                    dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
                )
                self.current_base_dl = dataloader
            return
        elif isinstance(dataloaders, Generator):
            # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
            # remove this in the next PR
            self.current_dataloader = dataloaders
            return

        assert len(dataloaders) > 0, "No dataloaders provided"

        def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
            import gc

            log_rank(
                f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
                logger=logger,
                level=logging.INFO,
            )
            self.current_base_dl = None

            # NOTE: Clear dataloader from memory
            del dataloader.dataset
            del dataloader.sampler
            del dataloader.batch_sampler

            gc.collect()

        dataloader = None

        def find_stage_idx_to_resume():
            reversed_data_stages = sorted(self.config.data_stages, key=lambda x: x.start_training_step, reverse=True)
            for idx, stage in enumerate(reversed_data_stages):
                if self.iteration_step >= stage.start_training_step:
                    return len(self.config.data_stages) - idx - 1
            return None

        stage_idx_to_resume = find_stage_idx_to_resume()

        for stage_idx, stage in enumerate(self.config.data_stages):
            if stage_idx < self.metadata.last_stage_idx:
                continue

            stage = cast(DatasetStageArgs, stage)

            is_resume_from_training = self.current_dataloader is None and stage_idx_to_resume == stage_idx
            if (stage.start_training_step == self.iteration_step) or is_resume_from_training:
                if self.current_dataloader is not None:
                    prev_stage_name = self.config.data_stages[stage_idx - 1].name
                    prev_dataloader = dataloaders[prev_stage_name]

                    if isinstance(prev_dataloader, DataLoader):
                        # NOTE: we don't need to clear dummy data generator from memory
                        clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)

                self.metadata.last_stage_idx = stage_idx

                if is_resume_from_training:
                    remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
                        stage, self.config, self.metadata
                    )
                    (
                        consumed_train_steps,
                        consumed_tokens_per_dataset_folder,
                    ) = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata)
                    log_rank(
                        f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps"
                        f"\nConsumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}",
                        logger=logger,
                        level=logging.INFO,
                        rank=0,
                    )

                dataloader = dataloaders[stage.name]
                # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
                dataloader = dataloader() if callable(dataloader) else dataloader
                break

        if dataloader is not None:
            self.current_dataloader = sanity_check_dataloader(
                dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
            )
            self.current_base_dl = dataloader