def make_ndarray()

in tensorboard/util/tensor_util.py [0:0]


def make_ndarray(tensor):
    """Create a numpy ndarray from a tensor.

    Create a numpy ndarray with the same shape and data as the tensor.

    Args:
      tensor: A TensorProto.

    Returns:
      A numpy array with the tensor contents.

    Raises:
      TypeError: if tensor has unsupported type.
    """
    shape = [d.size for d in tensor.tensor_shape.dim]
    num_elements = np.prod(shape, dtype=np.int64)
    tensor_dtype = dtypes.as_dtype(tensor.dtype)
    dtype = tensor_dtype.as_numpy_dtype

    if tensor.tensor_content:
        return (
            np.frombuffer(tensor.tensor_content, dtype=dtype)
            .copy()
            .reshape(shape)
        )
    elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
        # the half_val field of the TensorProto stores the binary representation
        # of the fp16: we need to reinterpret this as a proper float16
        if len(tensor.half_val) == 1:
            tmp = np.array(tensor.half_val[0], dtype=np.uint16)
            tmp.dtype = tensor_dtype.as_numpy_dtype
            return np.repeat(tmp, num_elements).reshape(shape)
        else:
            tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
            tmp.dtype = tensor_dtype.as_numpy_dtype
            return tmp.reshape(shape)
    elif tensor_dtype == dtypes.float32:
        if len(tensor.float_val) == 1:
            return np.repeat(
                np.array(tensor.float_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == dtypes.float64:
        if len(tensor.double_val) == 1:
            return np.repeat(
                np.array(tensor.double_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
    elif tensor_dtype in [
        dtypes.int32,
        dtypes.uint8,
        dtypes.uint16,
        dtypes.int16,
        dtypes.int8,
        dtypes.qint32,
        dtypes.quint8,
        dtypes.qint8,
        dtypes.qint16,
        dtypes.quint16,
    ]:
        if len(tensor.int_val) == 1:
            return np.repeat(
                np.array(tensor.int_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == dtypes.int64:
        if len(tensor.int64_val) == 1:
            return np.repeat(
                np.array(tensor.int64_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == dtypes.string:
        if len(tensor.string_val) == 1:
            return np.repeat(
                np.array(tensor.string_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.array(
                [x for x in tensor.string_val], dtype=dtype
            ).reshape(shape)
    elif tensor_dtype == dtypes.complex64:
        it = iter(tensor.scomplex_val)
        if len(tensor.scomplex_val) == 2:
            return np.repeat(
                np.array(
                    complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
                    dtype=dtype,
                ),
                num_elements,
            ).reshape(shape)
        else:
            return np.array(
                [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
            ).reshape(shape)
    elif tensor_dtype == dtypes.complex128:
        it = iter(tensor.dcomplex_val)
        if len(tensor.dcomplex_val) == 2:
            return np.repeat(
                np.array(
                    complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
                    dtype=dtype,
                ),
                num_elements,
            ).reshape(shape)
        else:
            return np.array(
                [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
            ).reshape(shape)
    elif tensor_dtype == dtypes.bool:
        if len(tensor.bool_val) == 1:
            return np.repeat(
                np.array(tensor.bool_val[0], dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
    else:
        raise TypeError("Unsupported tensor type: %s" % tensor.dtype)