def get_data_splits()

in training/utils.py [0:0]


def get_data_splits(highlighted_dataset, non_highlighted_dataset, train_size, target_idx):

    train_idxs = list(range(len(highlighted_dataset)))
    shuffle(train_idxs)

    other_idxs = [i+len(highlighted_dataset) for i in range(len(non_highlighted_dataset))]
    shuffle(other_idxs)

    concatenated = ConcatDataset((highlighted_dataset, non_highlighted_dataset))
    concatenated_idxs = train_idxs + other_idxs

    # WARNING: HateSpeech highlighted samples are all POSITIVE!
    # I am putting all of them at the beginning, so that they are always chosen
    # Already shuffled, I do not need this thing

    # Compute train split

    train_split = []
    no_pos, no_neg = 0, 0
    for i in concatenated_idxs:

        if concatenated[i][target_idx] == 1:
            if no_pos < (train_size // 10) * 5:
                train_split.append(i)
                no_pos += 1
        elif concatenated[i][target_idx] != 1:
            if no_neg < (train_size // 10) * 5:
                train_split.append(i)
                no_neg += 1

        if len(train_split) == train_size:
            break

    concatenated_idxs_no_train = diff(concatenated_idxs, train_split)
    shuffle(concatenated_idxs_no_train)

    valid_split = []
    no_pos, no_neg = 0, 0
    for i in concatenated_idxs_no_train:

        # Valid size should be the same as train size
        if concatenated[i][target_idx] == 1:
            if no_pos < (train_size // 10) * 5:
                valid_split.append(i)
                no_pos += 1
        elif concatenated[i][target_idx] != 1:
            if no_neg < (train_size // 10) * 5:
                valid_split.append(i)
                no_neg += 1

        if len(valid_split) == train_size:
            break

    # Compute difference to get the remaining data points as the validation set
    unlabelled_split = diff(concatenated_idxs_no_train, valid_split)

    assert not (set(train_split) & set(valid_split) & set(unlabelled_split))

    print(train_split, valid_split)
    print(len(train_split), len(valid_split), len(unlabelled_split))
    return concatenated, train_split, valid_split, unlabelled_split