in workload_generator/generate_deepspeed_stage3_workload.py [0:0]
def _gather_param_prefetch(self, param, step_id):
prefetch_bucket, prefetch_bucket_size = [], 0
if not param.has_been_allgather:
prefetch_bucket.append(param)
prefetch_bucket_size += param.numel()
future_param, future_step_id = self._param_queue.popleft()
if future_param != param:
print(
f"WARNING: expected {param.__dict__, step_id} but got {future_param.__dict__, future_step_id}"
)
param.has_been_allgather = True
self.current_live_parameters += param.numel()
while (
self._param_queue
and prefetch_bucket_size < self.prefetch_bucket_size
and self.current_live_parameters < self.max_live_parameters
):
future_param, step_id = self._param_queue.popleft()
self.__most_recent_step_id_param_fetched_for[future_param.id] = max(
step_id, self.__most_recent_step_id_param_fetched_for[future_param.id]
)
if future_param.has_been_allgather:
continue
prefetch_bucket.append(future_param)
future_param.has_been_allgather = True
self.current_live_parameters += future_param.numel()
prefetch_bucket_size += future_param.numel()
if prefetch_bucket:
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.dp_world_size,
msg_size=sum(param.msg_size() for param in prefetch_bucket),
stage=f"{self.stage}.allgather_fn",
)
)
if self.compute_enable:
for param in prefetch_bucket:
self._compute_for_param(param)