in expanded_checklist/checklist/editor.py [0:0]
def template(self, templates, nsamples=None,
product=True, remove_duplicates=False, mask_only=False,
unroll=False, labels=None, meta=False, save=False,
**kwargs):
"""Fills in templates
Parameters
----------
templates : str, list, tuple, or dict
On leaves: templates with {tags}, which will be substituted for mapping in **kwargs
Can have {mask} tags, which will be replaced by a masked language model.
Other tags can be numbered for distinction, e.g. {person} and {person1} will be considered
separate tags, but both will use fill-ins for 'person'
nsamples : int
Number of samples
product : bool
If true, take cartesian product
remove_duplicates : bool
If True, will not generate any strings where two or more fill-in values are duplicates.
mask_only : bool
If True, return only fill-in values for {mask} tokens
unroll : bool
If True, returns list of strings regardless of template type (i.e. unrolls)
labels : int or object with strings on leaves
If int, all generated strings will have the same label. Otherwise, can refer
to tags, or be strings, etc. Output will be in ret.meta
meta : bool
If True, ret.meta will contain a dict of fill in values for each item in ret.data
save : bool
If True, ret.templates will contain all parameters and fill-in lists
**kwargs : type
Must include fill-in lists for every tag not in editor.lexicons
Returns
-------
MunchWithAdd
Returns ret, a glorified dict, which will have the filled in templates in ret.data.
It may contain ret.labels, ret.templates and ret.meta (depending on parameters as noted above)
You can add or += two MunchWithAdd, which will concatenate values
"""
# 1. go through object, find every attribute inside brackets
# 2. check if they are in kwargs and self.attributes
# 3. generate keys and vals
# 4. go through object, generate
params = locals()
ret = MunchWithAdd()
del params['kwargs']
del params['self']
templates = copy.deepcopy(templates)
added_labels = False
if labels is not None and type(labels) != int:
added_labels = True
templates = (templates, labels)
all_keys = find_all_keys(templates)
items = self._get_fillin_items(all_keys, **kwargs)
mask_index, mask_options = get_mask_index(templates)
for mask, strings in mask_index.items():
# ks = {re.sub(r'.*?:', '', a): '{%s}' % a for a in all_keys}
ks = {}
tok = 'VERYLONGTOKENTHATWILLNOTEXISTEVER'
ks[mask] = tok
a_tok = 'thisisaratherlongtokenthatwillnotexist'
# print(mask)
# print('options:', mask_options[mask])
top = 100
find_top = re.search(r't(\d+)', mask_options[mask])
if find_top:
top = int(find_top.group(1))
sub_a = lambda x: re.sub(r'{[^:}]*a[^:}]*:(%s)}' % mask, r'{%s} {\1}' % a_tok, x)
# print(strings)
strings = recursive_apply(strings, sub_a)
ks[a_tok] = '{%s}' % a_tok
# print(strings)
ts = recursive_format(strings, ks, ignore_missing=True)
np.random.seed(1)
samp = self.template(ts, nsamples=5, remove_duplicates=remove_duplicates,
thisisaratherlongtokenthatwillnotexist=['a'], **kwargs).data
samp += self.template(ts, nsamples=5, remove_duplicates=remove_duplicates,
thisisaratherlongtokenthatwillnotexist=['an'], **kwargs).data
# print(samp)
# print(len([x for x in samp if ' an ' in x[0]]))
samp = [x.replace(tok, self.tg.tokenizer.mask_token) for y in samp for x in y][:20]
samp = list(set(samp))
# print(samp)
if 'beam_size' not in kwargs:
kwargs['beam_size'] = 100
# beam_size = kwargs.get('beam_size', 100)
# kwargs.
options = self.tg.unmask_multiple(samp, **kwargs)
# print(options)
# print(top)
v = [x[0] for x in options][:top]
items[mask] = v
if mask_only:
return options[:nsamples]
if save:
ret.templates = [(params, items)]
templates = recursive_apply(templates, replace_mask)
# print(templates)
keys = [x[0] for x in items.items()]
vals = [[x[1]] if type(x[1]) not in [list, tuple] else x[1] for x in items.items()]
if nsamples is not None:
# v = [np.random.choice(x, nsamples) for x in vals]
v = [wrapped_random_choice(x, nsamples) for x in vals]
if not v:
vals = [[]]
else:
vals = zip(*v)
# print(list(vals))
else:
if not product:
vals = zip(*vals)
else:
vals = itertools.product(*vals)
data = []
use_meta = meta
meta = []
for v in vals:
# print(v)
if remove_duplicates and len(v) != len(set([str(x) for x in v])):
continue
mapping = dict(zip(keys, v))
# print(templates)
# print(mapping)
data.append(recursive_format(templates, mapping))
meta.append(mapping)
if unroll and data and type(data[0]) in [list, np.array, tuple]:
data = [x for y in data for x in y]
meta = [x for y in meta for x in y]
if use_meta:
ret.meta = meta
if added_labels:
data, labels = map(list, zip(*data))
ret.labels = labels
if labels is not None and type(labels) == int:
ret.labels = [labels for _ in range(len(data))]
ret.data = data
return ret