chatlearn/runtime/trainer.py (65 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer"""
import math
import ray
from chatlearn.utils import future
from chatlearn.utils.constant import TrainingShffuleMode
from chatlearn.utils.logger import logger
from .executor import Executor
from .utils import encode_data
class Trainer(Executor):
"""Trainer"""
def __init__(self, model_flow):
"""
Trainer
Args
----
models : List[BaseModule]
a list of modules
"""
super().__init__(model_flow)
for model, func_names in self.model_to_call_funcs.items():
model.trainable_funcs += func_names
self.iteration = 0
self._data_parallel_size = None
def setup(self):
super().setup()
for model_node in self.model_flow.model_nodes:
model_node.trainable = True
def set_data_loader(self, data_loader):
self._data_loader = data_loader
def next_batch(self):
batches = []
for _ in range(self.num_micro_batch_per_dp):
data = self._data_loader.next.remote()
if future.get(self._data_loader.has_next.remote()):
batches.append(data)
if not batches:
return
else:
if len(batches) < self.num_micro_batch_per_dp:
batches += batches[:self.num_micro_batch_per_dp - len(batches)]
return batches
# pylint: disable=unused-argument
def num_iteration(self, model=None):
# Given that we have incorporated support for relay buffer and dynamic reward outputs,
# the number of training data batches per episode may differ, hence we dynamically determine the total number of batches per episode.
_sample_per_episode = ray.get(self._data_loader.total_samples.remote())
return math.ceil(_sample_per_episode / self.args.train_global_batch_size)
@property
def data_parallel_size(self):
if self._data_parallel_size is None:
self._data_parallel_size = self.first_model.replicas[0].data_parallel_size
for model in self.models[1:]:
assert model.replicas[0].data_parallel_size == self._data_parallel_size, \
"Currently, all training models are assumed to have the same data_parallel_size"
return self._data_parallel_size
def train(self, episode):
self.num_micro_batch_per_dp = self.args.train_global_batch_size // self.args.train_micro_batch_size // self.data_parallel_size
_num_training_iteration = self.num_iteration()
self._batch_per_episode = _num_training_iteration
assert self.args.training_shuffle_mode in list(TrainingShffuleMode), \
f"Unsupported training shuffle mode {self.args.training_shuffle_mode}, only {list(TrainingShffuleMode)} allowed."
logger.info(f"Set training shuffle mode {self.args.training_shuffle_mode}.")
for epoch in range(self.args.num_training_epoch):
if epoch > 0:
if self.args.training_shuffle_mode == TrainingShffuleMode.BATCH:
ret = self._data_loader.shuffle.remote(self.args.train_micro_batch_size)
elif self.args.training_shuffle_mode == TrainingShffuleMode.SAMPLE:
ret = self._data_loader.shuffle.remote()
future.wait(ret)
data_queues, out_queue = self.setup_queues()
for mb in range(_num_training_iteration * self.data_parallel_size):
batch = encode_data(mb, self.next_batch())
for data_queue in data_queues:
data_queue.put(batch)
self.compute_loop(out_queue, _num_training_iteration)
self.iteration = self.iteration + _num_training_iteration
logger.info(f"train episode: {episode+1}, epoch {epoch} num_step {_num_training_iteration} done")