def flatten_data()

in DST/dataclass.py [0:0]


    def flatten_data(self, data):
        '''
            transform session data input turn data
            each item in session has length of (number of turns)
            each turn has the following keys: 
                dial_id: data id
                user: user input at this turn; 
                      e.g. '<sos_u> i am looking for an expensive restaurant in the centre . thank you . <eos_u>'
                usdx: delexicalized user input at this turn; 
                      e.g. '<sos_u> i am looking for an expensive restaurant in the centre . thank you . <eos_u>'
                resp: delexicialized system response; 
                      e.g. '<sos_r> there are several restaurant -s in the price range what type of food would you like to eat ? <eos_r>'
                bspn: belief state span;
                      e.g. '<sos_b> [restaurant] pricerange expensive area centre <eos_b>'
                bsdx: delexicialized belief state span;
                      e.g. '<sos_b> [restaurant] pricerange area <eos_b>'
                aspn: action span;
                      e.g. '<sos_a> [restaurant] [request] food <eos_a>'
                dspn: domain span;
                      e.g. '<sos_d> [restaurant] <eos_d>'
                pointer: e.g. [0, 0, 0, 1, 0, 0]
                turn_domain: e.g. ['[restaurant]']
                turn_num: the turn number in current session
                db: database result e.g. '<sos_db> [db_3] <eos_db>'
                bspn_reform: reformed belief state;
                      e.g. '<sos_b> [restaurant] pricerange = expensive , area = centre <eos_b>'
                bsdx_reform: reformed delexicialized belief state;
                      e.g. '<sos_b> [restaurant] pricerange , area <eos_b>'
                aspn_reform: reformed dialogue action;
                      e.g. '<sos_a> [restaurant] [request] food <eos_a>'
        '''
        data_list = []
        for session in data:
            one_dial_id = session[0]['dial_id']
            turn_num = len(session)
            previous_context = [] # previous context contains all previous user input and system response
            for turn_id in range(turn_num):
                curr_turn = session[turn_id]
                assert curr_turn['turn_num'] == turn_id # the turns should be arranged in order
                curr_user_input = curr_turn['user']
                curr_sys_resp = curr_turn['resp']
                curr_bspn = curr_turn['bspn']

                # construct belief state data
                bs_input = previous_context + curr_user_input
                bs_input = self.bs_prefix_id + [self.sos_context_token_id] + bs_input[-900:] + [self.eos_context_token_id]
                bs_output = curr_bspn

                data_list.append({'dial_id': one_dial_id,
                    'turn_num': turn_id,
                    'prev_context':previous_context,
                    'user': curr_turn['user'],
                    'usdx': curr_turn['usdx'],
                    'resp': curr_sys_resp,
                    'bspn': curr_turn['bspn'],
                    'bspn_reform': curr_turn['bspn_reform'],
                    'bsdx': curr_turn['bsdx'],
                    'bsdx_reform': curr_turn['bsdx_reform'],
                    'bs_input': bs_input,
                    'bs_output': bs_output
                    })
                # update context for next turn
                previous_context = previous_context + curr_user_input + curr_sys_resp
        return data_list