def get_goals_and_targets()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


def get_goals_and_targets(params):

    train_goals = getattr(params, "goals", [])
    train_targets = getattr(params, "targets", [])
    test_goals = getattr(params, "test_goals", [])
    test_targets = getattr(params, "test_targets", [])

    if params.train_data:

        train_data = pd.read_csv(params.train_data)

        # this line shuffles the rows of train data randomly with a random seed
        train_data = train_data.sample(frac=1, random_state=params.random_seed).reset_index(drop=True)

        train_targets = train_data["target"].tolist()[: params.n_train_data]
        if "goal" in train_data.columns:
            train_goals = train_data["goal"].tolist()[: params.n_train_data]
        else:
            train_goals = [""] * len(train_targets)
        if params.test_data and params.n_test_data > 0:
            test_data = pd.read_csv(params.test_data)
            test_targets = test_data["target"].tolist()[: params.n_test_data]
            if "goal" in test_data.columns:
                test_goals = test_data["goal"].tolist()[: params.n_test_data]
            else:
                test_goals = [""] * len(test_targets)
        elif params.n_test_data > 0:
            test_targets = train_data["target"].tolist()[params.n_train_data : params.n_train_data + params.n_test_data]
            if "goal" in train_data.columns:
                test_goals = train_data["goal"].tolist()[params.n_train_data : params.n_train_data + params.n_test_data]
            else:
                test_goals = [""] * len(test_targets)

    assert len(train_goals) == len(train_targets)
    assert len(test_goals) == len(test_targets)
    print("Loaded {} train goals".format(len(train_goals)))
    print("Loaded {} test goals".format(len(test_goals)))

    return train_goals, train_targets, test_goals, test_targets