def generate_text()

in clutrr/relations/puzzle.py [0:0]


    def generate_text(self, stype='story', combination_length=1, templator:Templator=None, edges=None):
        """

        :param stype: can be story, fact, target, or query
        :param combination_length: the max length of combining the edges for text replacement
        :param templator: templator class
        :param edges: if provided, use these edges instead of stypes
        :return:
        """
        generated_rows = []
        if edges is None:
            if stype == 'story':
                edges_to_convert = copy.copy(self.story)
            elif stype == 'fact':
                edges_to_convert = copy.copy([fact.fact_edges for fact in self.facts])
                edges_to_convert = [y for x in edges_to_convert for y in x]
            elif stype == 'target':
                # derive the relation (solution) from the target edge
                edges_to_convert = [copy.copy(self.target_edge)]
            elif stype == 'query':
                # derive the question from the target edge
                edges_to_convert = [copy.copy(self.target_edge)]
            else:
                raise NotImplementedError("stype not implemented")
        else:
            edges_to_convert = edges

        combined_edges = comb_indexes(edges_to_convert, combination_length)
        for comb_group in combined_edges:
            r_combs = ['-'.join([self.get_edge_relation(edge) for edge in edge_group])
                       for edge_group in comb_group]
            # typo unfix for "neice niece"
            r_combs = [r.replace('niece','neice') if 'niece' in r else r for r in r_combs ]
            r_entities = [[ent for edge in edge_group for ent in edge] for edge_group
                          in comb_group]
            prows = [templator.replace_template(edge_group, r_entities[group_id])
                     for group_id, edge_group in enumerate(r_combs)]
            # if contains None, then reject this combination
            prc = [x for x in prows if x is not None]
            if len(prc) == len(prows):
                generated_rows.append(prows)


        # select the generated row such that the priority of
        # complex decomposition is higher. sort by length and choose the min
        generated_rows = list(sorted(generated_rows, key=len))
        generated_rows = [g for g in generated_rows if len(g) > 0]
        if stype == 'story':
            if len(generated_rows) == 0:
                # assert
                raise AssertionError()
        if len(generated_rows) > 0:
            generated_row = random.choice(generated_rows)
            for g in generated_row:
                if type(g) != str:
                    import ipdb; ipdb.set_trace()

            return generated_row
        else:
            return []