def generate_rows()

in clutrr/generator.py [0:0]


def generate_rows(args, store, task_name, split=0.8, prev_patterns=None):
    # pre-flight checks
    combination_length = min(args.combination_length, args.relation_length)
    if not args.use_mturk_template:
        if combination_length > 1:
            raise NotImplementedError("combination of two or more relations not implemented in Synthetic templating")
    else:
        if combination_length > 3:
            raise NotImplementedError("combinations of > 3 not implemented in AMT Templating")
    # generate
    print(args.relation_length)
    print("Loading templates...")
    all_puzzles = {}
    if args.template_split:
        train_templates = json.load(open(args.template_file + '.train.json'))
        test_templates = json.load(open(args.template_file + '.test.json'))
    else:
        train_templates = json.load(open(args.template_file + '.json'))
        test_templates = json.load(open(args.template_file + '.json'))
    if args.use_mturk_template:
        templatorClass = TemplatorAMT
    else:
        synthetic_templates_per_rel = {}
        for key, val in store.relations_store.items():
            for gender, gv in val.items():
                synthetic_templates_per_rel[gv['rel']] = gv['p']
        templatorClass = TemplatorSynthetic
        train_templates = synthetic_templates_per_rel
        test_templates = synthetic_templates_per_rel

    # Build a mapping from ANY relation to the SAME list of sentences for asking queries
    query_templates = {}
    for key, val in store.relations_store.items():
        for gender, gv in val.items():
            query_templates[gv['rel']] = store.question_store['relational']
    query_templator_class = TemplatorSynthetic

    pb = tqdm(total=args.num_rows)
    num_stories = args.num_rows
    stories_left = num_stories
    columns = ['id', 'story', 'query', 'text_query', 'target', 'text_target', 'clean_story', 'proof_state', 'f_comb',
               'task_name','story_edges','edge_types','query_edge','genders', 'syn_story', 'node_mapping', 'task_split']
    f_comb_count = {}
    rows = []
    anc_num = 0
    anc_num += 1
    anc = Ancestry(args, store)
    rb = RelationBuilder(args, store, anc)
    while stories_left > 0:
        status = rb.build()
        if not status:
            rb.reset_puzzle()
            rb.anc.next_flip()
            continue
        rb.add_facts()
        # keeping a count of generated patterns to make sure we have homogenous distribution
        if len(f_comb_count) > 0 and args.equal:
            min_c = min([v for k,v in f_comb_count.items()])
            weight = {k:(min_c/v) for k,v in f_comb_count.items()}
            rb.generate_puzzles(weight)
        else:
            rb.generate_puzzles()
        # if unique_test_pattern flag is set, and split is 0 (which indicates the task is test),
        # only take the same test patterns as before
        # also assert that the relation - test is present
        if args.unique_test_pattern and split == 0 and len(prev_patterns) > 0 and len(prev_patterns[args.relation_length]['test']) > 0:
            # if all these conditions met, prune the puzzles
            todel = []
            for pid,puzzle in rb.puzzles.items():
                if puzzle.relation_comb not in prev_patterns[args.relation_length]['test']:
                    todel.append(pid)
            for pid in todel:
                del rb.puzzles[pid]
        # now we have got the puzzles, assign the templators
        for pid, puzzle in rb.puzzles.items():
            if puzzle.relation_comb not in f_comb_count:
                f_comb_count[puzzle.relation_comb] = 0
            f_comb_count[puzzle.relation_comb] += 1
            pb.update(1)
            stories_left -= 1
        # store the puzzles
        all_puzzles.update(rb.puzzles)
        rb.reset_puzzle()
        rb.anc.next_flip()
    pb.close()
    print("Puzzles created. Now splitting train and test on pattern level")
    print("Number of unique puzzles : {}".format(len(all_puzzles)))
    pattern_puzzles = {}
    for pid, pz in all_puzzles.items():
        if pz.relation_comb not in pattern_puzzles:
            pattern_puzzles[pz.relation_comb] = []
        pattern_puzzles[pz.relation_comb].append(pid)
    print("Number of unique patterns : {}".format(len(pattern_puzzles)))
    train_puzzles = []
    test_puzzles = []
    sp = int(len(pattern_puzzles) * split)
    all_patterns = list(pattern_puzzles.keys())

    no_pattern_overlap = not args.holdout
    # if k=2, then set no_pattern_overlap=True
    if args.relation_length == 2:
        no_pattern_overlap = True

    if not no_pattern_overlap:
        # for case > 3, strict no pattern overlap
        train_patterns = all_patterns[:sp]
        pzs = [pattern_puzzles[p] for p in train_patterns]
        pzs = [s for p in pzs for s in p]
        train_puzzles.extend(pzs)
        test_patterns = all_patterns[sp:]
        pzs = [pattern_puzzles[p] for p in test_patterns]
        pzs = [s for p in pzs for s in p]
        test_puzzles.extend(pzs)
    else:
        # for case of 2, pattern overlap but templators are different
        # In this case, we have overlapping patterns, first choose the overlapping patterns
        # we directly split on puzzle level
        train_patterns = all_patterns
        test_patterns = all_patterns[sp:]
        pzs_train = []
        pzs_test = []
        for pattern in all_patterns:
            pz = pattern_puzzles[pattern]
            if pattern in test_patterns:
                # now split - hacky way
                sz = int(len(pz) * (split - 0.2))
                pzs_train.extend(pz[:sz])
                pzs_test.extend(pz[sz:])
            else:
                pzs_train.extend(pz)
        train_puzzles.extend(pzs_train)
        test_puzzles.extend(pzs_test)

    print("# Train puzzles : {}".format(len(train_puzzles)))
    print("# Test puzzles : {}".format(len(test_puzzles)))
    pb = tqdm(total=len(all_puzzles))
    # saving in csv
    for pid, puzzle in all_puzzles.items():
        task_split = ''
        if pid in train_puzzles:
            task_split = 'train'
            templator = templatorClass(templates=train_templates, family=puzzle.anc.family_data)
        elif pid in test_puzzles:
            task_split = 'test'
            templator = templatorClass(templates=test_templates, family=puzzle.anc.family_data)
        else:
            AssertionError("pid must be either in train or test")
        story_text = puzzle.generate_text(stype='story', combination_length=combination_length, templator=templator)
        fact_text = puzzle.generate_text(stype='fact', combination_length=combination_length, templator=templator)
        story = story_text + fact_text
        story = random.sample(story, len(story))
        story = ' '.join(story)
        clean_story = ' '.join(story_text)
        target_text = puzzle.generate_text(stype='target', combination_length=1, templator=templator)

        story_key_edges = puzzle.get_story_relations(stype='story') + puzzle.get_story_relations(stype='fact')
        # Build query text
        query_templator = query_templator_class(templates=query_templates, family=puzzle.anc.family_data)
        query_text = puzzle.generate_text(stype='query', combination_length=1, templator=query_templator)
        query_text = ' '.join(query_text)
        query_text = query_text.replace('?.', '?')  # remove trailing '.'
        puzzle.convert_node_ids(stype='story')
        puzzle.convert_node_ids(stype='fact')
        story_keys_changed_ids = puzzle.get_sorted_story_edges(stype='story') + puzzle.get_sorted_story_edges(stype='fact')
        query_edge = puzzle.get_sorted_query_edge()

        genders = puzzle.get_name_gender_string()

        rows.append([pid, story, puzzle.query_text, query_text, puzzle.target_edge_rel, target_text,
                     clean_story, puzzle.proof_trace, puzzle.relation_comb, task_name, story_keys_changed_ids,
                     story_key_edges, query_edge, genders, '', puzzle.story_sort_dict, task_split])
        pb.update(1)
    pb.close()

    print("{} ancestries created".format(anc_num))
    print("Number of unique patterns : {}".format(len(f_comb_count)))
    return columns, rows, all_puzzles, train_patterns, test_patterns