in torchtext/data/datasets_utils.py [0:0]
def _wrap_split_argument_with_fn(fn, splits):
"""
Wraps given function of specific signature to extend behavior of split
to support individual strings. The given function is expected to have a split
kwarg that accepts tuples of strings, e.g. ('train', 'valid') and the returned
function will have a split argument that also accepts strings, e.g. 'train', which
are then turned single entry tuples. Furthermore, the return value of the wrapped
function is unpacked if split is only a single string to enable behavior such as
train = AG_NEWS(split='train')
train, valid = AG_NEWS(split=('train', 'valid'))
"""
argspec = inspect.getfullargspec(fn)
if not (argspec.args[0] == "root" and
argspec.args[1] == "split" and
argspec.varargs is None and
argspec.varkw is None and
len(argspec.kwonlyargs) == 0
):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
@functools.wraps(fn)
def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
result = []
for item in _check_default_set(split, splits, fn.__name__):
result.append(fn(root, item, **kwargs))
return _wrap_datasets(tuple(result), split)
new_sig = inspect.signature(new_fn)
new_sig_params = new_sig.parameters
new_params = []
new_params.append(new_sig_params['root'].replace(default='.data'))
new_params.append(new_sig_params['split'].replace(default=splits))
new_params += [entry[1] for entry in list(new_sig_params.items())[2:]]
new_sig = new_sig.replace(parameters=tuple(new_params))
new_fn.__signature__ = new_sig
return new_fn