in sampling_rcv2.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--threshold',
default=0.8,
type=float,
dest='threshold',
help='ratio of expected number of examples if uniform prior',
)
parser.add_argument(
'--input-dir',
dest='input_dir',
help='directory of rcv2 stories to sample from',
)
parser.add_argument(
'--output-dir',
dest='output_dir',
help='directory to store samples',
)
parser.add_argument(
'--num-test',
default=4000,
type=int,
dest='num_test',
help='number of test examples',
)
parser.add_argument(
'--num-dev',
default=1000,
type=int,
dest='num_dev',
help='number of dev examples',
)
parser.add_argument(
'--min-num-train',
default=1000,
type=int,
dest='min_num_train',
help='minimal number of train examples',
)
args = parser.parse_args()
class_prior = [0.25, 0.25, 0.25, 0.25]
labels = ['C', 'E', 'G', 'M']
class_prior_dict = dict(zip(labels, class_prior))
if args.input_dir is None or args.output_dir is None:
raise Exception(
'Need to provide directory of RCV2 data and output directory.')
for current_path, _, dialects in os.walk(args.input_dir):
for dialect in dialects:
generate_samples(
os.sep.join([current_path, dialect]),
args.output_dir,
dialect,
class_prior_dict,
float(args.threshold),
int(args.num_test),
int(args.num_dev),
int(args.min_num_train),
)
logger.info("Finished sampling {}".format(dialect))