def solve()

in tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py [0:0]


  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
        like a [batch] matrices meaning for every set of leading dimensions, the
        last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
    if self.is_non_singular is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "be singular.")
    if self.is_square is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "not be square.")

    def _check_operators_agree(r, l, message):
      if (r.range_dimension is not None and
          l.domain_dimension is not None and
          r.range_dimension != l.domain_dimension):
        raise ValueError(message)

    if isinstance(rhs, linear_operator.LinearOperator):
      left_operator = self.adjoint() if adjoint else self
      right_operator = rhs.adjoint() if adjoint_arg else rhs

      _check_operators_agree(
          right_operator, left_operator,
          "Operators are incompatible. Expected `x` to have dimension"
          " {} but got {}.".format(
              left_operator.domain_dimension, right_operator.range_dimension))

      # We can efficiently solve BlockDiag LinearOperators if the number of
      # blocks agree.
      if isinstance(right_operator, LinearOperatorBlockDiag):
        if len(left_operator.operators) != len(right_operator.operators):
          raise ValueError(
              "Can not efficiently solve `LinearOperatorBlockDiag` when "
              "number of blocks differ.")

        for o1, o2 in zip(left_operator.operators, right_operator.operators):
          _check_operators_agree(
              o2, o1,
              "Blocks are incompatible. Expected `x` to have dimension"
              " {} but got {}.".format(
                  o1.domain_dimension, o2.range_dimension))

      with self._name_scope(name):  # pylint: disable=not-callable
        return linear_operator_algebra.solve(left_operator, right_operator)

    with self._name_scope(name):  # pylint: disable=not-callable
      block_dimensions = (self._block_domain_dimensions() if adjoint
                          else self._block_range_dimensions())
      arg_dim = -1 if adjoint_arg else -2
      blockwise_arg = linear_operator_util.arg_is_blockwise(
          block_dimensions, rhs, arg_dim)

      if blockwise_arg:
        split_rhs = rhs
        for i, block in enumerate(split_rhs):
          if not isinstance(block, linear_operator.LinearOperator):
            block = ops.convert_to_tensor(block)
            # self._check_input_dtype(block)
            block_dimensions[i].assert_is_compatible_with(tensor_shape.TensorShape(block.shape)[arg_dim])
            split_rhs[i] = block
      else:
        rhs = ops.convert_to_tensor(rhs, name="rhs")
        # self._check_input_dtype(rhs)
        op_dimension = (self.domain_dimension if adjoint
                        else self.range_dimension)
        op_dimension.assert_is_compatible_with(tensor_shape.TensorShape(rhs.shape)[arg_dim])
        split_dim = -1 if adjoint_arg else -2
        # Split input by rows normally, and otherwise columns.
        split_rhs = linear_operator_util.split_arg_into_blocks(
            self._block_domain_dimensions(),
            self._block_domain_dimension_tensors,
            rhs, axis=split_dim)

      solution_list = []
      for index, operator in enumerate(self.operators):
        solution_list = solution_list + [operator.solve(
            split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]

      if blockwise_arg:
        return solution_list

      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
          solution_list)
      return prefer_static.concat(solution_list, axis=-2)