in muss/mining/training.py [0:0]
def get_bart_kwargs(dataset, language, use_access, use_short_name=False, bart_model='bart.large'):
assert language == 'en'
bart_path = prepare_bart_model(bart_model) / 'model.pt'
arch = {
'bart.base': 'bart_base',
'bart.large': 'bart_large',
'bart.large.cnn': 'bart_large',
}[bart_model]
kwargs = {
'dataset': dataset,
'metrics_coefs': [0, 1, 0],
'parametrization_budget': 128,
'predict_files': get_predict_files(language),
'preprocessors_kwargs': {
'GPT2BPEPreprocessor': {},
},
'preprocess_kwargs': {'dict_path': GPT2BPEPreprocessor().dict_path},
'train_kwargs': {
'ngpus': 8,
'arch': arch,
'restore_file': bart_path,
'max_tokens': 4096,
'lr': 3e-05,
'warmup_updates': 500,
'truncate_source': True,
'layernorm_embedding': True,
'share_all_embeddings': True,
'share_decoder_input_output_embed': True,
'reset_optimizer': True,
'reset_dataloader': True,
'reset_meters': True,
'required_batch_size_multiple': 1,
'criterion': 'label_smoothed_cross_entropy',
'label_smoothing': 0.1,
'dropout': 0.1,
'attention_dropout': 0.1,
'weight_decay': 0.01,
'optimizer': 'adam',
'adam_betas': '(0.9, 0.999)',
'adam_eps': 1e-08,
'clip_norm': 0.1,
'lr_scheduler': 'polynomial_decay',
'max_update': 20000,
'skip_invalid_size_inputs_valid_test': True,
'find_unused_parameters': True,
},
'evaluate_kwargs': get_evaluate_kwargs(language),
}
if use_access:
kwargs['preprocessors_kwargs'] = add_dicts(
get_access_preprocessors_kwargs(language, use_short_name=use_short_name), kwargs['preprocessors_kwargs']
)
return kwargs