def _validate_params()

in chatlearn/utils/arguments.py [0:0]


    def _validate_params(self):
        if self.runtime_args.train_global_batch_size is None:
            self.runtime_args.train_global_batch_size = self.runtime_args.train_micro_batch_size
        assert self.runtime_args.train_global_batch_size % self.runtime_args.train_micro_batch_size == 0, \
            f"train_global_batch_size should be times of train_micro_batch_size," \
            f"but got {self.runtime_args.train_global_batch_size}/{self.runtime_args.train_micro_batch_size}"
        assert self.runtime_args.train_global_batch_size <= self.runtime_args.sample_per_episode, \
            "train_global_batch_size should be less than or equal to sample_per_episode, " \
            f"got {self.runtime_args.train_global_batch_size} and {self.runtime_args.sample_per_episode}"
        assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"]
        assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY]
        assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE)
        if isinstance(self.runtime_args.data_path, list):
            assert self.runtime_args.data_ratio is not None and isinstance(self.runtime_args.data_ratio, list), (
                f"expect data_ratio to be list when data_path is list, got {self.runtime_args.data_ratio}"
            )
            assert len(self.runtime_args.data_path) == len(self.runtime_args.data_ratio), (
                "expect data_path and data_ratio to have same length, "
                f"got {len(self.runtime_args.data_path)} and {len(self.runtime_args.data_ratio)}"
            )
        for model_name, model_args in self.models.items():
            if model_args.num_gpu >= 1:
                if model_args.gpu_per_process is None:
                    model_args.gpu_per_process = 1
                else:
                    assert model_args.gpu_per_process <= model_args.num_gpu, \
                        f"{model_name}: gpu_per_process: {model_args.gpu_per_process}, num_cpu: {model_args.num_gpu}"
            elif model_args.num_cpu >= 1:
                if model_args.cpu_per_process is None:
                    model_args.cpu_per_process = 1
                else:
                    assert model_args.cpu_per_process <= model_args.num_cpu, \
                        f"{model_name}: cpu_per_process: {model_args.cpu_per_process}, num_cpu: {model_args.num_cpu}"
            if model_args.generation_batch_size is not None and model_args.generation_batch_size <= 0:
                model_args.generation_batch_size = DYNAMIC_BATCH_SIZE
            if model_args.generation_batch_size is None:
                if self.runtime_args.generation_batch_size:
                    model_args.generation_batch_size = self.runtime_args.generation_batch_size
            for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]:
                if model_args.args_dict.get(key) is not None:
                    setattr(model_args, key, model_args.args_dict.get(key))
                    assert getattr(model_args, key) >= 1
                elif getattr(model_args, key) is None:
                    setattr(model_args, key, 1)

            for key in ["fsdp_size"]:
                if getattr(model_args, key) is not None:
                    setattr(model_args, key, getattr(model_args, key))
                    if getattr(model_args, key) == -1:
                        print(f"set_fsdp_size {getattr(model_args, key)} to num_gpu: {model_args.num_gpu}")
                        setattr(model_args, key, model_args.num_gpu)
                    assert getattr(model_args, key) >= 1
                elif getattr(model_args, key) is None:
                    setattr(model_args, key, 1)

            ep_size = model_args.args_dict.get("expert_model_parallel_size")
            moe_ep_size = model_args.args_dict.get("moe_expert_model_parallel_size")
            if ep_size is not None and moe_ep_size is not None:
                assert ep_size == moe_ep_size, (
                    f"{model_name}: if you set moe_expert_model_parallel_size ({moe_ep_size}), "
                    f"it must be equal to expert_model_parallel_size ({ep_size})"
                )
                finalized_ep_size = ep_size
            elif ep_size is not None:
                finalized_ep_size = ep_size
            elif moe_ep_size is not None:
                finalized_ep_size = moe_ep_size
            else:
                finalized_ep_size = 1
            assert finalized_ep_size >= 1
            setattr(model_args, "expert_model_parallel_size", finalized_ep_size)

            if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1 or model_args.expert_model_parallel_size > 1:
                assert model_args.zero_size == 1 or model_args.zero_size is None
                assert model_args.fsdp_size == 1 or model_args.fsdp_size is None
                assert model_args.num_gpu % (
                    model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) == 0, \
                    f"{model_name}: num_gpu must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * " \
                    f"expert_model_parallel_size, but got num_gpu = {model_args.num_gpu}, " \
                    f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, " \
                    f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}, and "\
                    f"expert_model_parallel_size = {model_args.expert_model_parallel_size}."
            assert model_args.num_gpu > 0 or model_args.num_cpu > 0, \
                f"{model_name} num_gpu: {model_args.num_gpu}, num_cpu: {model_args.num_cpu}, at least one of them should be set"

            if model_args.num_gpu >= 1:
                if model_args.zero_size > 1:
                    assert model_args.num_gpu % model_args.zero_size == 0
                    model_args.num_replica = model_args.num_gpu // model_args.zero_size
                elif model_args.fsdp_size > 1:
                    # For FSDP, num_gpu must be divisible by fsdp_size
                    assert model_args.num_gpu % model_args.fsdp_size == 0
                    model_args.num_replica = model_args.num_gpu // (
                        model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size \
                            * model_args.expert_model_parallel_size * model_args.fsdp_size)
                else:
                    model_args.num_replica = model_args.num_gpu // (
                        model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size)

            elif model_args.num_cpu >= 1:
                model_args.num_replica = model_args.num_cpu // model_args.cpu_per_process
            assert model_args.num_replica * model_args.generation_batch_size <= self.runtime_args.sample_per_episode, \
                f"{model_name}: num_replica * batch_size {model_args.num_replica}*{model_args.generation_batch_size} " + \
                f"should be less than or equal to sample_per_episode {self.runtime_args.sample_per_episode}"
            if model_args.batch_generation.min_prompt_length:
                logger.info(f"Enable batch generation: \
                    min_prompt_length = {model_args.batch_generation.min_prompt_length}")
            if model_args.free_memory:
                model_args.offload_weights = True
                if model_args.trainable:
                    model_args.free_grad_buffers = True
                    model_args.offload_optimizer_states = True
        if self.runtime_args.colocation and len(self.runtime_args.colocation) > 0:
            model_set = set()
            for colocate_models in self.runtime_args.colocation:
                for model_name in colocate_models:
                    assert model_name not in model_set, f"Model {model_name} should only appear once in colocation group"
                    model_set.add(model_name)
        if self.runtime_args.exp_name not in self.runtime_args.output_dir:
            self.runtime_args.output_dir = f"{self.runtime_args.output_dir}/{self.runtime_args.exp_name}"
        logger.info(f"Env Config: \n{self.env_args}")
        logger.info(f"Runtime Config: \n{self.runtime_args}")
        for name, model_args in self.models.items():
            logger.info(f"Model({name}) Config: \n{model_args}")