def split_train_test()

in clutrr/utils/utils.py [0:0]


def split_train_test(args, rows):
    # split training testing
    r1 = prob_dist(rows)
    indices = range(len(rows))
    mask_i = np.random.choice(indices,
                              int(len(indices) * args.train_test_split),
                              replace=False)
    test_indices = [i for i in indices if i not in set(mask_i)]
    train_indices = [i for i in indices if i in set(mask_i)]
    train_rows = [rows[ti] for ti in train_indices]
    r_train = prob_dist(train_rows)
    test_rows = [rows[ti] for ti in test_indices]
    r_test = prob_dist(test_rows)
    train_rows = [row[:-1] for row in train_rows]
    test_rows = [row[:-1] for row in test_rows]

    return train_rows, test_rows