in training/utils.py [0:0]
def get_data_splits(highlighted_dataset, non_highlighted_dataset, train_size, target_idx):
train_idxs = list(range(len(highlighted_dataset)))
shuffle(train_idxs)
other_idxs = [i+len(highlighted_dataset) for i in range(len(non_highlighted_dataset))]
shuffle(other_idxs)
concatenated = ConcatDataset((highlighted_dataset, non_highlighted_dataset))
concatenated_idxs = train_idxs + other_idxs
# WARNING: HateSpeech highlighted samples are all POSITIVE!
# I am putting all of them at the beginning, so that they are always chosen
# Already shuffled, I do not need this thing
# Compute train split
train_split = []
no_pos, no_neg = 0, 0
for i in concatenated_idxs:
if concatenated[i][target_idx] == 1:
if no_pos < (train_size // 10) * 5:
train_split.append(i)
no_pos += 1
elif concatenated[i][target_idx] != 1:
if no_neg < (train_size // 10) * 5:
train_split.append(i)
no_neg += 1
if len(train_split) == train_size:
break
concatenated_idxs_no_train = diff(concatenated_idxs, train_split)
shuffle(concatenated_idxs_no_train)
valid_split = []
no_pos, no_neg = 0, 0
for i in concatenated_idxs_no_train:
# Valid size should be the same as train size
if concatenated[i][target_idx] == 1:
if no_pos < (train_size // 10) * 5:
valid_split.append(i)
no_pos += 1
elif concatenated[i][target_idx] != 1:
if no_neg < (train_size // 10) * 5:
valid_split.append(i)
no_neg += 1
if len(valid_split) == train_size:
break
# Compute difference to get the remaining data points as the validation set
unlabelled_split = diff(concatenated_idxs_no_train, valid_split)
assert not (set(train_split) & set(valid_split) & set(unlabelled_split))
print(train_split, valid_split)
print(len(train_split), len(valid_split), len(unlabelled_split))
return concatenated, train_split, valid_split, unlabelled_split