in chatlearn/runtime/engine.py [0:0]
def learn(self):
self.timers("chatlearn").start()
self.timers("setup").start()
self.setup()
self.timers("executor_setup").start()
for executor in self._executors:
if executor:
executor.setup()
self.timers("executor_setup").stop()
logger.info(
f"{LOG_START} {self._name} setup executors: {self.timers.log(names=['executor_setup'])}")
self.timers("setup").stop()
logger.info(
f"{LOG_START} {self._name} setup summary {self.timers.log(names=['setup'])}")
self.logging_memory()
self._resume_from_data_checkpoint()
# Enable chunkflow optimization
enable_chunkflow_optimization = os.environ.get("ENABLE_CHUNKFLOW_OPTIMIZATION", "False") in ["True", "true", "1", 1]
logger.info(f"{LOG_START} Check ENABLE_CHUNKFLOW_OPTIMIZATION={enable_chunkflow_optimization} for chunkflow optimization")
data_loader = StreamDataset.remote(
self.runtime_args.stream_data_loader_type,
self.runtime_args.train_micro_batch_size,
self.env._padding_config,
self.runtime_args.max_relay_episode,
self.runtime_args.relay_episode_offset,
self.runtime_args.train_global_batch_size \
if enable_chunkflow_optimization \
else self.runtime_args.train_micro_batch_size
)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
dump_root_path = os.getenv("DEBUG_SYNC_PARAMETERS_PATH", "")
if dump_root_path:
if os.path.exists(dump_root_path):
shutil.rmtree(dump_root_path)
logger.info(f"{LOG_START} dump parameters before syncnizing...")
self.dump_parameters(os.path.join(dump_root_path, "before_sync_parameter"))
self.timers("sync_parameters").start()
if os.getenv("ENABLE_PARAM_SYNC_WARMUP", "false") == "true":
self.timers("warmup_sync_parameters").start()
self.model_manager.sync_parameters(requires_grad=False, validate=False, dryrun=True)
self.model_manager.warmup_collective_topology()
self.timers("warmup_sync_parameters").stop()
logger.info(f"{LOG_START} finish warmup_sync_parameters {self.timers.log(names=['warmup_sync_parameters'])} ")
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
self.timers("sync_parameters").stop()
if self.runtime_args.enable_eval_before_training:
self.evaluate(-1)
if dump_root_path:
logger.info(f"{LOG_START} dump parameters after synchronizing...")
self.dump_parameters(os.path.join(dump_root_path, "after_sync_parameter"))
logger.info(f"{LOG_START} finish dump parameters, ChatLearn will exit")
return
logger.info(get_full_proc_memory_info('After first param sync'))
self.logging_summary(-1)
self._data_loader = data_loader
for episode_id in range(self._start_episode, self.runtime_args.num_episode):
if self.runtime_args.nsys:
if episode_id == 4:
torch.cuda.cudart().cudaProfilerStart()
if episode_id == 5:
torch.cuda.cudart().cudaProfilerStop()
self.timers("episode").start()
self.before_episode()
logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
if self.env.timers is None:
self.env.set_timers(self.timers)
queue = []
if os.getenv("SKIP_GENERATION", None) is None:
logger.info(f"{LOG_START} start to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
queue = self.env.make_experiences()
logger.info(f"{LOG_START} complete to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
self.timers("set_train_dataset").start()
else:
logger.info(f"{LOG_START} Skip generation phase for episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
refs = data_loader.set_dataset.remote(queue, episode_id, self._relay_sample_manager,
self.runtime_args.sample_per_episode)
future.wait(refs, return_output=True)
if self.trainer is not None:
# validate parameter sync in the first two episodes
validate = self.runtime_args.validate_param_sync and episode_id < 2
self.timers("set_train_dataset").stop()
self.trainer.set_data_loader(data_loader)
logger.info(f"{LOG_START} set dataloader for trainer done")
logger.info(get_full_proc_memory_info(f"{LOG_START} Before train {episode_id}"))
if self.trainer.timers is None:
self.trainer.set_timers(self.timers)
self.trainer.train(episode_id)
logger.info(get_full_proc_memory_info(f"{LOG_START} After train {episode_id}"))
self.timers("sync_parameters").start()
self.model_manager.sync_parameters(episode_id + 1, validate=validate)
self.timers("sync_parameters").stop()
logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} parameter sync done")
logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} done")
self.timers("episode").stop()
self.save_checkpoint(episode_id)
self.evaluate(episode_id)
self.after_episode()
self.logging_summary(episode_id)
self.timers("chatlearn").stop()
logger.info(f"{LOG_START} {self._name} overall summary {self.timers.log(names=['chatlearn'])}")
logger.info(f"{LOG_START} train {self._name} done")