in src/accelerate/utils/megatron_lm.py [0:0]
def get_train_valid_test_datasets_provider(self, accelerator):
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
dataset_args = {
"data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],
"splits_string": args.split,
"train_valid_test_num_samples": train_val_test_num_samples,
"seed": args.seed,
}
if args.model_type_name == "bert":
dataset_args.update(
{
"max_seq_length": args.seq_length,
"binary_head": args.bert_binary_head,
}
)
elif args.model_type_name == "gpt":
dataset_args.update(
{
"max_seq_length": args.seq_length,
}
)
elif args.model_type_name == "t5":
dataset_args.update(
{
"max_seq_length": args.encoder_seq_length,
"max_seq_length_dec": args.decoder_seq_length,
"dataset_type": "t5",
}
)
else:
raise ValueError(f"Unsupported model type: {args.model_type_name}")
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)
return train_ds, valid_ds, test_ds
if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:
return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function
try:
args = get_args()
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source
if args.model_type_name == "bert":
from pretrain_bert import train_valid_test_datasets_provider
train_valid_test_datasets_provider.is_distributed = True
return train_valid_test_datasets_provider
elif args.model_type_name == "gpt":
from pretrain_gpt import train_valid_test_datasets_provider
train_valid_test_datasets_provider.is_distributed = True
return train_valid_test_datasets_provider
elif args.model_type_name == "t5":
from pretrain_t5 import train_valid_test_datasets_provider
train_valid_test_datasets_provider.is_distributed = True
return train_valid_test_datasets_provider
except ImportError:
pass
return train_valid_test_datasets_provider