def train_validate_test_split()

in use-cases/model-fine-tuning-pipeline/data-preparation/gemma-it/src/dataprep.py [0:0]


def train_validate_test_split(df):
    logger.info(f"Total Data Size: {len(df)}")
    train_size = int(0.8 * len(df))
    val_size = int(0.1 * len(df))
    train_df = df.sample(n=train_size, random_state=42)
    remaining_df = df.drop(train_df.index)
    val_df = remaining_df.sample(n=val_size, random_state=42)
    test_df = remaining_df.drop(val_df.index)
    logger.info(
        f"Training data size: {len(train_df)}, Validation data size: {len(val_df)}, Test data size: {len(test_df)}"
    )
    # Create DatasetDict with splits
    dataset = DatasetDict(
        {
            "train": Dataset.from_pandas(train_df),
            "validation": Dataset.from_pandas(val_df),
            "test": Dataset.from_pandas(test_df),
        }
    )
    dataset["train"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/training/")
    dataset["validation"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/validation/")
    dataset["test"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/test/")