def get_comp_out()

in utils/utils.py [0:0]


def get_comp_out(args):
    vocab_size = args.vocab_size
    batch_size = args.micro_batch
    seq_len = args.seq_length
    tp = args.tensor_model_parallel_size
    vocab_size = args.padded_vocab_size
    if "Megatron" in args.frame:
        device = torch.cuda.current_device()
        from workload_generator.mocked_model.AiobMegatron import MegatronModel

        measure_model = MegatronModel(args)
        measure_model.train()
        if args.dtype == "bfloat16":
            dtype = torch.bfloat16
        elif args.dtype == "float16":
            dtype = torch.float16
        else:
            dtype = torch.float32
        # total_input_1 = torch.rand(args.seq_len,
        #                                       args.batch_size,
        #                                       args.hidden_size,
        #                                       device=device).to(dtype)
        masked_input = torch.randint(
            0,
            math.ceil(vocab_size / tp),
            (batch_size, seq_len),
            device=device,
            dtype=torch.int64,
        )
        filepath = measure_model(masked_input)
        return filepath