def __init__()

in chatlearn/models/base_module.py [0:0]


    def __init__(self, name, args=None, replica_id=0):
        logger.info(f"{LOG_START} basemodule {name} init start")
        self.name = name
        if args is None:
            global_args = get_args()
        else:
            global_args = args
            set_global_variables(args)
        self.global_args = global_args
        args = global_args.models[name]
        self.total_gpu = args.num_gpu
        self.total_cpu = args.num_cpu
        self.gpu_per_process = args.gpu_per_process
        self.trainable = args.trainable
        self._runtime_args = self.global_args.runtime_args
        self._module_args = args
        self.replica_id = replica_id
        self.config_dir = args.config_dir
        self._is_colocate = False

        if self.total_gpu > 0:
            self._num_gpu_per_replica = (
                args.tensor_model_parallel_size
                * args.pipeline_model_parallel_size
                * args.expert_model_parallel_size
                * args.zero_size
                * args.fsdp_size
            )
            assert self._num_gpu_per_replica <= self.total_gpu, \
                f"_num_gpu_per_replica {self._num_gpu_per_replica} larger than total_gpu {self.total_gpu} " + \
                f"tp_size: {args.tensor_model_parallel_size} pp_size: {args.pipeline_model_parallel_size} " + \
                f"ep_size: {args.expert_model_parallel_size} zero_size: {args.zero_size}"
            assert self.total_gpu % self._num_gpu_per_replica == 0
            if not self.trainable:
                self._num_replica = args.num_gpu // self._num_gpu_per_replica
            else:
                # For trainable models, perform the DP inside DistActor
                self._num_replica = 1
                self._num_gpu_per_replica = self.total_gpu
        else:
            self._num_gpu_per_replica = 0
            self._num_replica = args.num_replica

        assert self._num_replica >= 1
        self._param_ranks = None
        self._named_parameters = None
        self._param_to_name = None
        self._parameters = None
        self._coalesced_parameters = None
        self.error_signal = None
        self._rank = None
        self._world_size = None
        self._group_names = []
        self._dataloader = None
        self._eval_dataloader = None
        self._kl_coef = None
        self._padding_config = {}
        self._storage = None
        self._timers = None
        self._data_iter = None
        self._eval_data_iter = None
        self.call_funcs = []
        self.trainable_funcs = []
        self._data_ckpt_manager = None
        self._peak_memory = 0
        self._parameters_to_sync = defaultdict(list)
        self._parameters_to_send = defaultdict(list)
        self._parameters_to_recv = defaultdict(list)
        self._parameters_shape = []
        # current compute iteration
        self._iteration = 0
        self._train_iteration = 0
        self._episode_id = 0
        self.enable_lora = self._module_args.lora.enable_lora
        self._finalized = False
        self._resume_training = False
        self._address = dlc_utils.get_addr() if dlc_utils.in_dlc_env() else get_host_addr()
        self._is_master_node = os.environ.get("RANK", '0') == '0'
        self._logger = setup_logger(model_name=self.name, ip_addr=self._address)
        # parameter sync from src_model
        self._src_parameter_model = None
        self.profiler = None
        self._buffer_num = {}
        self._tp_division = {}
        self._tp_num_mapping = 1
        self._sync_buffer = defaultdict(list)
        self._sync_dst_rank_to_src_ranks = {}
        self._expert_sync_buffer = {}
        self._synchronizer = None
        self._metric_prefix = ""
        self._metric_list = []
        self._stage_resume_done = False
        logger.info(f"{LOG_START} basemodule {name} init done")