def _gather_param_prefetch()

in workload_generator/generate_deepspeed_stage3_workload.py [0:0]


    def _gather_param_prefetch(self, param, step_id):
        prefetch_bucket, prefetch_bucket_size = [], 0
        if not param.has_been_allgather:
            prefetch_bucket.append(param)
            prefetch_bucket_size += param.numel()
            future_param, future_step_id = self._param_queue.popleft()
            if future_param != param:
                print(
                    f"WARNING: expected {param.__dict__, step_id} but got {future_param.__dict__, future_step_id}"
                )
            param.has_been_allgather = True
            self.current_live_parameters += param.numel()

        while (
            self._param_queue
            and prefetch_bucket_size < self.prefetch_bucket_size
            and self.current_live_parameters < self.max_live_parameters
        ):
            future_param, step_id = self._param_queue.popleft()
            self.__most_recent_step_id_param_fetched_for[future_param.id] = max(
                step_id, self.__most_recent_step_id_param_fetched_for[future_param.id]
            )
            if future_param.has_been_allgather:
                continue
            prefetch_bucket.append(future_param)
            future_param.has_been_allgather = True
            self.current_live_parameters += future_param.numel()
            prefetch_bucket_size += future_param.numel()

        if prefetch_bucket:
            self.workload.append(
                LogItem(
                    comm_type=CommType.all_gather,
                    comm_group=CommGroup.dp_group,
                    comm_group_size=self.dp_world_size,
                    msg_size=sum(param.msg_size() for param in prefetch_bucket),
                    stage=f"{self.stage}.allgather_fn",
                )
            )
            if self.compute_enable:
                for param in prefetch_bucket:
                    self._compute_for_param(param)