chatlearn/data/data.py (290 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""data processing."""
import math
import random
import copy
import os
from typing import List, Dict
import ray
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import default_collate
from chatlearn.utils import future
from chatlearn.utils.constant import CHATLEARN_REGROUP_TAG
from chatlearn.utils.utils import regroup_by_concat_along_batch, map_reduce_metrics
def get_iter_keys(data):
"""
get iterator keys
"""
if isinstance(data, (list, tuple)):
return range(len(data))
if isinstance(data, dict):
return data.keys()
raise ValueError(f"Only list or dict type is accepted, but got type {type(data)}")
def create_from_type(data):
"""
create collection from data type
"""
if isinstance(data, (list, tuple)):
return [None] * len(data)
return type(data)()
def batching(tensors, padding_value=0.0, padding_type="right"):
"""
batch tensors
"""
if isinstance(tensors[0], torch.Tensor):
if tensors[0].dim() == 0:
return torch.stack(tensors)
if padding_type == "right":
return pad_sequence(tensors, batch_first=True, padding_value=padding_value)
return pad_sequence([elem.flip(0) for elem in tensors],
padding_value=padding_value,
batch_first=True).flip(1)
batch = create_from_type(tensors[0])
batch_size = len(tensors)
for key in get_iter_keys(tensors[0]):
pad = padding_value.get(key, 0.0) if isinstance(padding_value,
dict) else padding_value
ptype = padding_type.get(key, "right") if isinstance(padding_type, dict) else padding_type
batched = [tensors[j][key] for j in range(batch_size)]
if isinstance(batched[0], torch.Tensor):
batched = batching(batched, pad, ptype)
batch[key] = batched
return batch
def split_batch(batch):
"""
split batch into samples
"""
assert isinstance(batch, (list, tuple, dict)), \
"batch type {} is not supported".format(type(batch))
samples = []
if isinstance(batch, (list, tuple)):
batch_size = len(batch[0])
keys = range(len(batch))
else:
batch_size = len(next(iter(batch.values())))
keys = batch.keys()
for batch_index in range(batch_size):
if isinstance(batch, (list, tuple)):
sample = [batch[key][batch_index] for key in keys]
else:
sample = {key: batch[key][batch_index] for key in keys}
samples.append(sample)
return samples
def batch_shuffle(data, batch_size):
num_batches = len(data) // batch_size
batches = [data[batch_size*i:batch_size*(i+1)] for i in range(num_batches)]
random.shuffle(batches)
shuffled_data = [item for batch in batches for item in batch]
return shuffled_data
@ray.remote
class StreamDataset:
"""dataset built from queues"""
def __init__(self, data_loader_type, micro_batch_size, padding_config=None, max_relay_episode=0, relay_episode_offset=0, global_batch_size=-1):
"""
Args:
data_loader_type: fixed or dynamic
"""
if data_loader_type == "fixed":
self._dynamic_dataset = False
else:
self._dynamic_dataset = True
self.batch_size = micro_batch_size
self._padding_config = padding_config if padding_config is not None else {}
self._padding_value = {key: value["padding_value"] for key, value in padding_config.items()}
self._padding_type = {key: value["padding_type"] for key, value in padding_config.items()}
if max_relay_episode < 0:
max_relay_episode = math.inf
self._max_relay_episode = max_relay_episode
self._relay_episode_offset = relay_episode_offset
self._episode_relay_buffers = []
self.relay_sample_manager = None
# ChunkFlow: Params for ChunkFlow
self.prefetch_batch_cnt= micro_batch_size if global_batch_size < 0 else global_batch_size
def shuffle(self, batch_size=None):
"""
shuffle relay buffer
"""
self.relay_buffer.shuffle(batch_size)
self.iter = self.__iter__() # pylint: disable=unnecessary-dunder-call
self._has_next = True
def __iter__(self):
if self._dynamic_dataset and not self._read_data_complete:
return self.iter_dynamic()
return self.iter_fixed()
def _get_batch(self, start_index):
end_index = min(start_index + self.batch_size, len(self.relay_buffer))
data_to_batch = self.relay_buffer.get_samples(start_index, end_index)
if len(data_to_batch) < self.batch_size:
data_to_batch += self.relay_buffer.get_samples(0, self.batch_size - len(data_to_batch))
batched_data = batching(data_to_batch, self._padding_value, self._padding_type)
return batched_data
def iter_fixed(self):
"""
iteration with fixed batch size
"""
produce_index = 0
batch_count = 0
prefetched_batch_list = []
while produce_index < self._total_samples:
# read from cache
if len(self.relay_buffer) < self._total_samples:
while len(self.relay_buffer) < self._total_samples and \
(len(self.relay_buffer) - produce_index) < self.batch_size:
self.relay_buffer.add_raw_batch()
prefetched_batch_list.append(self._get_batch(produce_index))
if len(prefetched_batch_list) == self.prefetch_batch_cnt:
# ChunkFlow: Sort by sample length for better balance across data parallel ranks
# TODO: fix hardcode key for sample len
if "response_ids" in prefetched_batch_list[0].keys():
prefetched_batch_list.sort(key=lambda x: len(x["response_ids"][0]))
for batched_data in prefetched_batch_list:
yield batched_data
batch_count += 1
prefetched_batch_list.clear()
produce_index += self.batch_size
if len(prefetched_batch_list) != 0:
for batched_data in prefetched_batch_list:
yield batched_data
batch_count += 1
assert batch_count == math.ceil(self._total_samples / self.batch_size)
assert produce_index >= len(self.relay_buffer), \
f"produce_index: {produce_index} < len(self.relay_buffer) {len(self.relay_buffer)}"
def iter_dynamic(self):
"""
iteration with dynamic batch size
"""
produce_index = 0
if self._read_data_complete:
return self.iter_fixed()
batch_count = 0
while self.relay_buffer.queue_not_empty():
while self.relay_buffer.queue_not_empty() and \
(len(self.relay_buffer) - produce_index) < self.batch_size:
# get from queue
self.relay_buffer.add_raw_batch()
batched_data = self._get_batch(produce_index)
yield batched_data
batch_count += 1
produce_index += self.batch_size
self._read_data_complete = True
assert len(self.relay_buffer) == self._total_samples
self._num_batches = batch_count
def next(self):
"""get next batch"""
try:
data = next(self.iter)
return data
except StopIteration:
self._has_next = False
return None
def has_next(self):
"""
has next batch
"""
return self._has_next
def set_dataset(self, queue, episode_id, relay_sample_manager=None, sample_per_episode=-1):
relay_buffer = EpisodeRelayBuffer(episode_id, queue=queue)
if self._max_relay_episode > 0 and episode_id >= self._relay_episode_offset:
self._episode_relay_buffers.append(relay_buffer)
if len(self._episode_relay_buffers) > self._max_relay_episode:
old_buffer = self._episode_relay_buffers.pop(0)
del old_buffer
# this function will sync until all data computing finished,
# which will block training until environment rollout finished.
if os.getenv("SKIP_GENERATION", None) is None:
relay_buffer.sync()
if relay_sample_manager is None:
raise Exception("default relay sample function is not currently supported")
self.relay_sample_manager = relay_sample_manager
buffer = self.relay_sample_manager(self._episode_relay_buffers)
self.relay_buffer = EpisodeRelayBuffer(episode_id, buffer=buffer)
self._total_samples = len(self.relay_buffer)
self._read_data_complete = True
else:
num_rollout_batches = queue.qsize()
self.relay_buffer = relay_buffer
self.relay_buffer.add_raw_batch()
assert sample_per_episode != -1, "In fixed batch size, you must set sample_per_episode for StreamDataset."
self._total_samples = sample_per_episode
self._read_data_complete = num_rollout_batches <= 1
self.iter = iter(self)
self._has_next = True
def episode_relay_buffers(self):
return self._episode_relay_buffers
def total_samples(self):
return self._total_samples
def batch_per_episode(self):
return math.ceil(self._total_samples / self.batch_size)
def get_and_clear_metrics(self):
# TODO: deal with situation that relay_sample_manager is None
try:
return self.relay_sample_manager.get_and_clear_metrics()
except Exception:
return "no relay", {}
class EpisodeRelayBuffer:
"""EpisodeRelayBuffer"""
def __init__(self, episode_id, queue=None, buffer=None):
self._episode_id = episode_id
assert (queue is None or buffer is None) and (queue is not None or buffer is not None)
if buffer is not None:
assert queue is None
self._buffer = buffer
else:
assert queue is not None
self._buffer = []
self.queue = queue
self._rollout_batch_size = -1
def add_raw_batch(self):
if self.queue.qsize() == 0:
raise ValueError("WARN: data queue is empty")
# get from queue
data = self.queue.get()
merged_data = {}
for item in data:
local_data = future.get(item)
if CHATLEARN_REGROUP_TAG in local_data:
local_data = regroup_by_concat_along_batch(local_data[CHATLEARN_REGROUP_TAG])
merged_data.update(local_data)
samples = split_batch(merged_data)
if self._rollout_batch_size < 0:
self._rollout_batch_size = len(samples)
self._buffer += samples
return samples
def queue_not_empty(self):
return self.queue.qsize() > 0
def shuffle(self, batch_size):
if batch_size is None:
random.shuffle(self._buffer)
return
self._buffer = batch_shuffle(self._buffer, batch_size)
def get_samples(self, start_index, end_index):
return self._buffer[start_index: end_index]
def __len__(self):
return len(self._buffer)
def sync(self):
while self.queue_not_empty():
self.add_raw_batch()
@property
def buffer(self):
return self._buffer
@property
def episode_id(self):
return self._episode_id
class RelaySampleManager:
"""
Relay sample Manager, users should inherit it to self-defined relay samples for trainer
"""
def __init__(self, global_args):
self.args = global_args
self._metric_prefix = "relay"
self._metric_list = []
def __call__(self, episode_relay_buffers: List[EpisodeRelayBuffer]) -> List[Dict]:
raise NotImplementedError("default relay sample function is not currently supported")
def get_and_clear_metrics(self):
if self._metric_list is None or len(self._metric_list) == 0:
return self._metric_prefix, {}
reduced_metrics = map_reduce_metrics(self._metric_list)
self._metric_list = []
return self._metric_prefix, reduced_metrics
class RLHFDataLoader:
"""
RLHF data loader
"""
def __init__(
self,
datasets,
sampler,
collate_fn=None,
add_uid=False,
data_parallel_rank=0,
data_parallel_size=1,
num_inference_per_prompt=1,
vllm_prompt_key="prompt"):
"""generate prompts data loader"""
self.datasets = datasets
self.dataset_num = len(self.datasets)
self.sampler = sampler
self.collate_fn = collate_fn
self.add_uid = add_uid
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.num_inference_per_prompt = num_inference_per_prompt
self.vllm_prompt_key = vllm_prompt_key
def __iter__(self):
self.sampler_iter = iter(self.sampler)
while True:
try:
batch_idxes = next(self.sampler_iter)
batch = [self.datasets[dataset_idx][data_idx] for dataset_idx, data_idx, _ in batch_idxes]
id_in_episode = [id for _, _, id in batch_idxes]
if self.add_uid:
batch = self.update_data_uid(batch, id_in_episode)
if self.collate_fn is not None:
yield self.collate_fn(batch)
else:
yield default_collate(batch)
except StopIteration:
self.sampler_iter = iter(self.sampler)
def update_data_uid(self, batch, id_in_episode):
updated_batch = []
for i, data in enumerate(batch):
if isinstance(data, dict) and self.vllm_prompt_key in data \
and isinstance(data[self.vllm_prompt_key], dict):
copy_data = copy.deepcopy(data)
copy_data[self.vllm_prompt_key]['uid'] = str(id_in_episode[i])
updated_batch.append(copy_data)
else:
updated_batch.append(data)
return updated_batch