def __init__()

in torchmoji/finetuning.py [0:0]


    def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
        self.batch_size = batch_size
        self.epoch_size = epoch_size
        self.upsample = upsample

        np.random.seed(seed)

        if upsample:
            # Should only be used on binary class problems
            assert len(y_in.shape) == 1
            neg = np.where(y_in.numpy() == 0)[0]
            pos = np.where(y_in.numpy() == 1)[0]
            assert epoch_size % 2 == 0
            samples_pr_class = int(epoch_size / 2)
        else:
            ind = range(len(y_in))

        if not upsample:
            # Randomly sample observations in a balanced way
            self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
        else:
            # Randomly sample observations in a balanced way
            sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
            sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
            concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)

            # Shuffle to avoid labels being in specific order
            # (all negative then positive)
            p = np.random.permutation(len(concat_ind))
            self.sample_ind = concat_ind[p]

            label_dist = np.mean(y_in.numpy()[self.sample_ind])
            assert(label_dist > 0.45)
            assert(label_dist < 0.55)