in DST/dataclass.py [0:0]
def __init__(self, model_name, tokenizer, data_path_prefix, shuffle_mode='shuffle_session_level',
data_mode='train', add_prefix=True, add_special_decoder_token=True, train_data_ratio=1.0):
'''
model_name: t5-small or t5-base or t5-large
data_path_prefix: where the data stores
shuffle_mode: turn level shuffle or session level shuffle
add_prefix: whether adding task-specifc prompt to drive the generation
add_special_decoder_token: whether add special decoder token for generation of each subtasks
<sos_b>, <eos_b> for belief state tracking
<sos_a>, <eos_a> for dialogue action prediction
<sos_r>, <eos_r> for response generation
'''
self.add_prefix = add_prefix
assert self.add_prefix in [True, False]
self.add_special_decoder_token = add_special_decoder_token
assert self.add_special_decoder_token in [True, False]
self.tokenizer = tokenizer
print ('Original Tokenizer Size is %d' % len(self.tokenizer))
self.special_token_list = self.add_sepcial_tokens()
print ('Tokenizer Size after extension is %d' % len(self.tokenizer))
self.pad_token_id = self.tokenizer.convert_tokens_to_ids(['<_PAD_>'])[0]
self.sos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<sos_context>'])[0]
self.eos_context_token_id = self.tokenizer.convert_tokens_to_ids(['<eos_context>'])[0]
# initialize bos_token_id, eos_token_id
self.model_name = model_name
assert self.model_name.startswith('t5')
from transformers import T5Config
t5config = T5Config.from_pretrained(model_name)
self.bos_token_id = t5config.decoder_start_token_id
self.eos_token_id = self.tokenizer.eos_token_id
self.bos_token = self.tokenizer.convert_ids_to_tokens([self.bos_token_id])[0]
self.eos_token = self.tokenizer.convert_ids_to_tokens([self.eos_token_id])[0]
print ('bos token is {}, eos token is {}'.format(self.bos_token, self.eos_token))
self.all_sos_token_id_list = []
for token in all_sos_token_list:
one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
self.all_sos_token_id_list.append(one_id)
print (self.tokenizer.convert_ids_to_tokens([one_id]))
print (len(self.all_sos_token_id_list))
self.all_eos_token_id_list = []
for token in all_eos_token_list:
one_id = self.tokenizer.convert_tokens_to_ids([token])[0]
self.all_eos_token_id_list.append(one_id)
print (self.tokenizer.convert_ids_to_tokens([one_id]))
print (len(self.all_eos_token_id_list))
if self.add_prefix:
bs_prefix_text = 'translate dialogue to belief state:'
self.bs_prefix_id = self.tokenizer.convert_tokens_to_ids(tokenizer.tokenize(bs_prefix_text))
else:
self.bs_prefix_id = []
import json
if data_mode == 'train':
train_json_path = data_path_prefix + '/multiwoz-fine-processed-train.json'
with open(train_json_path) as f:
train_raw_data = json.load(f)
self.train_data_ratio = train_data_ratio
assert self.train_data_ratio > 0
# few-shot learning
if self.train_data_ratio < 1.0:
print ('Few-shot training setup.')
few_shot_num = int(len(train_raw_data) * self.train_data_ratio) + 1
random.shuffle(train_raw_data)
# randomly select a subset of training data
train_raw_data = train_raw_data[:few_shot_num]
print ('Number of training sessions is {}'.format(few_shot_num))
print ('Tokenizing raw train data...')
train_data_id_list = self.tokenize_raw_data(train_raw_data)
self.train_data_list = self.flatten_data(train_data_id_list)
# record training data
self.train_id2session_dict = {}
self.train_dial_id_list = []
for item in self.train_data_list:
one_item_id = item['dial_id']
try:
self.train_id2session_dict[one_item_id].append(item)
except KeyError:
self.train_dial_id_list.append(one_item_id)
self.train_id2session_dict[one_item_id] = [item]
assert len(self.train_dial_id_list) == len(self.train_id2session_dict)
self.train_num = len(self.train_data_list)
elif data_mode == 'test':
train_raw_data = []
else:
raise Exception('Wrong Data Mode!!!')
dev_json_path = data_path_prefix + '/multiwoz-fine-processed-dev.json'
with open(dev_json_path) as f:
dev_raw_data = json.load(f)
print ('Tokenizing raw dev data...')
dev_data_id_list = self.tokenize_raw_data(dev_raw_data)
self.dev_data_list = self.flatten_data(dev_data_id_list)
test_json_path = data_path_prefix + '/multiwoz-fine-processed-test.json'
with open(test_json_path) as f:
test_raw_data = json.load(f)
print ('Tokenizing raw test data...')
test_data_id_list = self.tokenize_raw_data(test_raw_data)
self.test_data_list = self.flatten_data(test_data_id_list)
print ('The size of raw train, dev and test sets are %d, %d and %d' % \
(len(train_raw_data), len(dev_raw_data), len(test_raw_data)))
self.dev_num, self.test_num = len(self.dev_data_list), len(self.test_data_list)
if data_mode == 'train':
print ('train turn number is %d, dev turn number is %d, test turn number is %d' % \
(len(self.train_data_list), len(self.dev_data_list), len(self.test_data_list)))
self.shuffle_mode = shuffle_mode
self.ordering_train_data()
else:
pass