def is_enough_layers_for_pp()

in bench_cluster/create_configs.py [0:0]


def is_enough_layers_for_pp(pp_size, config):
    
    def _get_block_compute_costs(config):
        """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
        model_config = config["model"]["model_config"]
        d_ff = model_config["intermediate_size"]
        d_qkv = model_config["hidden_size"] // model_config["num_attention_heads"]
        
        block_compute_costs = {
            # This is the last lm_head
            "lm_head": model_config["vocab_size"] * model_config["hidden_size"],
        }
        for i in range(model_config["num_hidden_layers"]):
            # CausalSelfAttention (qkv proj + attn out) + MLP
            block_compute_costs[f"decoder{i}"] = 4 * model_config["num_attention_heads"] * d_qkv * model_config["hidden_size"] + 3 * d_ff * model_config["hidden_size"]

        return block_compute_costs

    # compute PP block repartition
    block_compute_costs = _get_block_compute_costs(config)
    num_layers = config["model"]["model_config"]["num_hidden_layers"]
    pipeline_blocks = ["token_embedding"] + [f"decoder{i}" for i in range(num_layers)] + ["final_layer_norm", "lm_head", "cast_to_fp32", "loss"]
    block_cumulative_costs = np.cumsum(
        [
            block_compute_costs[name] if name in block_compute_costs else 0
            for name in pipeline_blocks
        ]
    )
    
    # Assign ranks to blocks
    block2rank = {block: 0 for block in pipeline_blocks}
    target_pp_ranks = list(range(pp_size))
    thresholds = [block_cumulative_costs[-1] * ((rank + 1) / pp_size) for rank in range(pp_size)]
    assert thresholds[-1] >= block_cumulative_costs[-1]
    target_pp_rank_idx = 0
    
    for block, cumulative_cost in zip(pipeline_blocks, block_cumulative_costs):
        assert target_pp_rank_idx < pp_size
        block2rank[block] = target_pp_ranks[target_pp_rank_idx]
        
        if cumulative_cost > thresholds[target_pp_rank_idx]:
            target_pp_rank_idx += 1

    block2rank["token_embedding"] = target_pp_ranks[0]
    block2rank["loss"] = target_pp_ranks[target_pp_rank_idx]
    
    # Check if all ranks have a block assigned to it
    unique_ranks = sorted(set(block2rank.values()))
    expected_ranks = list(range(pp_size))

    return unique_ranks == expected_ranks