def _torch_to_tensor_nd()

in neuron_explainer/activation_server/dst_helpers.py [0:0]


def _torch_to_tensor_nd(x: torch.Tensor) -> TensorND:
    ndim = x.ndim
    if ndim == 0:
        return Tensor0D(value=x.item())
    elif ndim == 1:
        return Tensor1D(value=_float_tensor_to_list(x))
    elif ndim == 2:
        return Tensor2D(value=[_float_tensor_to_list(row) for row in x])
    elif ndim == 3:
        return Tensor3D(value=[[_float_tensor_to_list(row) for row in matrix] for matrix in x])
    else:
        raise NotImplementedError(f"Unknown ndim: {ndim}")