chatlearn/runtime/trainer.py (65 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Trainer""" import math import ray from chatlearn.utils import future from chatlearn.utils.constant import TrainingShffuleMode from chatlearn.utils.logger import logger from .executor import Executor from .utils import encode_data class Trainer(Executor): """Trainer""" def __init__(self, model_flow): """ Trainer Args ---- models : List[BaseModule] a list of modules """ super().__init__(model_flow) for model, func_names in self.model_to_call_funcs.items(): model.trainable_funcs += func_names self.iteration = 0 self._data_parallel_size = None def setup(self): super().setup() for model_node in self.model_flow.model_nodes: model_node.trainable = True def set_data_loader(self, data_loader): self._data_loader = data_loader def next_batch(self): batches = [] for _ in range(self.num_micro_batch_per_dp): data = self._data_loader.next.remote() if future.get(self._data_loader.has_next.remote()): batches.append(data) if not batches: return else: if len(batches) < self.num_micro_batch_per_dp: batches += batches[:self.num_micro_batch_per_dp - len(batches)] return batches # pylint: disable=unused-argument def num_iteration(self, model=None): # Given that we have incorporated support for relay buffer and dynamic reward outputs, # the number of training data batches per episode may differ, hence we dynamically determine the total number of batches per episode. _sample_per_episode = ray.get(self._data_loader.total_samples.remote()) return math.ceil(_sample_per_episode / self.args.train_global_batch_size) @property def data_parallel_size(self): if self._data_parallel_size is None: self._data_parallel_size = self.first_model.replicas[0].data_parallel_size for model in self.models[1:]: assert model.replicas[0].data_parallel_size == self._data_parallel_size, \ "Currently, all training models are assumed to have the same data_parallel_size" return self._data_parallel_size def train(self, episode): self.num_micro_batch_per_dp = self.args.train_global_batch_size // self.args.train_micro_batch_size // self.data_parallel_size _num_training_iteration = self.num_iteration() self._batch_per_episode = _num_training_iteration assert self.args.training_shuffle_mode in list(TrainingShffuleMode), \ f"Unsupported training shuffle mode {self.args.training_shuffle_mode}, only {list(TrainingShffuleMode)} allowed." logger.info(f"Set training shuffle mode {self.args.training_shuffle_mode}.") for epoch in range(self.args.num_training_epoch): if epoch > 0: if self.args.training_shuffle_mode == TrainingShffuleMode.BATCH: ret = self._data_loader.shuffle.remote(self.args.train_micro_batch_size) elif self.args.training_shuffle_mode == TrainingShffuleMode.SAMPLE: ret = self._data_loader.shuffle.remote() future.wait(ret) data_queues, out_queue = self.setup_queues() for mb in range(_num_training_iteration * self.data_parallel_size): batch = encode_data(mb, self.next_batch()) for data_queue in data_queues: data_queue.put(batch) self.compute_loop(out_queue, _num_training_iteration) self.iteration = self.iteration + _num_training_iteration logger.info(f"train episode: {episode+1}, epoch {epoch} num_step {_num_training_iteration} done")