in data.py [0:0]
def get_train_val_test_data(data_params, env_params, batch_size, device, sort_dict):
corpus = _build_corpus(data_params['data_path'], env_params, sort_dict)
data_params['vocab_size'] = corpus.vocab_size
train_data, val_data, test_data = _get_train_val_test_data(
corpus=corpus, batch_size=batch_size)
if env_params['distributed']:
# split the data into equal parts
assert batch_size % env_params['world_size'] == 0
device_batch_size = batch_size // env_params['world_size']
slice_data = slice(
device_batch_size * env_params['rank'],
device_batch_size * (env_params['rank'] + 1))
train_data = train_data[slice_data]
val_data = val_data[slice_data]
test_data = test_data[slice_data]
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)
return train_data, val_data, test_data