in reagent/core/types.py [0:0]
def __getattr__(self, attr):
if attr.startswith("__") and attr.endswith("__"):
raise AttributeError
tensor_attr = getattr(torch.Tensor, attr, None)
if tensor_attr is None or not callable(tensor_attr):
# TODO: can we get this working well with jupyter?
logger.error(
f"Attempting to call {self.__class__.__name__}.{attr} on "
f"{type(self)} (instance of TensorDataClass)."
)
if tensor_attr is None:
raise AttributeError(
f"{self.__class__.__name__}doesn't have {attr} attribute."
)
else:
raise RuntimeError(f"{self.__class__.__name__}.{attr} is not callable.")
def continuation(*args, **kwargs):
def f(v):
# if possible, returns v.attr(*args, **kwargs).
# otws, return v
if isinstance(v, (torch.Tensor, TensorDataClass)):
return getattr(v, attr)(*args, **kwargs)
elif isinstance(v, dict):
return {kk: f(vv) for kk, vv in v.items()}
elif isinstance(v, tuple):
return tuple(f(vv) for vv in v)
return v
return type(self)(**f(self.__dict__))
return continuation