in torchmoji/finetuning.py [0:0]
def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
self.batch_size = batch_size
self.epoch_size = epoch_size
self.upsample = upsample
np.random.seed(seed)
if upsample:
# Should only be used on binary class problems
assert len(y_in.shape) == 1
neg = np.where(y_in.numpy() == 0)[0]
pos = np.where(y_in.numpy() == 1)[0]
assert epoch_size % 2 == 0
samples_pr_class = int(epoch_size / 2)
else:
ind = range(len(y_in))
if not upsample:
# Randomly sample observations in a balanced way
self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
else:
# Randomly sample observations in a balanced way
sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
# Shuffle to avoid labels being in specific order
# (all negative then positive)
p = np.random.permutation(len(concat_ind))
self.sample_ind = concat_ind[p]
label_dist = np.mean(y_in.numpy()[self.sample_ind])
assert(label_dist > 0.45)
assert(label_dist < 0.55)