def compute_loop_one_model()

in chatlearn/runtime/executor.py [0:0]


    def compute_loop_one_model(self, model_node, num_batch=None):
        logger.info(f"{LOG_START} start compute_loop for {model_node}, is_eval={self.is_eval}")
        model = model_node.model
        is_eval = self.is_eval

        if num_batch is None:
            num_batch = self.num_iteration(model)

        func_name = model_node.func_name
        if model_node.remote_objects_to_wait:
            logger.info(f"{LOG_START} start to wait colocate models to finish for {model_node}")
            model_node.wait_colocate_models_to_finish(self.timers, func_name)
            logger.info(f"{LOG_START} complete to wait colocate models to finish for {model_node}")
        replica_num = len(model.replicas)
        last_step_start = max(num_batch - replica_num, 0)
        in_queue = model_node.get_input_queues()

        logger.info(f"{LOG_START} start to regroup in_queue for {model_node}")
        in_queue = self.regroup_inqueue(model_node, in_queue, is_eval=is_eval)
        logger.info(f"{LOG_START} complete to regroup in_queue for {model_node}")

        if isinstance(in_queue, list) and len(in_queue) == 1:
            in_queue = in_queue[0]
        results = []
        logger.info(f"{LOG_START} start to generate_step_one_model for {model_node}")
        self.timers(f"{model.name}").start()
        for step in range(num_batch):
            to_empty_cache = step >= last_step_start and (model.is_colocate or model.module_args.force_free_memory)
            to_onload = step < replica_num and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
            to_offload = step >= last_step_start and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
            replica = self._next_model(model)
            _, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache,
                                                   is_eval=is_eval, to_onload=to_onload, to_offload=to_offload)
            results.append(data)
        self.timers(f"{model.name}").stop()
        if model_node.next_colocate_node:
            # before the execution of next colocate model, perform the wait, since we want to empty the cache.
            logger.info(
                f"{LOG_START} Model {model_node.next_colocate_node} will wait model {model} to finish since they are colocated")
            self._models_and_results_to_wait = model_node.next_colocate_node.add_dependent_colocate_model_results(
                model_node, results, self._models_and_results_to_wait)
        elif model.colocate_models or model.trainable:
            # 1. the model may colocate with training/inference, so we should wait until the end of compute_loop
            # 2. the model is trainable and it does not have next_colocate_model, we should make sure it is finished before parameter_sync
            # so we add them to a temp list
            logger.info(f"{LOG_START} Sync {model} in the end of {self.__class__.__name__}")
            self._models_and_results_to_wait.append((model_node, results))