def template()

in expanded_checklist/checklist/editor.py [0:0]


    def template(self, templates, nsamples=None,
                 product=True, remove_duplicates=False, mask_only=False,
                 unroll=False, labels=None, meta=False,  save=False,
                 **kwargs):
        """Fills in templates

        Parameters
        ----------
        templates : str, list, tuple, or dict
            On leaves: templates with {tags}, which will be substituted for mapping in **kwargs
            Can have {mask} tags, which will be replaced by a masked language model.
            Other tags can be numbered for distinction, e.g. {person} and {person1} will be considered
            separate tags, but both will use fill-ins for 'person'
        nsamples : int
            Number of samples
        product : bool
            If true, take cartesian product
        remove_duplicates : bool
            If True, will not generate any strings where two or more fill-in values are duplicates.
        mask_only : bool
            If True, return only fill-in values for {mask} tokens
        unroll : bool
            If True, returns list of strings regardless of template type (i.e. unrolls)
        labels : int or object with strings on leaves
            If int, all generated strings will have the same label. Otherwise, can refer
            to tags, or be strings, etc. Output will be in ret.meta
        meta : bool
            If True, ret.meta will contain a dict of fill in values for each item in ret.data
        save : bool
            If True, ret.templates will contain all parameters and fill-in lists

        **kwargs : type
            Must include fill-in lists for every tag not in editor.lexicons

        Returns
        -------
        MunchWithAdd
            Returns ret, a glorified dict, which will have the filled in templates in ret.data.
            It may contain ret.labels, ret.templates and ret.meta (depending on parameters as noted above)
            You can add or += two MunchWithAdd, which will concatenate values

        """

    # 1. go through object, find every attribute inside brackets
    # 2. check if they are in kwargs and self.attributes
    # 3. generate keys and vals
    # 4. go through object, generate
        params = locals()
        ret = MunchWithAdd()
        del params['kwargs']
        del params['self']
        templates = copy.deepcopy(templates)
        added_labels = False
        if labels is not None and type(labels) != int:
            added_labels = True
            templates = (templates, labels)
        all_keys = find_all_keys(templates)
        items = self._get_fillin_items(all_keys, **kwargs)
        mask_index, mask_options = get_mask_index(templates)

        for mask, strings in mask_index.items():
            # ks = {re.sub(r'.*?:', '', a): '{%s}' % a for a in all_keys}
            ks = {}
            tok = 'VERYLONGTOKENTHATWILLNOTEXISTEVER'
            ks[mask] = tok
            a_tok = 'thisisaratherlongtokenthatwillnotexist'
            # print(mask)
            # print('options:', mask_options[mask])
            top = 100
            find_top = re.search(r't(\d+)', mask_options[mask])
            if find_top:
                top = int(find_top.group(1))
            sub_a = lambda x: re.sub(r'{[^:}]*a[^:}]*:(%s)}' % mask, r'{%s} {\1}' % a_tok, x)
            # print(strings)
            strings = recursive_apply(strings, sub_a)
            ks[a_tok] = '{%s}' % a_tok
            # print(strings)
            ts = recursive_format(strings, ks, ignore_missing=True)
            np.random.seed(1)
            samp = self.template(ts, nsamples=5, remove_duplicates=remove_duplicates,
                                 thisisaratherlongtokenthatwillnotexist=['a'], **kwargs).data
            samp += self.template(ts, nsamples=5, remove_duplicates=remove_duplicates,
                                 thisisaratherlongtokenthatwillnotexist=['an'], **kwargs).data
            # print(samp)
            # print(len([x for x in samp if ' an ' in x[0]]))
            samp = [x.replace(tok, self.tg.tokenizer.mask_token) for y in samp for x in y][:20]
            samp = list(set(samp))
            # print(samp)
            if 'beam_size' not in kwargs:
                kwargs['beam_size'] = 100
            # beam_size = kwargs.get('beam_size', 100)
            # kwargs.
            options = self.tg.unmask_multiple(samp, **kwargs)
            # print(options)
            # print(top)
            v = [x[0] for x in options][:top]
            items[mask] = v
            if mask_only:
                return options[:nsamples]
        if save:
            ret.templates = [(params, items)]
        templates = recursive_apply(templates, replace_mask)
        # print(templates)
        keys = [x[0] for x in items.items()]
        vals = [[x[1]] if type(x[1]) not in [list, tuple] else x[1] for x in items.items()]
        if nsamples is not None:
            # v = [np.random.choice(x, nsamples) for x in vals]
            v = [wrapped_random_choice(x, nsamples) for x in vals]
            if not v:
                vals = [[]]
            else:
                vals = zip(*v)
            # print(list(vals))
        else:
            if not product:
                vals = zip(*vals)
            else:
                vals = itertools.product(*vals)
        data = []
        use_meta = meta
        meta = []
        for v in vals:
            # print(v)
            if remove_duplicates and len(v) != len(set([str(x) for x in v])):
                continue
            mapping = dict(zip(keys, v))
            # print(templates)
            # print(mapping)
            data.append(recursive_format(templates, mapping))
            meta.append(mapping)
        if unroll and data and type(data[0]) in [list, np.array, tuple]:
            data = [x for y in data for x in y]
            meta = [x for y in meta for x in y]
        if use_meta:
            ret.meta = meta
        if added_labels:
            data, labels = map(list, zip(*data))
            ret.labels = labels
        if labels is not None and type(labels) == int:
            ret.labels = [labels for _ in range(len(data))]
        ret.data = data
        return ret