in lib/misc.py [0:0]
def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):
"""
Usage:
x_Bhwse, undo = reshape_undo(
x_bthwe,
'b, t, ..., stride*e',
'b*t, ..., stride, e',
stride=7
)
x_Bhwse = do_some_stuff(x_Bhwse)
x_bthwe = undo(x_Bhwse)
It's necessary to pass known values as keywords only
when they can't be inferred from the shape.
(Eg. in the above example we needed to pass
stride but not b, t, or e, since those can be determined from
inp.shape once stride is known.)
"""
if known:
known = {**kwargs, **known}
else:
known = kwargs
assert type(before) is type(after), f"{type(before)} != {type(after)}"
assert isinstance(inp, (th.Tensor, np.ndarray)), f"require tensor or ndarray but got {type(inp)}"
assert isinstance(before, (str, list)), f"require str or list but got {type(before)}"
if isinstance(before, str):
before = _parse_reshape_str(before, "before")
after = _parse_reshape_str(after, "after")
before, after = _handle_ellipsis(inp, before, after)
before_saved, after_saved = before, after
before, known = _infer(known=known, desc=before, shape=inp.shape)
before = _ground(before, known, product(inp.shape))
after = _ground(after, known, product(inp.shape))
known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)}
assert tuple(inp.shape) == tuple(before), f"expected shape {before} but got shape {inp.shape}"
assert product(inp.shape) == product(
after
), f"cannot reshape {inp.shape} to {after} because the number of elements does not match"
return (
inp.reshape(after),
compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)),
)