def main()

in model/generate.py [0:0]


def main(inference_cfg):
    if inference_cfg.EVENT.event_representation == "magenta":
        empty_bar_symbol = "TIME_SHIFT_100"
        if inference_cfg.SAMPLING.technique == "topk":
            if inference_cfg.SAMPLING.threshold:
                topk = int(inference_cfg.SAMPLING.threshold)
            else:
                topk = 32
        elif inference_cfg.SAMPLING.technique == "nucleus":
            if inference_cfg.SAMPLING.threshold:
                p = inference_cfg.SAMPLING.threshold
            else:
                p = 0.95

        elif inference_cfg.SAMPLING.technique == "random":
            topk = 310
    else:
        raise NotImplementedError(
            "Newevent representation generations are yet to be implemented"
        )

    model_fp = os.path.join(inference_cfg.MODEL.model_directory,
                            inference_cfg.MODEL.checkpoint_name)  ## Get the Model Path

    cfg_fp = os.path.join(inference_cfg.MODEL.model_directory, "config.yml")

    if not os.path.isdir(inference_cfg.OUTPUT.output_txt_directory):
        os.makedirs(inference_cfg.OUTPUT.output_txt_directory)
    ext = ".txt"
    device = torch.device("cuda" if inference_cfg.MODEL.device else "cpu")


    tokens_list, token2index = load_vocab(inference_cfg.EVENT.vocab_file_path)
    perform_vocab = BaseVocab(tokens_list)

    nvocab = len(perform_vocab)

    # Encode empty bar token
    empty_bar_token = token2index[empty_bar_symbol]
    # Generate.
    # Params for generation

    cfg = get_default_cfg_training()
    # The following try to merge the configurations from yaml file,
    # and since we have "", which is integrated as None and can not be read by
    # yacs, we have the following block "try: except:" to read from list rather than merge from file.
    try:
        cfg.merge_from_file(cfg_fp)
    except Exception as e:
        print('*' * 100)
        print("Note, if you are loading an old config.yml file which includes None inside,\n"
              " please change it to a string 'None' to make sure you can do cfg.merge_from_file.\n"
              "e.g. cfg.DISCRIMINATOR.type , cfg.TRAIN.pad_type and cfg.TRAIN.load_from_previous.\n"
              "and please note DISCRIMINATOR.temperature is DISCRIMINATOR.beta_max\n")
        print('*' * 100)
        raise e
    cfg.defrost()
    cfg.DISCRIMINATOR.type = "Null"  # cnn for cnn distriminator or Null for no discriminator or 'bert' for BERT
    # discriminator
    cfg.MODEL.same_length = True  # Needed for same_length =True during evaluation
    cfg.freeze()

    if cfg.TRAIN.append_note_status:
        perform_vocab.notes_mapping()

    model = TransformerGAN(cfg, perform_vocab)

    checkpoint = torch.load(model_fp)
    trimmed_checkpoint = {}
    for key, val in checkpoint["model"].items():
        if 'generator' in key:
            new_key = key.replace('generator.', '')
            trimmed_checkpoint[new_key] = val
    model.generator.load_state_dict(trimmed_checkpoint, strict=False)

    # checkpoint = torch.load(model_fp)
    # model.load_state_dict(checkpoint["model"])

    model = model.to(device)
    model.eval()
    model.generator.reset_length(1, inference_cfg.MODEL.memory_length)

    # Load a conditional file for time_extension
    if inference_cfg.INPUT.time_extension:
        conditional_data = np.load(inference_cfg.INPUT.conditional_input_melody).tolist()  # inference_cfg.prefix
        print('* Loaded conditional file {}'.format(inference_cfg.INPUT.conditional_input_melody))
        num_conditional_tokens = inference_cfg.INPUT.num_conditional_tokens
        if inference_cfg.GENERATION.duration_based:
            duration = 0
            for num_conditional_tokens, conditional_index in enumerate(conditional_data):
                token_duration = get_duration_from_token(inference_cfg.EVENT.event_representation, conditional_index,
                                                         tokens_list)
                if token_duration:
                    duration += token_duration  # 10 ms
                if duration >= inference_cfg.INPUT.conditional_duration:
                    break
            # Note, when the conditional duration is longer than the duration of the conditional file,
            # all conditional file will be used.
            print('* Total number of tokens used for condition is {} for duration {}'.format(num_conditional_tokens,
                                                                                             duration))
        else:
            num_conditional_tokens = min(num_conditional_tokens, len(conditional_data))
            print('* Total number of tokens used for condition is {}'.format(num_conditional_tokens))

        with open(os.path.join(inference_cfg.OUTPUT.output_txt_directory, 'prefix' + ext), "w") as f:
            f.write("\n".join(tokens_list[t] for t in conditional_data[:num_conditional_tokens]))
        with open(os.path.join(inference_cfg.OUTPUT.output_txt_directory, 'full' + ext), "w") as f:
            f.write("\n".join(tokens_list[t] for t in conditional_data[:]))


    for midi_file in range(inference_cfg.INPUT.num_midi_files):
        out_fn = str(midi_file) + ext
        out_fp = os.path.join(inference_cfg.OUTPUT.output_txt_directory, out_fn)
        if cfg.TRAIN.replace_start_with_pad:
            seq = [token2index['<PAD>']]  # Pad Token
        else:
            seq = [token2index['<S>']]  # Start Token

        mems = None
        status_vec = None
        with torch.no_grad():
            print("Generating the Midi File Number: " + str(midi_file + 1))
            if inference_cfg.INPUT.time_extension and num_conditional_tokens >= 1:
                # check the argument to do time extension based on a conditional file
                # The time extension model is only activated when inference_cfg.INPUT.time_extension=True
                # and given conditional_len greater than 1
                context = np.array(seq + conditional_data[:num_conditional_tokens - 1], dtype=np.int32)[:, np.newaxis]
                context = torch.from_numpy(context).to(device).type(torch.long)
                if cfg.TRAIN.append_note_status:
                    status_vec = context.new_zeros((context.shape[0], 1, perform_vocab.vec_len), dtype=torch.bool)
                    perform_vocab.update_status_vec(context, status_vec)
                ret = model.generator.forward_generate(context, mems, status_vec=status_vec)
                _, mems = ret
                seq = seq + conditional_data[:num_conditional_tokens]

            if inference_cfg.GENERATION.duration_based:
                duration, generation_length = 0, inference_cfg.GENERATION.max_generation_length
            else:
                generation_length = inference_cfg.GENERATION.generation_length

            for _ in range(generation_length):
                if inference_cfg.GENERATION.duration_based:
                    token_duration = get_duration_from_token(inference_cfg.EVENT.event_representation, seq[-1],
                                                             tokens_list)
                    if token_duration:
                        duration += token_duration
                    if duration >= inference_cfg.GENERATION.generation_duration:
                        break
                # Create input array
                inp = np.array([seq[-1]], dtype=np.int32)[:, np.newaxis]
                inp = torch.from_numpy(inp).to(device).type(torch.long)
                if cfg.TRAIN.append_note_status:
                    bptt, batch_size = inp.shape
                    if status_vec is None:
                        status_vec = inp.new_zeros((bptt, batch_size, perform_vocab.vec_len), dtype=torch.bool)
                    else:
                        status_vec = status_vec[-1:, :, :]
                    perform_vocab.update_status_vec(inp, status_vec)
                ret = model.generator.forward_generate(inp, mems, status_vec=status_vec)
                all_logits, mems = ret
                # Select last timestep, only batch item
                logits = all_logits[-1, 0]

                if inference_cfg.INPUT.exclude_bos_token:
                    logits = logits[1:]

                if inference_cfg.INPUT.num_empty_tokens_to_ignore:
                    # check the total number of empty token (TIME_SHITF_100) in generated sequence
                    # and when the total number of tokens reach a certain number, stops sampling TIME_SHITF_100
                    if np.all(np.asarray(seq[-inference_cfg.INPUT.num_empty_tokens_to_ignore:]) == empty_bar_token):
                        if inference_cfg.INPUT.exclude_bos_token:
                            logits = torch.cat(
                                [logits[:empty_bar_token - 1], logits[empty_bar_token:]], 0
                            )
                        else:
                            logits = torch.cat(
                                [logits[:empty_bar_token], logits[empty_bar_token + 1:]], 0
                            )

                # Handle temp 0 (argmax) case
                if inference_cfg.SAMPLING.temperature == 0:
                    probs = torch.zeros_like(logits)
                    probs[logits.argmax()] = 1.0
                else:
                    # Apply temperature spec
                    logits /= inference_cfg.SAMPLING.temperature

                    # Compute softmax
                    probs = F.softmax(logits, dim=-1)

                if inference_cfg.INPUT.exclude_bos_token:
                    probs = F.pad(probs, [1, 0])

                if inference_cfg.INPUT.num_empty_tokens_to_ignore:
                    if np.all(np.asarray(seq[-inference_cfg.INPUT.num_empty_tokens_to_ignore:]) == empty_bar_token):
                        probs = torch.cat([probs[:empty_bar_token], F.pad(probs[empty_bar_token:], [1, 0])], 0)

                if inference_cfg.SAMPLING.technique == "topk" or inference_cfg.SAMPLING.technique == "random":

                    if inference_cfg.SAMPLING.technique == "random":
                        pass

                    elif topk is not None:
                        _, top_idx = torch.topk(probs, topk)
                        mask = torch.zeros_like(probs)
                        mask[top_idx] = 1.0
                        probs *= mask
                        probs /= probs.sum()

                elif inference_cfg.SAMPLING.technique == "nucleus":
                    if p > 0:
                        sorted_probs, sorted_indices = torch.sort(
                            probs, descending=True
                        )
                        cumulative_probs = torch.cumsum(sorted_probs, dim=0)
                        # Remove tokens with cumulative probability above the threshold
                        sorted_indices_to_remove = cumulative_probs >= p
                        # Shift the indices to the right to keep also the first token above the threshold

                        sorted_indices_to_remove[1:] = sorted_indices_to_remove[
                                                       :-1
                                                       ].clone()
                        sorted_indices_to_remove[0] = 0
                        # scatter sorted tensors to original indexing
                        indices_to_remove = sorted_indices_to_remove.scatter(
                            dim=0, index=sorted_indices, src=sorted_indices_to_remove
                        )
                        probs[indices_to_remove] = 0
                        probs /= probs.sum()

                else:
                    raise NotImplementedError(
                        "Other Sampling strategies are yet to be implemented"
                    )
                # Sample from probabilities
                token = torch.multinomial(probs, 1)
                token = int(token.item())
                seq.append(token)

            with open(out_fp, "w") as f:
                f.write("\n".join(tokens_list[t] for t in seq[1:]))

            if inference_cfg.MODEL.debug:
                # Ignore last element in seq so that len(mems) is same
                status_vec = None
                data = np.array(seq[:-1], dtype=np.int32)[:, np.newaxis]
                data = torch.from_numpy(data).to(device).type(torch.long)

                if cfg.TRAIN.append_note_status:
                    status_vec = data.new_zeros((data.shape[0], 1, perform_vocab.vec_len), dtype=torch.bool)
                    perform_vocab.update_status_vec(data, status_vec)
                ret = model.generator.forward_generate(data, None, status_vec=status_vec)
                _, new_mems = ret

                assert all(
                    [
                        torch.allclose(i, j, atol=1e-4)
                        for i, j in zip(new_mems, mems)
                    ]
                )
                print("Mem same")
                # This time-shift debug needs to be placed after above
                if inference_cfg.INPUT.time_extension and num_conditional_tokens >= 1:
                    # check the argument to do time extension based on a conditional file
                    # The time extension model is only activated when args.time_extension=True
                    # and given conditional_len greater than 1
                    if cfg.TRAIN.replace_start_with_pad:
                        input_index = token2index['<PAD>']
                    else:
                        input_index = token2index['<S>']
                    nll = 0.
                    status_vec = None
                    for i in range(num_conditional_tokens):
                        target = conditional_data[i]
                        input_index = np.array([input_index], dtype=np.int32)[:, np.newaxis]
                        input_index = torch.from_numpy(input_index).to(device).type(torch.long)
                        if cfg.TRAIN.append_note_status:
                            if status_vec is None:
                                bptt, batch_size = input_index.shape
                                status_vec = inp.new_zeros((bptt, batch_size, perform_vocab.vec_len), dtype=torch.bool)
                            else:
                                status_vec = status_vec[-1:, :, :]
                            perform_vocab.update_status_vec(input_index, status_vec)
                        ret = model.generator.forward_generate(input_index, None,
                                                               status_vec=status_vec)
                        all_logits, _ = ret
                        logits = all_logits[-1, 0]
                        probs = F.softmax(logits, dim=-1)
                        target_prob = probs[target].cpu().item()
                        nll += -np.log(target_prob)
                        input_index = target

                    print('Prime NLL: {}, Prime PPL: {}'.format(nll / num_conditional_tokens,
                                                                np.exp(nll / num_conditional_tokens)))

                with open(os.path.join(inference_cfg.OUTPUT.output_txt_directory, "inference.yml"), "w") as f:
                    f.write(str(inference_cfg))