def build_dataloader()

in chatlearn/models/base_module.py [0:0]


    def build_dataloader(self,
                         all_datasets,
                         sample_per_episode,
                         collate_fn=None,
                         is_eval=False,
                         consumed_samples=0,
                         data_ratio=None,
                         shuffle=True,
                         drop_last=False,
                         data_rerank=True):
        """
        build the dataloader for the model
        Args:
            all_datasets: a list of torch.utils.data.Dataset objects
            batch_size: how many samples per batch to load
            collate_fn: set when loading from an map-style dataset (defulat: `None`)
            is_eval: set to `True` to build a dataloader for evaluation (default: `False`)
            consumed_samples: consumed samples (default: `0`)
            data_ratio: ratio of samples for each dataset (default: `None`)
            drop_last: whether to drop last samples (default: `False`)

        :meta private:
        """
        log_rank_0(
            f"Creating DataLoader... consumed_samples: {consumed_samples}, "
            f"data_ratio: {data_ratio}",
            self._logger
        )
        if "num_inference_per_prompt" in self.model_args:
            num_inference_per_prompt = self.model_args["num_inference_per_prompt"]
        else:
            num_inference_per_prompt = 1
        vllm_prompt_key = self.model_args["vllm_prompt_key"] \
            if "vllm_prompt_key" in self.model_args else "prompt"
        self._logger.info(f"====Data Rerank: {data_rerank}")
        if is_eval:
            batch_sampler = MultiDatasetSampler(
                dataset_sizes=[len(dataset) for dataset in all_datasets],
                sample_per_episode=sample_per_episode,
                shuffle=False,
                is_eval=True,
                data_parallel_rank=self.replica_id,
                data_parallel_size=self._num_replica
            )
        else:
            batch_sampler = MultiDatasetSampler(
                dataset_sizes=[len(dataset) for dataset in all_datasets],
                sample_per_episode=sample_per_episode,
                data_ratio=data_ratio,
                consumed_samples=consumed_samples,
                num_inference_per_prompt=num_inference_per_prompt,
                shuffle=shuffle,
                is_eval=False,
                data_parallel_rank=self.replica_id,
                data_parallel_size=self._num_replica,
                drop_last="drop" if drop_last else "cycle",
                data_rerank=data_rerank
            )
        return RLHFDataLoader(
            all_datasets,
            batch_sampler,
            collate_fn=collate_fn,
            add_uid=True,
            data_parallel_rank=self.replica_id,
            data_parallel_size=self._num_replica,
            num_inference_per_prompt=num_inference_per_prompt,
            vllm_prompt_key=vllm_prompt_key
        )