chatlearn/runtime/environment.py (247 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Environment"""
import math
from itertools import cycle
from chatlearn.data.ranking import batch_generation_ranking
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.utils import future
from chatlearn.utils.logger import logger
from chatlearn.utils.utils import execute_in_parallel
from .executor import Executor
from .utils import encode_data
# pylint: disable=not-callable
class Environment(Executor):
"""BaseEnv"""
def __init__(self, model_flow):
"""
Environment
Args
----
models : List[BaseModule]
a list of modules
"""
super().__init__(model_flow)
self._batch_size = None
self._batch_per_episode = None
self._all_datasets = None
self.data_iter = None
self._padding_config = {}
def set_dataset(self, dataset):
"""Set dataset for the environment.
Args:
dataset (list): a list of prompts strs
Returns:
Environment instance: return environment
"""
assert isinstance(dataset, list), (
f"expect the dataset to be a list of prompts, got {type(dataset)}"
)
assert not isinstance(dataset[0], list), (
"expect only one dataset to be set, if you want to use more "
"than one dataset, please try `set_multiple_datasets`"
)
self._all_datasets = [dataset]
return self
def set_multiple_datasets(self, all_datasets):
"""Set multiple datasets for the environment.
Args:
dataset (list): a list of prompts strs
Returns:
Environment instance: return environment
"""
# sanity check
assert len(all_datasets) >= 1, (
f"expect at least one dataset, got {len(all_datasets)} datasets."
)
assert isinstance(all_datasets, list), (
f"expect datasets to be a list, got {type(all_datasets)}"
)
for dataset in all_datasets:
assert isinstance(dataset, list), (
f"expect each dataset to be a list of prompts, got {type(dataset)}"
)
self._all_datasets = all_datasets
return self
def setup_dataset(self):
self.data_producer = self.models[0]
assert self.sample_per_episode % len(self.data_producer.replicas) == 0, \
"replica number of data producer model must be divisible by sample_per_episode"
logger.info("start set dataset for data_producer")
refs = []
if self.models[0].module_args.batch_generation.ranking:
for i, dataset in enumerate(self._all_datasets):
episode_per_epoch = math.ceil(len(dataset) / self.sample_per_episode)
self._all_datasets[i] = batch_generation_ranking(
dataset, episode_per_epoch, self.sample_per_episode
)
for policy_replica in self.data_producer.replicas:
ref = policy_replica.master._build_dataloader.remote(self._all_datasets,
self.sample_per_episode)
refs.append(ref)
future.get(refs)
logger.info("set dataset for data_producer done")
def setup(self):
super().setup()
self.setup_dataset()
for model_node in self.model_flow.model_nodes:
model = model_node.model.replicas[0]
config = future.get(model.master.padding_config.remote())
self._padding_config.update(config)
if isinstance(model.model, VLLMModuleV2):
logger.info(
f"setup vllm engine for model {model.model}")
refs = []
for replica in model_node.model.replicas:
refs.append(replica.vllm_engine.setup_vllm.remote(
replica.all_actors))
future.wait(refs, return_output=True)
@property
def sample_per_episode(self):
return self.args.sample_per_episode
def batch_size(self, model=None):
if model is None:
model = self.models[0]
if model.use_vllm_backend:
num_replica = len(model.replicas)
batch_size = self.sample_per_episode // num_replica
else:
batch_size = model.module_args.generation_batch_size
return batch_size
def batch_per_episode(self, model=None):
if model is None:
model = self.models[0]
num_replica = len(model.replicas)
if self.sample_per_episode >= num_replica:
_batch_per_episode = num_replica
else:
_batch_per_episode = self.sample_per_episode
dp_size = len(model.replicas[0].dp_rank_to_actors)
if dp_size > 1:
_batch_per_episode *= dp_size
return _batch_per_episode
def num_iteration(self, model=None):
"""Calculate the number of iterations for a model in the environment.
Args:
model: an model in environment. if None, use the first model. default: None.
Returns:
The number of iterations for the model in the environment
"""
if model is None:
model = self.models[0]
_batch_per_episode = self.batch_per_episode(model)
dp_size = len(model.replicas[0].dp_rank_to_actors)
if model.module_args.zero_size > 1:
assert _batch_per_episode % model.module_args.zero_size == 0
return _batch_per_episode // model.module_args.zero_size
elif dp_size > 1: # for trainable model or ep model
if _batch_per_episode < dp_size:
raise NotImplementedError(
"Currently ChaLearn requires batch_per_episode >= len(dp_rank_to_actors), "
f"got {_batch_per_episode} and {dp_size}. "
f"Please allocate more replicas to inference model {model.name} to walk-around the issue."
)
assert _batch_per_episode % dp_size == 0, (
"Inner loop in Executor.generate_step_one_model_internal() depends on dp_size of each replica."
)
return _batch_per_episode // dp_size
else:
return _batch_per_episode
def execute(self, is_eval):
data_queues, out_queue = self.setup_queues()
data_producer_iter = cycle(iter(self.models[0].replicas))
# prepare batches for all model replicas
for mb in range(self.batch_per_episode(self.models[0])):
current_data_producer = next(data_producer_iter)
query = current_data_producer.master.next_batch.remote(is_eval=is_eval)
encoded_data = encode_data(mb, query)
for data_queue in data_queues:
data_queue.put(encoded_data)
self.compute_loop(out_queue)
return out_queue
def make_experiences(self):
"""
Generate a collection of experiences for one episode
"""
return self.execute(is_eval=False)
class MCTSEnv(Environment):
"""MCTS Env"""
def __init__(self, model_flow, mcts):
super().__init__(model_flow)
self.max_iteration_per_sample = self.args.max_iteration_per_sample
self.mcts = mcts
assert self.args.sample_per_episode == mcts.module_args.num_cpu
def mcts_loop(self, max_iteration, encoded_data, data_queues, mb, replica_data_list, mcts):
future.wait(mcts.init_tree())
for i in range(max_iteration):
for data_queue in data_queues:
data_queue.put(encoded_data)
for replica, model_node in replica_data_list:
in_queue = model_node.get_input_queues()
func_name = model_node.func_name
# TODO: we will consider colocation/offload later
to_empty_cache = False
to_onload = False
to_offload = False
self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, i, func_name, to_empty_cache,
is_eval=self.is_eval, to_onload=to_onload, to_offload=to_offload, micro_batch_index=mb)
should_stop = future.get(mcts.should_stop())
assert len(should_stop) == 1
if should_stop[0]:
break
def execute(self, is_eval):
data_queues, out_queue = self.setup_queues()
data_producer_iter = cycle(iter(self.models[0].replicas))
args = []
for mb in range(self.batch_per_episode()):
current_data_producer = next(data_producer_iter)
query = current_data_producer.master.next_batch.remote(is_eval=is_eval)
encoded_data = encode_data(mb, query)
replica_data_list = []
model_to_replica = {}
for model_group in self.model_flow.flow_topology:
for model_node in model_group:
model = model_node.model
assert not model.is_colocate, "colocation is currently not supported in MCTSEnv"
assert not model.enable_offload, "offload is currently not supported in MCTSEnv"
if model in model_to_replica:
replica = model_to_replica[model]
else:
replica = self._next_model(model)
model_to_replica[model] = replica
replica_data_list.append((replica, model_node))
mcts = [replica_data[0] for replica_data in replica_data_list if replica_data[0].model is self.mcts]
assert len(mcts) > 0
mcts = mcts[0]
args.append((self.max_iteration_per_sample, encoded_data, data_queues, mb, replica_data_list, mcts))
if self.args.debug:
for arg in args:
self.mcts_loop(*arg)
else:
execute_in_parallel(self.mcts_loop, args)
data = [None] * len(self.model_flow.return_model_nodes)
for model_node in self.model_flow.model_nodes:
if model_node in self.model_flow.return_model_nodes:
# let the results order follow model_node order
data[self.model_flow.return_model_nodes.index(model_node)] = model_node.out_queues[-1]
if data:
self.get_all_merged_data(data, out_queue, encode=False)
return out_queue
class SPRLEnv(Environment):
"""SPRL(Self-Play Reinforcement Learning) Env"""
def __init__(self, model_flow, sprl):
super().__init__(model_flow)
self.max_iteration_per_sample = self.args.max_iteration_per_sample
self.sprl = sprl
def sprl_loop(self, max_iteration, encoded_data, data_queues, mb, replica_data_list, sprl):
future.wait(sprl.reset())
for i in range(max_iteration):
for data_queue in data_queues:
data_queue.put(encoded_data)
for replica, model_node in replica_data_list:
in_queue = model_node.get_input_queues()
func_name = model_node.func_name
# TODO: we will consider colocation/offload later
to_empty_cache = False
to_onload = False
to_offload = False
self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, i, func_name, to_empty_cache,
is_eval=self.is_eval, to_onload=to_onload, to_offload=to_offload, micro_batch_index=mb)
should_stop = future.get(sprl.should_stop())
assert len(should_stop) == 1
if should_stop[0]:
break
def execute(self, is_eval):
data_queues, out_queue = self.setup_queues()
data_producer_iter = cycle(iter(self.models[0].replicas))
args = []
for mb in range(self.batch_per_episode()):
current_data_producer = next(data_producer_iter)
query = current_data_producer.master.next_batch.remote(is_eval=is_eval)
encoded_data = encode_data(mb, query)
replica_data_list = []
model_to_replica = {}
for model_group in self.model_flow.flow_topology:
for model_node in model_group:
model = model_node.model
assert not model.is_colocate, "colocation is currently not supported in SPRLEnv"
assert not model.enable_offload, "offload is currently not supported in SPRLEnv"
if model in model_to_replica:
replica = model_to_replica[model]
else:
replica = self._next_model(model)
model_to_replica[model] = replica
replica_data_list.append((replica, model_node))
sprl = [replica_data[0] for replica_data in replica_data_list if replica_data[0].model is self.sprl]
assert len(sprl) > 0
sprl = sprl[0]
args.append((self.max_iteration_per_sample, encoded_data, data_queues, mb, replica_data_list, sprl))
# not support execute_in_parallel now.
for arg in args:
self.sprl_loop(*arg)
data = [None] * len(self.model_flow.return_model_nodes)
for model_node in self.model_flow.model_nodes:
if model_node in self.model_flow.return_model_nodes:
# let the results order follow model_node order
data[self.model_flow.return_model_nodes.index(model_node)] = model_node.out_queues[-1]
if data:
self.get_all_merged_data(data, out_queue, encode=False)
return out_queue
# pylint: disable=not-callable