def get_mbart_kwargs()

in muss/mining/training.py [0:0]


def get_mbart_kwargs(dataset, language, use_access, use_short_name=False):
    mbart_dir = prepare_mbart_model()
    mbart_path = mbart_dir / 'model.pt'
    # source_lang = f'{language}_XX'
    # target_lang = f'{language}_XX'
    source_lang = 'complex'
    target_lang = 'simple'
    kwargs = {
        'dataset': dataset,
        'metrics_coefs': [0, 1, 0],
        'parametrization_budget': 128,
        'predict_files': get_predict_files(language),
        'preprocessors_kwargs': {
            'SentencePiecePreprocessor': {
                'sentencepiece_model_path': mbart_dir / 'sentence.bpe.model',
                'tokenize_special_tokens': True,
            },
        },
        'preprocess_kwargs': {
            'dict_path': mbart_dir / 'dict.txt',
            'source_lang': source_lang,
            'target_lang': target_lang,
        },
        'train_kwargs': add_dicts(
            {'ngpus': 8},
            args_str_to_dict(
                f'''--restore-file {mbart_path}  --arch mbart_large --task translation_from_pretrained_bart  --source-lang {source_lang} --target-lang {target_lang}  --encoder-normalize-before --decoder-normalize-before --criterion label_smoothed_cross_entropy --label-smoothing 0.2  --dataset-impl mmap --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 --dropout 0.3 --attention-dropout 0.1  --weight-decay 0.0 --max-tokens 1024 --update-freq 2 --log-format simple --log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
     --layernorm-embedding  --ddp-backend no_c10d'''
            ),
        ),  # noqa: E501
        'generate_kwargs': args_str_to_dict(
            f'''--task translation_from_pretrained_bart --source_lang {source_lang} --target-lang {target_lang} --batch-size 32 --langs ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN'''  # noqa: E501
        ),
        'evaluate_kwargs': get_evaluate_kwargs(language),
    }
    if use_access:
        kwargs['preprocessors_kwargs'] = add_dicts(
            get_access_preprocessors_kwargs(language, use_short_name=use_short_name), kwargs['preprocessors_kwargs']
        )
    return kwargs