def reshape_undo()

in lib/misc.py [0:0]


def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):
    """
    Usage:
    x_Bhwse, undo = reshape_undo(
        x_bthwe,
        'b, t, ..., stride*e',
        'b*t, ..., stride, e',
        stride=7
    )
    x_Bhwse = do_some_stuff(x_Bhwse)
    x_bthwe = undo(x_Bhwse)

    It's necessary to pass known values as keywords only
    when they can't be inferred from the shape.

    (Eg. in the above example we needed to pass
    stride but not b, t, or e, since those can be determined from
    inp.shape once stride is known.)
    """
    if known:
        known = {**kwargs, **known}
    else:
        known = kwargs
    assert type(before) is type(after), f"{type(before)} != {type(after)}"
    assert isinstance(inp, (th.Tensor, np.ndarray)), f"require tensor or ndarray but got {type(inp)}"
    assert isinstance(before, (str, list)), f"require str or list but got {type(before)}"
    if isinstance(before, str):
        before = _parse_reshape_str(before, "before")
        after = _parse_reshape_str(after, "after")
        before, after = _handle_ellipsis(inp, before, after)
    before_saved, after_saved = before, after
    before, known = _infer(known=known, desc=before, shape=inp.shape)
    before = _ground(before, known, product(inp.shape))
    after = _ground(after, known, product(inp.shape))
    known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)}
    assert tuple(inp.shape) == tuple(before), f"expected shape {before} but got shape {inp.shape}"
    assert product(inp.shape) == product(
        after
    ), f"cannot reshape {inp.shape} to {after} because the number of elements does not match"
    return (
        inp.reshape(after),
        compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)),
    )