"""
Copyright (c) 2021, Alibaba Group;
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import workload_generator.mocked_model.MockedDeepspeed
from workload_generator.mocked_model.MockedMegatron import *
from workload_generator.mocked_model.MockedModel import MockedParam, MockedModel
from utils.utils import CommType, get_params, get_comp_out, extract_averages
import os
from typing import List, Tuple
from collections import deque
import dataclasses
from enum import Enum

try:
    import torch
except ImportError as e:
    torch = None
    print("Failed to import 'torch'.")
import math
import re





@dataclasses.dataclass
class Work_Item:
    name: str = dataclasses.field(default="none")
    placeholder: int = dataclasses.field(default=-1)
    forward_compute_time: int = dataclasses.field(default=0)
    forward_comm: str = dataclasses.field(default="NONE")
    forward_comm_size: int = dataclasses.field(default=0)
    backward_compute_time: int = dataclasses.field(default=0)
    backward_comm: str = dataclasses.field(default="NONE")
    backward_comm_size: int = dataclasses.field(default=0)
    dp_compute_time: int = dataclasses.field(default=0)
    dp_comm: str = dataclasses.field(default="NONE")
    dp_comm_size: int = dataclasses.field(default=0)
    process_time: int = dataclasses.field(default=100)



def _get_aiob_compute_time(compute_cache, forward_or_backward, stage):
    compute_time_map = compute_cache
    if stage == "grad":
        prefix = stage + "_" + forward_or_backward
    elif stage == "embedding":
        prefix = "Emb"
    elif stage == "final":
        prefix = "attention" + "_" + forward_or_backward
    else:
        prefix = stage + "_" + forward_or_backward

    for key, value in compute_time_map.items():
        if prefix == key:

            compute_time = compute_time_map.get(key)
            return compute_time

    print("[warn] can't match any stage", stage)
    return 1


class LayerInfo:
    def __init__(self, layer_id, layer_name, param_count):
        self.layer_id = layer_id
        self.layer_name = layer_name
        self.param_count = param_count


class SIMAI_workload:
    def __init__(self, model, args, compute_cache=None):
        self.model = model
        self.args = args
        self.compute_cache = compute_cache
        self.workload = []
        self.seq_len = args.seq_length
        self.tp = args.tensor_model_parallel_size
        self.mbs = args.micro_batch
        if args.moe_enable:
            self.expert_model_parallel_size = args.expert_model_parallel_size
            self.num_experts = args.num_experts
            self.topk = args.moe_router_topk

    def get_model_details(self):
        layers = []
        visited = set()

        def traverse_model(model):
            if id(model) in visited:
                return
            visited.add(id(model))

            if self.args.enable_sequence_parallel:
                if (
                    isinstance(model, MegatronColumnLinear)
                    or isinstance(model, MegatronRowLinear)
                    or isinstance(model, MegatronEmbedding)
                    or isinstance(model, FusedLayernorm)
                ):
                    params = model.parameters()
                    param_count = sum(p.numel() for p in params)
                    layers.append(LayerInfo(model.layer_id, model.name, param_count))
                if isinstance(model, MOEMLP):
                    moe_params = model.parameters()
                    moe_param_count = sum(p.numel() for p in moe_params)
                    layers.append(LayerInfo(model.layer_id, model.name, moe_param_count))

            else:
                if (
                    isinstance(model, MegatronAttention)
                    or isinstance(model, MegatronMlp)
                    or isinstance(model, MegatronEmbedding)
                ):
                    params = model.parameters()
                    param_count = sum(p.numel() for p in params)
                    layers.append(LayerInfo(model.layer_id, model.name, param_count))

            for child in model.child_modules():
                traverse_model(child)

        traverse_model(model)

        return layers

    def _get_total_params(self):
        total_params = 0
        moe_param_count = 0
        layers = self.get_model_details()
        for layer in layers:
            total_params += layer.param_count
            if "moe" in layer.layer_name:
                moe_param_count += layer.param_count

        return total_params, moe_param_count

    def workload_generate_aiob(self):
        # args.world_size --> total gpus number
        self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.dp_num)
        if self.ga_num < 1:
            print(
                "[WARN]: ga num < 1, please confirm global_batch num and micro_batch num"
            )
        default_compute_time = 1
        compute_time = 0
        tp_comm_size = (
            2 * self.args.micro_batch * self.args.seq_length * self.args.hidden_size
        )
        layers = self.get_model_details()
        total_params, moe_param_count = self._get_total_params()
        # self.workload.append(Work_Item(name="norm", forward_compute_time=0,
        #                         forward_comm = "BROADCAST", forward_comm_size= 8*self.args.micro_batch*self.args.seq_length,
        #                         backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
        #                         dp_compute_time=default_compute_time, dp_comm="NONE", dp_comm_size=0
        #                         ))
        forward_compute_time = _get_aiob_compute_time(
            self.compute_cache, "forward", "grad"
        )
        backward_compute_time = _get_aiob_compute_time(
            self.compute_cache, "backward", "grad"
        )
        self.workload.append(
            Work_Item(
                name="grad_gather",
                forward_compute_time=default_compute_time,
                forward_comm="NONE",
                forward_comm_size=0,
                backward_compute_time=default_compute_time,
                backward_comm="NONE",
                backward_comm_size=0,
                dp_compute_time=default_compute_time,
                dp_comm="ALLGATHER",
                dp_comm_size=2 * (total_params-moe_param_count),
            )
        )
        self.workload.append(
            Work_Item(
                name="grad_param_comm",
                forward_compute_time=default_compute_time,
                forward_comm="NONE",
                forward_comm_size=0,
                backward_compute_time=default_compute_time,
                backward_comm="NONE",
                backward_comm_size=0,
                dp_compute_time=default_compute_time,
                dp_comm="REDUCESCATTER",
                dp_comm_size=4 * (total_params-moe_param_count),
            )
        )
        self.workload.append(
            Work_Item(
                name="grad_param_compute",
                forward_compute_time=default_compute_time,
                forward_comm="NONE",
                forward_comm_size=0,
                backward_compute_time=forward_compute_time + backward_compute_time,
                backward_comm="NONE",
                backward_comm_size=0,
                dp_compute_time=default_compute_time,
                dp_comm="NONE",
                dp_comm_size=0,
            )
        )

        if not self.args.enable_sequence_parallel:
            self.workload.append(
                Work_Item(
                    name="layernorm",
                    forward_compute_time=default_compute_time,
                    forward_comm="NONE",
                    forward_comm_size=0,
                    backward_compute_time=default_compute_time,
                    backward_comm="ALLREDUCE",
                    backward_comm_size=2 * total_params,
                    dp_compute_time=default_compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )
        if args.tensor_model_parallel_size == 1 :
            emd_backward_comm = "NONE"
        else:
            emd_backward_comm = "ALLREDUCE"
        self.workload.append(
            Work_Item(
                name="embedding_grads",
                forward_compute_time=default_compute_time,
                forward_comm="NONE",
                forward_comm_size=0,
                backward_compute_time=default_compute_time,
                backward_comm=emd_backward_comm,
                backward_comm_size=tp_comm_size,
                dp_compute_time=default_compute_time,
                dp_comm="NONE",
                dp_comm_size=0,
            )
        )
        if self.args.expert_model_parallel_size != self.args.dp_num:
            self.workload.append(Work_Item(name="moe_grad_norm1", forward_compute_time=default_compute_time,
                                    forward_comm = "NONE", forward_comm_size= 0,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
                                    dp_compute_time=default_compute_time, dp_comm="ALLGATHER_DP_EP", dp_comm_size=2*moe_param_count
                                    ))
            self.workload.append(Work_Item(name="moe_grad_norm2", forward_compute_time=default_compute_time,
                                    forward_comm = "NONE", forward_comm_size= 0,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
                                    dp_compute_time=default_compute_time, dp_comm="REDUCESCATTER_DP_EP", dp_comm_size=4*moe_param_count
                                    ))
        for _ in range(self.ga_num):
            for layer in layers:
                name = layer.layer_name
                forward_comm = backward_comm = backward_comm_2 = "NONE"
                forward_comm_size = tp_comm_size
                emb_comm_size = tp_comm_size
                backward_comm_size = 0
                dp_comm = "NONE"
                dp_comm_size = 0
                if self.args.enable_sequence_parallel:
                    if "embedding" in name:
                        if args.tensor_model_parallel_size == 1 :
                            forward_comm = "NONE"
                            backward_comm = "NONE"
                        else:
                            forward_comm = "ALLREDUCE"
                            backward_comm = "NONE"
                        emb_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "", "embedding"
                        )
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=emb_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=emb_comm_size ,
                                backward_compute_time=default_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=backward_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )
                    if "row" in name:
                        
                        forward_compute_time = _get_aiob_compute_time(
                        self.compute_cache, "forward", name.split("_")[0]
                        )
                        backward_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "backward", name.split("_")[0]
                        )

                        if self.args.recompute_activations and 'attention' in name:
                            forward_compute_time *= 2
                        forward_compute_time = int(forward_compute_time / 2)
                        backward_compute_time = int(backward_compute_time / 2)
                        forward_comm_size_sp = tp_comm_size
                        if args.tensor_model_parallel_size == 1 :
                            forward_comm = "NONE"
                            backward_comm = "NONE"
                        else:
                            forward_comm = "REDUCESCATTER"
                            backward_comm = "ALLGATHER"
                        self.workload.append(
                                Work_Item(
                                    name=name,
                                    forward_compute_time=forward_compute_time,
                                    forward_comm=forward_comm,
                                    forward_comm_size=forward_comm_size,
                                    backward_compute_time=backward_compute_time,
                                    backward_comm=backward_comm,
                                    backward_comm_size=forward_comm_size_sp,#sp overlap allgather
                                    dp_compute_time=backward_compute_time,
                                    dp_comm=dp_comm,
                                    dp_comm_size=dp_comm_size,
                                )
                            )

                    elif "column" in name:
                        forward_compute_time = _get_aiob_compute_time(
                        self.compute_cache, "forward", name.split("_")[0]
                        )
                        backward_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "backward", name.split("_")[0]
                        )

                        if self.args.recompute_activations and 'attention' in name:
                            forward_compute_time *= 2
                        forward_compute_time = int(forward_compute_time / 2)
                        backward_compute_time = int(backward_compute_time / 2)
                        if args.tensor_model_parallel_size == 1 :
                            forward_comm = "NONE"
                            backward_comm = "NONE"
                            backward_comm_2 = "NONE"
                        else:
                            forward_comm = "ALLGATHER"
                            backward_comm = "REDUCESCATTER"
                            backward_comm_2 = "ALLGATHER"
                        self.workload.append(
                                Work_Item(
                                    name=name,
                                    forward_compute_time=forward_compute_time,
                                    forward_comm=forward_comm,
                                    forward_comm_size=forward_comm_size,
                                    backward_compute_time=backward_compute_time,
                                    backward_comm=backward_comm,
                                    backward_comm_size=backward_comm_size,
                                    dp_compute_time=backward_compute_time,
                                    dp_comm=dp_comm,
                                    dp_comm_size=dp_comm_size,
                                )
                            )
                    elif "moelayer" in name:
                        forward_compute_time = _get_aiob_compute_time(
                        self.compute_cache, "forward", name.split("_")[0]
                        )
                        backward_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "backward", name.split("_")[0]
                        )
                        if args.tensor_model_parallel_size == 1 :
                            forward_comm1 = "NONE"
                            forward_comm2 = "NONE"
                            forward_comm3 = "ALLTOALL_EP"
                            forward_comm4 = "NONE"
                            forward_comm5 = "NONE"
                            forward_comm6 = "ALLTOALL_EP"
                            forward_comm7 = "NONE"
                        else:
                            forward_comm1 = "ALLGATHER"
                            forward_comm2 = "ALLTOALL"
                            forward_comm3 = "ALLTOALL_EP"
                            forward_comm4 = "ALLGATHER"
                            forward_comm5 = "REDUCESCATTER"
                            forward_comm6 = "ALLTOALL_EP"
                            forward_comm7 = "ALLTOALL"
                        if args.expert_model_parallel_size != 1:
                            self.workload.append(Work_Item(name=name, forward_compute_time=forward_compute_time,
                                        forward_comm = forward_comm1, forward_comm_size= 2*self.mbs*self.seq_len*self.num_experts,
                                        backward_compute_time=backward_compute_time, backward_comm=forward_comm1, backward_comm_size=2*self.mbs*self.seq_len*self.num_experts,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm2, forward_comm_size= tp_comm_size//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm2, backward_comm_size=tp_comm_size//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm3, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm3, backward_comm_size=tp_comm_size*self.topk//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm4, forward_comm_size= tp_comm_size*self.topk,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm5, backward_comm_size=tp_comm_size*self.topk,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm5, forward_comm_size= tp_comm_size*self.topk,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm4, backward_comm_size=tp_comm_size*self.topk,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm6, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm6, backward_comm_size=tp_comm_size*self.topk//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm7, forward_comm_size= tp_comm_size//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm7, backward_comm_size=tp_comm_size//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                        else:
                            self.workload.append(Work_Item(name=name, forward_compute_time=forward_compute_time,
                                        forward_comm = forward_comm1, forward_comm_size= 2*self.mbs*self.seq_len*self.num_experts,
                                        backward_compute_time=backward_compute_time, backward_comm=forward_comm1, backward_comm_size=2*self.mbs*self.seq_len*self.num_experts,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm2, forward_comm_size= tp_comm_size//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm2, backward_comm_size=tp_comm_size//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm3, forward_comm_size=1,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm3, backward_comm_size=1,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm4, forward_comm_size= tp_comm_size*self.topk,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm4, backward_comm_size=tp_comm_size*self.topk,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm5, forward_comm_size= tp_comm_size*self.topk,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm4, backward_comm_size=tp_comm_size*self.topk,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm6, forward_comm_size=1,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm6, backward_comm_size=1,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                            self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                        forward_comm = forward_comm7, forward_comm_size= tp_comm_size//self.tp,
                                        backward_compute_time=default_compute_time, backward_comm=forward_comm7, backward_comm_size=tp_comm_size//self.tp,
                                        dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                        ))
                else:
                    if args.tensor_model_parallel_size == 1 :
                        forward_comm = "NONE"
                        backward_comm = "NONE"
                    else:

                        forward_comm = "ALLREDUCE"
                        backward_comm = "NONE"
                    if self.args.recompute_activations and 'attention' in name:
                        forward_compute_time *= 2
                    if "embedding" in name:
                        emb_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "", "embedding"
                        )
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=emb_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=forward_comm_size,
                                backward_compute_time=default_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=backward_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )
                    else:
                        forward_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "forward", name.split("_")[0]
                        )
                        backward_compute_time = _get_aiob_compute_time(
                            self.compute_cache, "backward", name.split("_")[0]
                        )
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=forward_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=forward_comm_size,
                                backward_compute_time=backward_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=backward_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )
            # compute_time = _get_aiob_compute_time(self.compute_cache, "forward", "embedding")
            # self.workload.append(Work_Item(name="embedding_norm", forward_compute_time=compute_time,
            #                         forward_comm = "ALLREDUCE", forward_comm_size= self.args.vocab_size*self.args.hidden_size*2,
            #                         backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
            #                         dp_compute_time=default_compute_time, dp_comm="NONE", dp_comm_size=0
            #                         ))
        for i in range(3):
            self.workload.append(
                Work_Item(
                    name="cross_entropy" + str(i + 1),
                    forward_compute_time=compute_time,
                    forward_comm="ALLREDUCE",
                    forward_comm_size=self.args.seq_length * self.args.micro_batch * 4,
                    backward_compute_time=compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )

        for i in range(4):
            self.workload.append(
                Work_Item(
                    name="optimizer" + str(i + 1),
                    forward_compute_time=compute_time,
                    forward_comm="ALLREDUCE",
                    forward_comm_size=4,
                    backward_compute_time=compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )

    def workload_generate(self):
        # args.world_size --> total gpus number
        self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.dp_num)
        if self.ga_num < 1:
            print(
                "[WARN]: ga num < 1, please confirm global_batch num and micro_batch num"
            )
        default_compute_time = 1
        compute_time = 0
        tp_comm_size = (
            2 * self.args.micro_batch * self.args.seq_length * self.args.hidden_size
        )
        layers = self.get_model_details()
        total_params, moe_param_count = self._get_total_params()
        # print(f"Total params is {total_params}, moe params is {moe_param_count}")
        # self.workload.append(Work_Item(name="norm", forward_compute_time=0,
        #                         forward_comm = "BROADCAST", forward_comm_size= 8*self.args.micro_batch*self.args.seq_length,
        #                         backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
        #                         dp_compute_time=default_compute_time, dp_comm="NONE", dp_comm_size=0
        #                         ))
        forward_compute_time = default_compute_time
        backward_compute_time = default_compute_time
        self.workload.append(
            Work_Item(
                name="grad_norm",
                forward_compute_time=forward_compute_time,
                forward_comm="ALLGATHER",
                forward_comm_size=2 * total_params,
                backward_compute_time=backward_compute_time,
                backward_comm="NONE",
                backward_comm_size=0,
                dp_compute_time=default_compute_time,
                dp_comm="REDUCESCATTER",
                dp_comm_size=4 * total_params,
            )
        )
        if not self.args.enable_sequence_parallel:
            self.workload.append(
                Work_Item(
                    name="layernorm",
                    forward_compute_time=default_compute_time,
                    forward_comm="NONE",
                    forward_comm_size=0,
                    backward_compute_time=default_compute_time,
                    backward_comm="ALLREDUCE",
                    backward_comm_size=2 * total_params,
                    dp_compute_time=default_compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )
        if args.expert_model_parallel_size != args.dp_num:
            self.workload.append(Work_Item(name="moe_grad_norm1", forward_compute_time=default_compute_time,
                                    forward_comm = "NONE", forward_comm_size= 0,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
                                    dp_compute_time=default_compute_time, dp_comm="ALLGATHER_DP_EP", dp_comm_size=2*moe_param_count
                                    ))
            self.workload.append(Work_Item(name="moe_grad_norm2", forward_compute_time=default_compute_time,
                                    forward_comm = "NONE", forward_comm_size= 0,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
                                    dp_compute_time=default_compute_time, dp_comm="REDUCESCATTER_DP_EP", dp_comm_size=4*moe_param_count
                                    ))
        for _ in range(self.ga_num):
            for layer in layers:
                name = layer.layer_name
                forward_comm = backward_comm = backward_comm_2 = "NONE"
                forward_comm_size = tp_comm_size
                backward_comm_size = tp_comm_size
                dp_comm = "NONE"
                dp_comm_size = 0
                if self.args.enable_sequence_parallel:
                    if "embedding" in name:
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=default_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=forward_comm_size,
                                backward_compute_time=default_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=backward_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )

                    if "row" in name:
                        if self.args.recompute_activations and 'attention' in name:
                            forward_comm_size *= 2
                        forward_comm = "REDUCESCATTER"
                        backward_comm = "ALLGATHER"
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm, forward_comm_size= forward_comm_size,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=tp_comm_size,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                    if "column" in name:
                        if self.args.recompute_activations and 'attention' in name:
                            forward_comm_size *= 2
                        forward_comm = "ALLGATHER"
                        forward_comm2 = "NONE"
                        backward_comm = "REDUCESCATTER"
                        backward_comm_2 = "ALLGATHER"
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm, forward_comm_size= forward_comm_size,
                                    backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                    if "moelayer" in name:
                        forward_comm1 = "ALLGATHER"
                        forward_comm2 = "ALLTOALL"
                        forward_comm3 = "ALLTOALL_EP"
                        forward_comm4 = "ALLGATHER"
                        forward_comm5 = "REDUCESCATTER"
                        forward_comm6 = "ALLTOALL_EP"
                        forward_comm7 = "ALLTOALL"
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm1, forward_comm_size= 2*self.seq_len*self.num_experts,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm1, backward_comm_size=2*self.seq_len*self.num_experts,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm2, forward_comm_size= tp_comm_size//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm2, backward_comm_size=tp_comm_size//self.tp,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm3, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm3, backward_comm_size=tp_comm_size*self.topk//self.tp,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm4, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm4, backward_comm_size=tp_comm_size*self.topk,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm5, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm4, backward_comm_size=tp_comm_size*self.topk//self.tp,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm6, forward_comm_size= tp_comm_size*self.topk//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm6, backward_comm_size=tp_comm_size*self.topk//self.tp,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        self.workload.append(Work_Item(name=name, forward_compute_time=default_compute_time,
                                    forward_comm = forward_comm7, forward_comm_size= tp_comm_size//self.tp,
                                    backward_compute_time=default_compute_time, backward_comm=forward_comm7, backward_comm_size=tp_comm_size//self.tp,
                                    dp_compute_time=default_compute_time, dp_comm=dp_comm, dp_comm_size=dp_comm_size
                                    ))
                        
                else:
                    forward_comm = "ALLREDUCE"
                    backward_comm = "ALLREDUCE"
                    if self.args.recompute_activations and 'attention' in name:
                        forward_comm_size *= 2
                    if "embedding" in name:
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=default_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=forward_comm_size,
                                backward_compute_time=default_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=backward_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )
                    else:
                        self.workload.append(
                            Work_Item(
                                name=name,
                                forward_compute_time=default_compute_time,
                                forward_comm=forward_comm,
                                forward_comm_size=forward_comm_size,
                                backward_compute_time=default_compute_time,
                                backward_comm=backward_comm,
                                backward_comm_size=backward_comm_size,
                                dp_compute_time=default_compute_time,
                                dp_comm=dp_comm,
                                dp_comm_size=dp_comm_size,
                            )
                        )
            self.workload.append(
                Work_Item(
                    name="embedding_norm",
                    forward_compute_time=default_compute_time,
                    forward_comm="ALLREDUCE",
                    forward_comm_size=self.args.vocab_size * self.args.hidden_size * 2,
                    backward_compute_time=default_compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=default_compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )
        for i in range(3):
            self.workload.append(
                Work_Item(
                    name="cross_entropy" + str(i + 1),
                    forward_compute_time=compute_time,
                    forward_comm="ALLREDUCE",
                    forward_comm_size=self.args.seq_length * self.args.micro_batch * 4,
                    backward_compute_time=compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )

        for i in range(4):
            self.workload.append(
                Work_Item(
                    name="optimizer" + str(i + 1),
                    forward_compute_time=compute_time,
                    forward_comm="ALLREDUCE",
                    forward_comm_size=4,
                    backward_compute_time=compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=compute_time,
                    dp_comm="NONE",
                    dp_comm_size=0,
                )
            )

    def dump_file(self, filename):
        filename = filename + ".txt"

        pp_comm_value = 2 * self.args.micro_batch * self.args.seq_length * self.args.hidden_size
        if self.args.enable_sequence_parallel:
            pp_comm_value /= self.args.tensor_model_parallel_size

        pp_comm = (
            f"pp_comm: {pp_comm_value}"
            if self.args.pipeline_model_parallel != 1
            else "pp_comm: 0"
        )
        with open(filename, "w") as f:
            f.write((
                f"HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: {self.args.tensor_model_parallel_size} "
                f"ep: {self.args.expert_model_parallel_size} "
                f"pp: {self.args.pipeline_model_parallel} "
                f"vpp: {self.args.num_layers} "
                f"ga: {self.ga_num} all_gpus: {self.args.world_size} "
                f"checkpoints: 0 checkpoint_initiates: 0 "
            ) + pp_comm + "\n")

            f.write(str(len(self.workload)) + "\n")
            for item in self.workload:
                f.write(
                    "\t".join([str(getattr(item, k)) for k in item.__dict__.keys()])
                    + "\n"
                )


class simAI_MicroTest:
    def __init__(self, args):
        self.args = args
        self.workload = []

    def _simAI_microtest_convert(self, comm_type):
        if comm_type == "all_reduce" or comm_type == "allreduce":
            return "ALLREDUCE"
        elif comm_type == "all_gather" or comm_type == "allgather":
            return "ALLGATHER"
        elif comm_type == "reduce_scatter" or comm_type == "reducescatter":
            return "REDUCESCATTER"
        elif comm_type == "all_to_all" or comm_type == "alltoall":
            return "ALLTOALL"
        else:
            return

    def workload_generator(self):
        curr_size = self.args.begin_size
        default_compute_time = 1
        while curr_size <= self.args.end_size:
            self.workload.append(
                Work_Item(
                    name="micro_test",
                    forward_compute_time=default_compute_time,
                    forward_comm="NONE",
                    forward_comm_size=0,
                    backward_compute_time=default_compute_time,
                    backward_comm="NONE",
                    backward_comm_size=0,
                    dp_compute_time=default_compute_time,
                    dp_comm=self._simAI_microtest_convert(self.args.test_comm),
                    dp_comm_size=curr_size,
                    process_time=1,
                )
            )
            curr_size *= 2

    def dump_file(self, filename):
        filename = filename + ".txt"
        with open(filename, "w") as f:
            if not self.args.multi_all_reduce_enable:
                f.write(f"MICRO" + "\n")
                f.write(str(len(self.workload)) + "\n")
                for item in self.workload:
                    f.write(
                        "\t".join([str(getattr(item, k)) for k in item.__dict__.keys()])
                        + "\n"
                    )
            else:
                f.write(
                    f"HYBRID_TRANSFORMER_FWD_IN_BCKWD	model_parallel_NPU_group: {self.args.tensor_model_parallel_size} \
                        expert_parallel_npu_group: {self.args.expert_model_parallel_size} pp: {self.args.pipeline_model_parallel} \
                        ga: {self.ga_num} all_gpus: {self.args.world_size} checkpoints: 0 checkpoint_initiates: 0"
                    + "\n"
                )
                f.write(str(len(self.workload)) + "\n")
                for item in self.workload:
                    f.write(
                        "\t".join([str(getattr(item, k)) for k in item.__dict__.keys()])
                        + "\n"
                    )


if __name__ == "__main__":
    args = get_params()
    print(args)
    model = MegatronModel(args)
    result_dir = "results/workload/"
    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)
    filename = f"{args.gpu_type}-{args.model_name}-world_size{args.world_size}-tp{args.tensor_model_parallel_size}-pp{args.pipeline_model_parallel}-ep{args.expert_model_parallel_size}-gbs{args.global_batch}-mbs{args.micro_batch}-seq{args.seq_length}-MOE-{args.moe_enable}-GEMM-{args.moe_grouped_gemm}-flash_attn-{args.use_flash_attn}"
    filepath = os.path.join(result_dir, filename)
    params = model.parameters()
    # work = SIMAI_workload(model, args, GPU_Tensor_core.A100, "gpt13B")
    # name_layers = work.workload_generate()
    # work.dump_file("test")
    print(sum(p.numel() for p in params))
    if args.aiob_enable:
        params = model.parameters()
        args.model_param = sum(p.numel() for p in params)
        if args.comp_filepath == None:

            comp_filepath = get_comp_out(args)

            compute_cache = extract_averages(comp_filepath,args)
        else:
            print("comp_filepath:", args.comp_filepath)
            comp_filepath = args.comp_filepath
            compute_cache = extract_averages(comp_filepath,args)

        print("compute_cache = {")
        for key, value in compute_cache.items():
            print(f"    '{key}' : {value},")
        print("}")
        work = SIMAI_workload(
            model, args,compute_cache
        )
        name_layers = work.workload_generate_aiob()

        work.dump_file(filepath)
        print("workload save in :", filepath)
    # print(args)
    else:

        work = SIMAI_workload(model, args, None)
        name_layers = work.workload_generate()
        work.dump_file(filepath)
        print(f"workload save in : {filepath}.txt")
