def compute_tflops()

in src/hyperpod_nemo_adapter/utils/train_utils.py [0:0]


def compute_tflops(cfg, model_config, sample_processed, step_time, world_size):
    # Based on
    # https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
    hidden_width = model_config.hidden_size
    num_heads = model_config.num_attention_heads
    num_key_value_heads = model_config.num_key_value_heads
    moe = cfg.get("moe", 0)
    num_experts_per_tok = cfg.get("num_experts_per_tok")
    max_context_width = cfg.get("max_context_width")
    num_layers = model_config.num_hidden_layers
    intermediate_size = model_config.intermediate_size
    vocab_size = model_config.vocab_size

    kv_channels = hidden_width // num_heads
    query_projection_size = kv_channels * num_heads
    query_projection_to_hidden_size_ratio = query_projection_size / hidden_width

    # Group Query Attention.
    if not num_key_value_heads:
        num_key_value_heads = num_heads

    # MoE.
    num_experts_routed_to = 1 if moe == 0 else num_experts_per_tok
    gated_linear_multiplier = 3 / 2 if moe > 0 else 1

    # Compute the number of floating point operations
    num_flops = (
        12
        * sample_processed
        * max_context_width
        * num_layers
        * hidden_width
        * hidden_width
        * (
            # Attention.
            (
                (1 + (num_key_value_heads / num_heads) + (max_context_width / hidden_width))
                * query_projection_to_hidden_size_ratio
            )
            # MLP.
            + ((intermediate_size / hidden_width) * num_experts_routed_to * gated_linear_multiplier)
            # Logit.
            + (vocab_size / (2 * num_layers * hidden_width))
        )
    )

    # Convert to TFLOPs per GPU
    tflops_per_gpu = num_flops / (step_time * 10**12 * world_size)

    return tflops_per_gpu