in sagemaker/src/hf_train_deploy.py [0:0]
def _get_dataset(data_dir, data_file_name, text_column, label_column):
"""generate dataset for model training"""
dataset = load_dataset('csv', data_files={'train': os.path.join(data_dir, data_file_name)})
if not 'labels' in dataset['train'].column_names:
dataset = dataset.rename_column(label_column, 'labels')
if not 'text' in dataset['train'].column_names:
dataset = dataset.rename_column(text_column, 'text')
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset))
dataset.set_format('torch', columns=['labels', 'attention_mask', 'input_ids'])
return dataset['train']