def get_datasets()

in sagemaker/07_tensorflow_distributed_training_data_parallelism/scripts/train.py [0:0]


def get_datasets():
    # Load dataset
    train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])

    # Preprocess train dataset
    train_dataset = train_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    train_features = {
        x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"]))

    # Preprocess test dataset
    test_dataset = test_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    test_features = {
        x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"]))

    if SDP_ENABLED:
        tf_train_dataset = tf_train_dataset.shard(sdp.size(), sdp.rank())
        tf_test_dataset = tf_test_dataset.shard(sdp.size(), sdp.rank())
    tf_train_dataset = tf_train_dataset.batch(args.train_batch_size, drop_remainder=True)
    tf_test_dataset = tf_test_dataset.batch(args.eval_batch_size, drop_remainder=True)

    return tf_train_dataset, tf_test_dataset