def infer_dtype()

in core/maxframe/tensor/utils.py [0:0]


def infer_dtype(np_func, multi_outputs=False, empty=True, reverse=False, check=True):
    def make_arg(arg):
        if empty:
            return np.empty((1,) * max(1, arg.ndim), dtype=arg.dtype)
        else:
            if hasattr(arg, "op") and hasattr(arg.op, "data"):
                arg = arg.op.data
            return arg[(0,) * max(1, arg.ndim)]

    tensor_ufunc = "__tensor_ufunc__"

    def is_arg(arg):
        if hasattr(arg, tensor_ufunc):
            return False
        return hasattr(arg, "ndim") and hasattr(arg, "dtype")

    def inner(func):
        @wraps(func)
        def h(*tensors, **kw):
            usr_dtype = np.dtype(kw.pop("dtype")) if "dtype" in kw else None
            args = [make_arg(t) if is_arg(t) else t for t in tensors]
            if reverse:
                args = args[::-1]
            np_kw = dict(
                (k, make_arg(v) if hasattr(v, "op") else v)
                for k, v in kw.items()
                if is_arg(v) and k != "out"
            )

            dtype = None
            if not any(
                hasattr(arg, tensor_ufunc)
                for arg in itertools.chain(args, np_kw.values())
            ):
                # skip infer if encounter maxframe DataFrame etc
                # that implements __tensor_ufunc__
                try:
                    with np.errstate(all="ignore"):
                        if multi_outputs:
                            dtype = np_func(*args, **np_kw)[0].dtype
                        else:
                            dtype = np_func(*args, **np_kw).dtype
                except:  # noqa: E722
                    dtype = None

            if usr_dtype and dtype:
                can_cast_kwargs = {}
                if kw.get("casting") is not None:
                    can_cast_kwargs["casting"] = kw.get("casting")
                if check and not np.can_cast(dtype, usr_dtype, **can_cast_kwargs):
                    raise TypeError(
                        "No loop matching the specified signature "
                        f"and casting was found for ufunc {np_func}"
                    )
                kw["dtype"] = usr_dtype
            else:
                kw["dtype"] = dtype

            ret = func(*tensors, **kw)
            if ret is NotImplemented:
                reverse_func = (
                    getattr(inspect.getmodule(func), f"r{func.__name__}", None)
                    if not reverse
                    else None
                )
                if reverse_func is not None:
                    ret = reverse_func(*tensors[::-1], **kw)
                if ret is NotImplemented:
                    raise TypeError(
                        "unsupported operator type(s) for {0}: '{1}' and '{2}".format(
                            func.__name__, *[type(t) for t in tensors]
                        )
                    )
            return ret

        return h

    return inner