def main()

in sampling_rcv2.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--threshold',
        default=0.8,
        type=float,
        dest='threshold',
        help='ratio of expected number of examples if uniform prior',
    )
    parser.add_argument(
        '--input-dir',
        dest='input_dir',
        help='directory of rcv2 stories to sample from',
    )
    parser.add_argument(
        '--output-dir',
        dest='output_dir',
        help='directory to store samples',
    )
    parser.add_argument(
        '--num-test',
        default=4000,
        type=int,
        dest='num_test',
        help='number of test examples',
    )
    parser.add_argument(
        '--num-dev',
        default=1000,
        type=int,
        dest='num_dev',
        help='number of dev examples',
    )
    parser.add_argument(
        '--min-num-train',
        default=1000,
        type=int,
        dest='min_num_train',
        help='minimal number of train examples',
    )
    args = parser.parse_args()

    class_prior = [0.25, 0.25, 0.25, 0.25]

    labels = ['C', 'E', 'G', 'M']
    class_prior_dict = dict(zip(labels, class_prior))

    if args.input_dir is None or args.output_dir is None:
        raise Exception(
            'Need to provide directory of RCV2 data and output directory.')

    for current_path, _, dialects in os.walk(args.input_dir):
        for dialect in dialects:
            generate_samples(
                os.sep.join([current_path, dialect]),
                args.output_dir,
                dialect,
                class_prior_dict,
                float(args.threshold),
                int(args.num_test),
                int(args.num_dev),
                int(args.min_num_train),
            )
            logger.info("Finished sampling {}".format(dialect))