def gate_approx_eq()

in tensorflow_quantum/python/util.py [0:0]


def gate_approx_eq(gate_true, gate_deser, atol=1e-5):
    """Compares gates in the allowed TFQ gate set.

    Gates in TFQ support symbols, numbers or a single product of a real number
    and a symbol as their parameters. This function behaves like
    `cirq.approx_eq` specialized for these kinds of gates so that TFQ can
    support approximate equality in gates containing symbols.

    Args:
        gate_true: `cirq.Gate` which is in the TFQ gate set.  These are gates
          which are instances of those found in `tfq.util.get_supported_gates()`
        gate_deser: `cirq.Gate` which is in the TFQ gate set.  These are gates
          which are instances of those found in `tfq.util.get_supported_gates()`

    Returns:
        bool which says if the two gates are approximately equal in the way
            described above.

    Raises:
        TypeError: If input gates are not of type `cirq.Gate`.
        ValueError: If invalid gate types are provided.
    """
    if not isinstance(gate_true, cirq.Gate):
        raise TypeError(f"`gate_true` not a cirq gate, got {type(gate_true)}")
    if not isinstance(gate_deser, cirq.Gate):
        raise TypeError(f"`gate_deser` not a cirq gate, got {type(gate_deser)}")
    if isinstance(gate_true, cirq.ControlledGate) != isinstance(
            gate_deser, cirq.ControlledGate):
        return False
    if isinstance(gate_true, cirq.ControlledGate):
        if gate_true.control_qid_shape != gate_deser.control_qid_shape:
            return False
        if gate_true.control_values != gate_deser.control_values:
            return False
        return gate_approx_eq(gate_true.sub_gate, gate_deser.sub_gate)
    supported_gates = serializer.SERIALIZER.supported_gate_types()
    if not any([isinstance(gate_true, g) for g in supported_gates]):
        raise ValueError(f"`gate_true` not a valid TFQ gate, got {gate_true}")
    if not any([isinstance(gate_deser, g) for g in supported_gates]):
        raise ValueError(f"`gate_deser` not a valid TFQ gate, got {gate_deser}")
    if not isinstance(gate_true, type(gate_deser)):
        return False
    if isinstance(gate_true, type(cirq.I)) and isinstance(
            gate_deser, type(cirq.I)):
        # all identity gates are the same
        return True
    if isinstance(gate_true, cirq.EigenGate):
        a = _expression_approx_eq(gate_true._global_shift,
                                  gate_deser._global_shift, atol)
        b = _expression_approx_eq(gate_true._exponent, gate_deser._exponent,
                                  atol)
        return a and b
    if isinstance(gate_true, cirq.FSimGate):
        a = _expression_approx_eq(gate_true.theta, gate_deser.theta, atol)
        b = _expression_approx_eq(gate_true.phi, gate_deser.phi, atol)
        return a and b
    if isinstance(gate_true, (cirq.PhasedXPowGate, cirq.PhasedISwapPowGate)):
        a = _expression_approx_eq(gate_true._global_shift,
                                  gate_deser._global_shift, atol)
        b = _expression_approx_eq(gate_true._exponent, gate_deser._exponent,
                                  atol)
        c = _expression_approx_eq(gate_true._phase_exponent,
                                  gate_deser._phase_exponent, atol)
        return a and b and c
    if any(isinstance(gate_true, x) for x in _SUPPORTED_CHANNELS):
        # Compare channels.
        return _channel_approx_eq(gate_true, gate_deser, atol)
    raise ValueError(
        f"Some valid TFQ gate type is not yet accounted for, got {gate_true}")