in chatlearn/models/base_module.py [0:0]
def __init__(self, name, args=None, replica_id=0):
logger.info(f"{LOG_START} basemodule {name} init start")
self.name = name
if args is None:
global_args = get_args()
else:
global_args = args
set_global_variables(args)
self.global_args = global_args
args = global_args.models[name]
self.total_gpu = args.num_gpu
self.total_cpu = args.num_cpu
self.gpu_per_process = args.gpu_per_process
self.trainable = args.trainable
self._runtime_args = self.global_args.runtime_args
self._module_args = args
self.replica_id = replica_id
self.config_dir = args.config_dir
self._is_colocate = False
if self.total_gpu > 0:
self._num_gpu_per_replica = (
args.tensor_model_parallel_size
* args.pipeline_model_parallel_size
* args.expert_model_parallel_size
* args.zero_size
* args.fsdp_size
)
assert self._num_gpu_per_replica <= self.total_gpu, \
f"_num_gpu_per_replica {self._num_gpu_per_replica} larger than total_gpu {self.total_gpu} " + \
f"tp_size: {args.tensor_model_parallel_size} pp_size: {args.pipeline_model_parallel_size} " + \
f"ep_size: {args.expert_model_parallel_size} zero_size: {args.zero_size}"
assert self.total_gpu % self._num_gpu_per_replica == 0
if not self.trainable:
self._num_replica = args.num_gpu // self._num_gpu_per_replica
else:
# For trainable models, perform the DP inside DistActor
self._num_replica = 1
self._num_gpu_per_replica = self.total_gpu
else:
self._num_gpu_per_replica = 0
self._num_replica = args.num_replica
assert self._num_replica >= 1
self._param_ranks = None
self._named_parameters = None
self._param_to_name = None
self._parameters = None
self._coalesced_parameters = None
self.error_signal = None
self._rank = None
self._world_size = None
self._group_names = []
self._dataloader = None
self._eval_dataloader = None
self._kl_coef = None
self._padding_config = {}
self._storage = None
self._timers = None
self._data_iter = None
self._eval_data_iter = None
self.call_funcs = []
self.trainable_funcs = []
self._data_ckpt_manager = None
self._peak_memory = 0
self._parameters_to_sync = defaultdict(list)
self._parameters_to_send = defaultdict(list)
self._parameters_to_recv = defaultdict(list)
self._parameters_shape = []
# current compute iteration
self._iteration = 0
self._train_iteration = 0
self._episode_id = 0
self.enable_lora = self._module_args.lora.enable_lora
self._finalized = False
self._resume_training = False
self._address = dlc_utils.get_addr() if dlc_utils.in_dlc_env() else get_host_addr()
self._is_master_node = os.environ.get("RANK", '0') == '0'
self._logger = setup_logger(model_name=self.name, ip_addr=self._address)
# parameter sync from src_model
self._src_parameter_model = None
self.profiler = None
self._buffer_num = {}
self._tp_division = {}
self._tp_num_mapping = 1
self._sync_buffer = defaultdict(list)
self._sync_dst_rank_to_src_ranks = {}
self._expert_sync_buffer = {}
self._synchronizer = None
self._metric_prefix = ""
self._metric_list = []
self._stage_resume_done = False
logger.info(f"{LOG_START} basemodule {name} init done")