def build_model_gbuf_param_range_map()

in workload_generator/generate_deepspeed_stage1_2_workload.py [0:0]


    def build_model_gbuf_param_range_map(self, model: MockedModel, dp_world_size: int):
        gbuf_size = sum([param.numel() for param in model.parameters()])

        gbuf_partition_size = int(math.ceil(gbuf_size / dp_world_size))
        gbuf_world_all_ranges = []
        for r in range(dp_world_size):
            gbuf_world_start = r * gbuf_partition_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start + gbuf_partition_size)
            gbuf_world_all_ranges.append((gbuf_world_start, gbuf_world_end))

        start_idx, r = 0, 0
        gbuf_world_start, gbuf_world_end = gbuf_world_all_ranges[r]
        # record each param should be reduced to which rank(s)
        # param_id: int -> List[(rank: int, param_start_idx: int, param_end_idx: int)]
        param_range_map = {}
        for param in self.all_params:
            # current param in [start_idx, end_idx) range of gbuf
            param_id = id(param)
            param_range_map[param_id] = []
            end_idx = start_idx + param.numel()

            # current rank is in change of [gbuf_world_start, gbuf_world_end) of gbuf
            param_start_idx = start_idx
            # if current rank cannot fully cover this param, move to next rank
            while gbuf_world_end < end_idx:
                param_range_map[param_id].append((r, param_start_idx, gbuf_world_end))
                param_start_idx = gbuf_world_end
                r += 1
                gbuf_world_start, gbuf_world_end = gbuf_world_all_ranges[r]
            param_range_map[param_id].append((r, param_start_idx, end_idx))

            # for next param
            start_idx = end_idx
        return param_range_map