in BERT/main.py [0:0]
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
directory, filename = os.path.split(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
if 'train' in file_path:
cached_features_file = os.path.join(
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + 'train.pkl')
elif 'valid' in file_path:
cached_features_file = os.path.join(
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + 'valid.pkl')
else:
cached_features_file = os.path.join(
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + 'test.pkl')
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as handle:
self.examples = pickle.load(handle)
else:
logger.info("Creating features from dataset file at %s", directory)
self.file_dict = {
'train': list(find_files_by_extensions(args.train_dir, ['.txt', '.npy'])),
'eval': list(find_files_by_extensions(args.eval_dir, ['.txt', '.npy'])),
'test': list(find_files_by_extensions(args.test_dir, ['.txt', '.npy'])),
}
self.examples = []
# with open(file_path, encoding="utf-8") as f:
# text = f.read()
if 'train' in file_path:
files = self.file_dict['train']
elif 'valid' in file_path:
files = self.file_dict['eval']
else:
files = self.file_dict['test']
for path in files:
if path.endswith('.txt'):
with open(path, 'r', encoding='utf-8') as f:
events = f.read().strip().splitlines()
tokenized_text = list(map(lambda x: tokenizer.vocab[x], events))
else:
tokenized_text = np.load(path).tolist()
print('processing', path)
for i in range(0, len(tokenized_text), block_size):
sample = tokenized_text[i: i + block_size]
if len(sample) == block_size:
self.examples.append(tokenized_text[i: i + block_size])
else:
pad = np.ones(block_size) * tokenizer.pad_token_id
pad[:len(sample)] = sample
self.examples.append(pad)
logger.info("Saving features into cached file %s", cached_features_file)
with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)