in muss/mining/training.py [0:0]
def get_mbart_kwargs(dataset, language, use_access, use_short_name=False):
mbart_dir = prepare_mbart_model()
mbart_path = mbart_dir / 'model.pt'
# source_lang = f'{language}_XX'
# target_lang = f'{language}_XX'
source_lang = 'complex'
target_lang = 'simple'
kwargs = {
'dataset': dataset,
'metrics_coefs': [0, 1, 0],
'parametrization_budget': 128,
'predict_files': get_predict_files(language),
'preprocessors_kwargs': {
'SentencePiecePreprocessor': {
'sentencepiece_model_path': mbart_dir / 'sentence.bpe.model',
'tokenize_special_tokens': True,
},
},
'preprocess_kwargs': {
'dict_path': mbart_dir / 'dict.txt',
'source_lang': source_lang,
'target_lang': target_lang,
},
'train_kwargs': add_dicts(
{'ngpus': 8},
args_str_to_dict(
f'''--restore-file {mbart_path} --arch mbart_large --task translation_from_pretrained_bart --source-lang {source_lang} --target-lang {target_lang} --encoder-normalize-before --decoder-normalize-before --criterion label_smoothed_cross_entropy --label-smoothing 0.2 --dataset-impl mmap --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 --max-tokens 1024 --update-freq 2 --log-format simple --log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
--layernorm-embedding --ddp-backend no_c10d'''
),
), # noqa: E501
'generate_kwargs': args_str_to_dict(
f'''--task translation_from_pretrained_bart --source_lang {source_lang} --target-lang {target_lang} --batch-size 32 --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN''' # noqa: E501
),
'evaluate_kwargs': get_evaluate_kwargs(language),
}
if use_access:
kwargs['preprocessors_kwargs'] = add_dicts(
get_access_preprocessors_kwargs(language, use_short_name=use_short_name), kwargs['preprocessors_kwargs']
)
return kwargs