def _get_dataset()

in sagemaker/src/hf_train_deploy.py [0:0]


def _get_dataset(data_dir, data_file_name, text_column, label_column):
    """generate dataset for model training"""
    
    dataset = load_dataset('csv', data_files={'train': os.path.join(data_dir, data_file_name)})

    if not 'labels' in dataset['train'].column_names:
        dataset = dataset.rename_column(label_column, 'labels')
    if not 'text' in dataset['train'].column_names:
        dataset = dataset.rename_column(text_column, 'text')

    dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset))
    dataset.set_format('torch', columns=['labels', 'attention_mask', 'input_ids'])
    
    return dataset['train']