in chatlearn/models/base_module.py [0:0]
def build_dataloader(self,
all_datasets,
sample_per_episode,
collate_fn=None,
is_eval=False,
consumed_samples=0,
data_ratio=None,
shuffle=True,
drop_last=False,
data_rerank=True):
"""
build the dataloader for the model
Args:
all_datasets: a list of torch.utils.data.Dataset objects
batch_size: how many samples per batch to load
collate_fn: set when loading from an map-style dataset (defulat: `None`)
is_eval: set to `True` to build a dataloader for evaluation (default: `False`)
consumed_samples: consumed samples (default: `0`)
data_ratio: ratio of samples for each dataset (default: `None`)
drop_last: whether to drop last samples (default: `False`)
:meta private:
"""
log_rank_0(
f"Creating DataLoader... consumed_samples: {consumed_samples}, "
f"data_ratio: {data_ratio}",
self._logger
)
if "num_inference_per_prompt" in self.model_args:
num_inference_per_prompt = self.model_args["num_inference_per_prompt"]
else:
num_inference_per_prompt = 1
vllm_prompt_key = self.model_args["vllm_prompt_key"] \
if "vllm_prompt_key" in self.model_args else "prompt"
self._logger.info(f"====Data Rerank: {data_rerank}")
if is_eval:
batch_sampler = MultiDatasetSampler(
dataset_sizes=[len(dataset) for dataset in all_datasets],
sample_per_episode=sample_per_episode,
shuffle=False,
is_eval=True,
data_parallel_rank=self.replica_id,
data_parallel_size=self._num_replica
)
else:
batch_sampler = MultiDatasetSampler(
dataset_sizes=[len(dataset) for dataset in all_datasets],
sample_per_episode=sample_per_episode,
data_ratio=data_ratio,
consumed_samples=consumed_samples,
num_inference_per_prompt=num_inference_per_prompt,
shuffle=shuffle,
is_eval=False,
data_parallel_rank=self.replica_id,
data_parallel_size=self._num_replica,
drop_last="drop" if drop_last else "cycle",
data_rerank=data_rerank
)
return RLHFDataLoader(
all_datasets,
batch_sampler,
collate_fn=collate_fn,
add_uid=True,
data_parallel_rank=self.replica_id,
data_parallel_size=self._num_replica,
num_inference_per_prompt=num_inference_per_prompt,
vllm_prompt_key=vllm_prompt_key
)