in torchbenchmark/models/attention_is_all_you_need_pytorch/preprocess.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-raw_dir', required=True)
parser.add_argument('-data_dir', required=True)
parser.add_argument('-codes', required=True)
parser.add_argument('-save_data', required=True)
parser.add_argument('-prefix', required=True)
parser.add_argument('-max_len', type=int, default=100)
parser.add_argument('--symbols', '-s', type=int, default=32000, help="Vocabulary size")
parser.add_argument(
'--min-frequency', type=int, default=6, metavar='FREQ',
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
parser.add_argument('--dict-input', action="store_true",
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
parser.add_argument(
'--separator', type=str, default='@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
parser.add_argument('--total-symbols', '-t', action="store_true")
opt = parser.parse_args()
# Create folder if needed.
mkdir_if_needed(opt.raw_dir)
mkdir_if_needed(opt.data_dir)
# Download and extract raw data.
raw_train = get_raw_files(opt.raw_dir, _TRAIN_DATA_SOURCES)
raw_val = get_raw_files(opt.raw_dir, _VAL_DATA_SOURCES)
raw_test = get_raw_files(opt.raw_dir, _TEST_DATA_SOURCES)
# Merge files into one.
train_src, train_trg = compile_files(opt.raw_dir, raw_train, opt.prefix + '-train')
val_src, val_trg = compile_files(opt.raw_dir, raw_val, opt.prefix + '-val')
test_src, test_trg = compile_files(opt.raw_dir, raw_test, opt.prefix + '-test')
# Build up the code from training files if not exist
opt.codes = os.path.join(opt.data_dir, opt.codes)
if not os.path.isfile(opt.codes):
sys.stderr.write(f"Collect codes from training data and save to {opt.codes}.\n")
learn_bpe(raw_train['src'] + raw_train['trg'], opt.codes, opt.symbols, opt.min_frequency, True)
sys.stderr.write(f"BPE codes prepared.\n")
sys.stderr.write(f"Build up the tokenizer.\n")
with codecs.open(opt.codes, encoding='utf-8') as codes:
bpe = BPE(codes, separator=opt.separator)
sys.stderr.write(f"Encoding ...\n")
encode_files(bpe, train_src, train_trg, opt.data_dir, opt.prefix + '-train')
encode_files(bpe, val_src, val_trg, opt.data_dir, opt.prefix + '-val')
encode_files(bpe, test_src, test_trg, opt.data_dir, opt.prefix + '-test')
sys.stderr.write(f"Done.\n")
field = Field(
tokenize=str.split,
lower=True,
pad_token=Constants.PAD_WORD,
init_token=Constants.BOS_WORD,
eos_token=Constants.EOS_WORD)
fields = (field, field)
MAX_LEN = opt.max_len
def filter_examples_with_length(x):
return len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN
enc_train_files_prefix = opt.prefix + '-train'
train = TranslationDataset(
fields=fields,
path=os.path.join(opt.data_dir, enc_train_files_prefix),
exts=('.src', '.trg'),
filter_pred=filter_examples_with_length)
from itertools import chain
field.build_vocab(chain(train.src, train.trg), min_freq=2)
data = { 'settings': opt, 'vocab': field, }
opt.save_data = os.path.join(opt.data_dir, opt.save_data)
print('[Info] Dumping the processed data to pickle file', opt.save_data)
pickle.dump(data, open(opt.save_data, 'wb'))