in tensorflow_quantum/python/util.py [0:0]
def gate_approx_eq(gate_true, gate_deser, atol=1e-5):
"""Compares gates in the allowed TFQ gate set.
Gates in TFQ support symbols, numbers or a single product of a real number
and a symbol as their parameters. This function behaves like
`cirq.approx_eq` specialized for these kinds of gates so that TFQ can
support approximate equality in gates containing symbols.
Args:
gate_true: `cirq.Gate` which is in the TFQ gate set. These are gates
which are instances of those found in `tfq.util.get_supported_gates()`
gate_deser: `cirq.Gate` which is in the TFQ gate set. These are gates
which are instances of those found in `tfq.util.get_supported_gates()`
Returns:
bool which says if the two gates are approximately equal in the way
described above.
Raises:
TypeError: If input gates are not of type `cirq.Gate`.
ValueError: If invalid gate types are provided.
"""
if not isinstance(gate_true, cirq.Gate):
raise TypeError(f"`gate_true` not a cirq gate, got {type(gate_true)}")
if not isinstance(gate_deser, cirq.Gate):
raise TypeError(f"`gate_deser` not a cirq gate, got {type(gate_deser)}")
if isinstance(gate_true, cirq.ControlledGate) != isinstance(
gate_deser, cirq.ControlledGate):
return False
if isinstance(gate_true, cirq.ControlledGate):
if gate_true.control_qid_shape != gate_deser.control_qid_shape:
return False
if gate_true.control_values != gate_deser.control_values:
return False
return gate_approx_eq(gate_true.sub_gate, gate_deser.sub_gate)
supported_gates = serializer.SERIALIZER.supported_gate_types()
if not any([isinstance(gate_true, g) for g in supported_gates]):
raise ValueError(f"`gate_true` not a valid TFQ gate, got {gate_true}")
if not any([isinstance(gate_deser, g) for g in supported_gates]):
raise ValueError(f"`gate_deser` not a valid TFQ gate, got {gate_deser}")
if not isinstance(gate_true, type(gate_deser)):
return False
if isinstance(gate_true, type(cirq.I)) and isinstance(
gate_deser, type(cirq.I)):
# all identity gates are the same
return True
if isinstance(gate_true, cirq.EigenGate):
a = _expression_approx_eq(gate_true._global_shift,
gate_deser._global_shift, atol)
b = _expression_approx_eq(gate_true._exponent, gate_deser._exponent,
atol)
return a and b
if isinstance(gate_true, cirq.FSimGate):
a = _expression_approx_eq(gate_true.theta, gate_deser.theta, atol)
b = _expression_approx_eq(gate_true.phi, gate_deser.phi, atol)
return a and b
if isinstance(gate_true, (cirq.PhasedXPowGate, cirq.PhasedISwapPowGate)):
a = _expression_approx_eq(gate_true._global_shift,
gate_deser._global_shift, atol)
b = _expression_approx_eq(gate_true._exponent, gate_deser._exponent,
atol)
c = _expression_approx_eq(gate_true._phase_exponent,
gate_deser._phase_exponent, atol)
return a and b and c
if any(isinstance(gate_true, x) for x in _SUPPORTED_CHANNELS):
# Compare channels.
return _channel_approx_eq(gate_true, gate_deser, atol)
raise ValueError(
f"Some valid TFQ gate type is not yet accounted for, got {gate_true}")