def learn()

in chatlearn/runtime/engine.py [0:0]


    def learn(self):
        self.timers("chatlearn").start()
        self.timers("setup").start()
        self.setup()
        self.timers("executor_setup").start()
        for executor in self._executors:
            if executor:
                executor.setup()
        self.timers("executor_setup").stop()
        logger.info(
            f"{LOG_START} {self._name} setup executors: {self.timers.log(names=['executor_setup'])}")
        self.timers("setup").stop()
        logger.info(
            f"{LOG_START} {self._name} setup summary {self.timers.log(names=['setup'])}")
        self.logging_memory()
        self._resume_from_data_checkpoint()
        # Enable chunkflow optimization
        enable_chunkflow_optimization = os.environ.get("ENABLE_CHUNKFLOW_OPTIMIZATION", "False") in ["True", "true", "1", 1]
        logger.info(f"{LOG_START} Check ENABLE_CHUNKFLOW_OPTIMIZATION={enable_chunkflow_optimization} for chunkflow optimization")
        data_loader = StreamDataset.remote(
            self.runtime_args.stream_data_loader_type,
            self.runtime_args.train_micro_batch_size,
            self.env._padding_config,
            self.runtime_args.max_relay_episode,
            self.runtime_args.relay_episode_offset,
            self.runtime_args.train_global_batch_size \
                if enable_chunkflow_optimization \
                else self.runtime_args.train_micro_batch_size
        )

        logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
        dump_root_path = os.getenv("DEBUG_SYNC_PARAMETERS_PATH", "")
        if dump_root_path:
            if os.path.exists(dump_root_path):
                shutil.rmtree(dump_root_path)
            logger.info(f"{LOG_START} dump parameters before syncnizing...")
            self.dump_parameters(os.path.join(dump_root_path, "before_sync_parameter"))
        self.timers("sync_parameters").start()
        if os.getenv("ENABLE_PARAM_SYNC_WARMUP", "false") == "true":
            self.timers("warmup_sync_parameters").start()
            self.model_manager.sync_parameters(requires_grad=False, validate=False, dryrun=True)
            self.model_manager.warmup_collective_topology()
            self.timers("warmup_sync_parameters").stop()
            logger.info(f"{LOG_START} finish warmup_sync_parameters {self.timers.log(names=['warmup_sync_parameters'])} ")
        self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
        self.timers("sync_parameters").stop()
        if self.runtime_args.enable_eval_before_training:
            self.evaluate(-1)
        if dump_root_path:
            logger.info(f"{LOG_START} dump parameters after synchronizing...")
            self.dump_parameters(os.path.join(dump_root_path, "after_sync_parameter"))
            logger.info(f"{LOG_START} finish dump parameters, ChatLearn will exit")
            return
        logger.info(get_full_proc_memory_info('After first param sync'))
        self.logging_summary(-1)

        self._data_loader = data_loader
        for episode_id in range(self._start_episode, self.runtime_args.num_episode):
            if self.runtime_args.nsys:
                if episode_id == 4:
                    torch.cuda.cudart().cudaProfilerStart()
                if episode_id == 5:
                    torch.cuda.cudart().cudaProfilerStop()
            self.timers("episode").start()
            self.before_episode()
            logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
            if self.env.timers is None:
                self.env.set_timers(self.timers)
            queue = []
            if os.getenv("SKIP_GENERATION", None) is None:
                logger.info(f"{LOG_START} start to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
                queue = self.env.make_experiences()
                logger.info(f"{LOG_START} complete to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
                self.timers("set_train_dataset").start()
            else:
                logger.info(f"{LOG_START} Skip generation phase for episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
            refs = data_loader.set_dataset.remote(queue, episode_id, self._relay_sample_manager,
                                                  self.runtime_args.sample_per_episode)
            future.wait(refs, return_output=True)
            if self.trainer is not None:
                # validate parameter sync in the first two episodes
                validate = self.runtime_args.validate_param_sync and episode_id < 2
                self.timers("set_train_dataset").stop()
                self.trainer.set_data_loader(data_loader)
                logger.info(f"{LOG_START} set dataloader for trainer done")
                logger.info(get_full_proc_memory_info(f"{LOG_START} Before train {episode_id}"))
                if self.trainer.timers is None:
                    self.trainer.set_timers(self.timers)
                self.trainer.train(episode_id)
                logger.info(get_full_proc_memory_info(f"{LOG_START} After train {episode_id}"))
                self.timers("sync_parameters").start()
                self.model_manager.sync_parameters(episode_id + 1, validate=validate)
                self.timers("sync_parameters").stop()
                logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} parameter sync done")
            logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} done")
            self.timers("episode").stop()
            self.save_checkpoint(episode_id)
            self.evaluate(episode_id)
            self.after_episode()
            self.logging_summary(episode_id)

        self.timers("chatlearn").stop()
        logger.info(f"{LOG_START} {self._name} overall summary {self.timers.log(names=['chatlearn'])}")
        logger.info(f"{LOG_START} train {self._name} done")