def batch_generate()

in E2E_TOD/inference_utlis.py [0:0]


def batch_generate(model, one_inference_batch, ref_bs, ref_act, ref_db, input_contain_db, data):
    '''
        e2e evaluation: 
            ref_bs: False
            ref_act: False
            ref_db: False
            input_contain_db: True or False
            
            ************************************************************************************************
            In this case: bspn_gen, aspn_gen, resp_gen need to be generated, if input_with_db then the db 
                          should be queried from the database using the generated belief state
            ************************************************************************************************
            
        oracle evaluation:
            ref_bs: True
            ref_act: True
            ref_db: True
            input_contain_db: True or False
            
            ************************************************************************************************
            In this case: only the resp_gen need to be generated, if input_with_db then using the oracle db 
                          as input
            ************************************************************************************************
                          
        policy evaluation:
            ref_bs: True
            ref_act: False
            ref_db: True
            input_contain_db: True or False
            
            ************************************************************************************************
            In this case: the aspn and resp_gen need to be generated, if input_with_db then using the oracle db 
                          as input
            ************************************************************************************************

        This function deals with batch generation. In order to fully take advantage of batch inference,
        in each batch, we only generate one type of output. e.g. Given a batch of dialogue history, we 
        generate the corresponding belief state/dialogue action/system response for the given batch. The 
        specific type of output is decided by the input argument "generate_mode"
    '''

    is_cuda = next(model.parameters()).is_cuda
    if is_cuda: 
        #device = next(model.parameters()).device
        device = torch.device('cuda')
        if torch.cuda.device_count() > 1: # multi-gpu training 
            model = model.module
        else: # single gpu training
            pass
    else:
        device = 0

    reform_bs_and_act = False

    max_span_len, max_response_len = 80, 120
    tokenizer = data.tokenizer
    bs_batch, da_batch, nlg_batch, parse_dict_batch = one_inference_batch
    batch_size = len(parse_dict_batch)
    res_batch_parse_dict = parse_dict_batch

    if ref_bs == False and ref_act == False and ref_db == False:
        evaluation_setup = 'e2e'
    elif ref_bs == True and ref_act == True and ref_db == True:
        evaluation_setup = 'oracle'
    elif ref_bs == True and ref_act == False and ref_db == True:
        evaluation_setup = 'policy'
    else:
        raise Exception('Wrong Evaluation Setup.')

    if evaluation_setup == 'e2e':
        # first generate belief state
        bs_tensor, bs_mask = data.pad_batch(bs_batch)
        if is_cuda:
            bs_tensor = bs_tensor.cuda(device)
            bs_mask = bs_mask.cuda(device)
        batch_bs_text = model.batch_generate(bs_tensor, bs_mask, generate_mode='bs', max_decode_len=max_response_len)
        # the belief state sequence could be long
        batch_bs_restore_text = []
        for idx in range(batch_size):
            if reform_bs_and_act:
                one_bs_text = batch_bs_text[idx]
                res_batch_parse_dict[idx]['bspn_gen_reform'] = one_bs_text
                one_bs_restore_text = restore_text(one_bs_text, mode='bs')
                res_batch_parse_dict[idx]['bspn_gen'] = one_bs_restore_text
                batch_bs_restore_text.append(one_bs_restore_text)
            else:
                one_bs_text = batch_bs_text[idx]
                res_batch_parse_dict[idx]['bspn_gen'] = one_bs_text
        if reform_bs_and_act:
            batch_bs_text = batch_bs_restore_text
        else:
            pass

        if input_contain_db:
            # we need to query the db base
            batch_db_input_id_list = []
            for idx in range(batch_size):
                one_queried_db_result = \
                data.reader.bspan_to_DBpointer(batch_bs_text[idx], res_batch_parse_dict[idx]['turn_domain'])
                one_db_text = '<sos_db> ' + one_queried_db_result + ' <eos_db>' 
                #print (db_text)
                #print (one_db_text)
                one_db_token_id_input = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(one_db_text))
                batch_db_input_id_list.append(one_db_token_id_input)
        else:
            batch_db_input_id_list = []
            for _ in range(batch_size):
                batch_db_input_id_list.append([])

        # then we generate the dialogue action
        da_batch_input_id_list = []
        for idx in range(batch_size):
            one_da_input_id_list = da_batch[idx] + batch_db_input_id_list[idx]
            da_batch_input_id_list.append(one_da_input_id_list)
        da_tensor, da_mask = data.pad_batch(da_batch_input_id_list)
        if is_cuda:
            da_tensor = da_tensor.cuda(device)
            da_mask = da_mask.cuda(device)
        batch_da_text = model.batch_generate(da_tensor, da_mask, generate_mode='da', max_decode_len=max_span_len)
        for idx in range(batch_size):
            if reform_bs_and_act:
                res_batch_parse_dict[idx]['aspn_gen_reform'] = batch_da_text[idx]
                res_batch_parse_dict[idx]['aspn_gen'] = restore_text(batch_da_text[idx], mode='da')
            else:
                res_batch_parse_dict[idx]['aspn_gen'] = batch_da_text[idx]            

        # finally, we generate the response
        nlg_batch_input_id_list = []
        for idx in range(batch_size):
            one_nlg_input_id_list = nlg_batch[idx] + batch_db_input_id_list[idx]
            nlg_batch_input_id_list.append(one_nlg_input_id_list)
        nlg_tensor, nlg_mask = data.pad_batch(nlg_batch_input_id_list)
        if is_cuda:
            nlg_tensor = nlg_tensor.cuda(device)
            nlg_mask = nlg_mask.cuda(device)
        batch_nlg_text = model.batch_generate(nlg_tensor, nlg_mask, generate_mode='nlg', max_decode_len=max_response_len)
        for idx in range(batch_size):
            res_batch_parse_dict[idx]['resp_gen'] = batch_nlg_text[idx]

    elif evaluation_setup == 'policy': 
        # we need to generate the dialogue action and dialogue response
        # the da input already contains the ref db result
        da_tensor, da_mask = data.pad_batch(da_batch)
        if is_cuda:
            da_tensor = da_tensor.cuda(device)
            da_mask = da_mask.cuda(device)
        batch_da_text = model.batch_generate(da_tensor, da_mask, generate_mode='da', max_decode_len=max_span_len)
        for idx in range(batch_size):
            res_batch_parse_dict[idx]['aspn_gen'] = batch_da_text[idx]

        nlg_tensor, nlg_mask = data.pad_batch(nlg_batch)
        if is_cuda:
            nlg_tensor = nlg_tensor.cuda(device)
            nlg_mask = nlg_mask.cuda(device)
        batch_nlg_text = model.batch_generate(nlg_tensor, nlg_mask, generate_mode='nlg', max_decode_len=max_response_len)
        for idx in range(batch_size):
            res_batch_parse_dict[idx]['resp_gen'] = batch_nlg_text[idx]

    elif evaluation_setup == 'oracle':
        # we only need to generate the response
        # nlg_batch already contains the ref db result
        nlg_tensor, nlg_mask = data.pad_batch(nlg_batch)
        if is_cuda:
            nlg_tensor = nlg_tensor.cuda(device)
            nlg_mask = nlg_mask.cuda(device)
        batch_nlg_text = model.batch_generate(nlg_tensor, nlg_mask, generate_mode='nlg', max_decode_len=max_response_len)
        for idx in range(batch_size):
            res_batch_parse_dict[idx]['resp_gen'] = batch_nlg_text[idx]
    else:
        raise Exception('Wrong Evaluation Setup.')
    return res_batch_parse_dict