From 1296df58ef069cc6364fc08d71657d4c157df466 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sat, 19 Apr 2025 23:55:27 +0300 Subject: [PATCH] Simplify decomposition of controlled eigengates with global phase --- .../classically_controlled_operation_test.py | 3 +- cirq-core/cirq/ops/common_gates.py | 18 ++++++ cirq-core/cirq/ops/controlled_gate.py | 12 ++++ cirq-core/cirq/ops/controlled_gate_test.py | 16 +++++ cirq-core/cirq/ops/eigen_gate.py | 20 ++++++- cirq-core/cirq/ops/eigen_gate_test.py | 41 +++++++++++++ cirq-core/cirq/ops/parity_gates_test.py | 59 +++++++++++++++++++ cirq-core/cirq/ops/raw_types_test.py | 2 +- .../transformers/merge_single_qubit_gates.py | 6 +- .../merge_single_qubit_gates_test.py | 36 ++++++++++- 10 files changed, 207 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 755f43f9368..2dd098a0462 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -473,7 +473,8 @@ def test_decompose(): op = cirq.H(q0).with_classical_controls('a') assert cirq.decompose(op) == [ (cirq.Y(q0) ** 0.5).with_classical_controls('a'), - cirq.XPowGate(exponent=1.0, global_shift=-0.25).on(q0).with_classical_controls('a'), + cirq.X(q0).with_classical_controls('a'), + cirq.global_phase_operation(1j**-0.5).with_classical_controls('a'), ] diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index af9c1100b34..a02bae29b6f 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -352,6 +352,12 @@ def __init__(self, *, rads: value.TParamVal): def _with_exponent(self, exponent: value.TParamVal) -> 'Rx': return Rx(rads=exponent * _pi(exponent)) + def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType: + """Returns: + NotImplemented, to signify the gate doesn't decompose further. + """ + return NotImplemented + def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> Union[str, 'protocols.CircuitDiagramInfo']: @@ -537,6 +543,12 @@ def __init__(self, *, rads: value.TParamVal): def _with_exponent(self, exponent: value.TParamVal) -> 'Ry': return Ry(rads=exponent * _pi(exponent)) + def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType: + """Returns: + NotImplemented, to signify the gate doesn't decompose further. + """ + return NotImplemented + def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> Union[str, 'protocols.CircuitDiagramInfo']: @@ -882,6 +894,12 @@ def __init__(self, *, rads: value.TParamVal): def _with_exponent(self, exponent: value.TParamVal) -> 'Rz': return Rz(rads=exponent * _pi(exponent)) + def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType: + """Returns: + NotImplemented, to signify the gate doesn't decompose further. + """ + return NotImplemented + def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> Union[str, 'protocols.CircuitDiagramInfo']: diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index e47602ac942..45626e8079d 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -33,6 +33,7 @@ control_values as cv, controlled_operation as cop, diagonal_gate as dg, + eigen_gate, global_phase_op as gp, op_tree, raw_types, @@ -159,6 +160,12 @@ def _decompose_with_context_( self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: control_qubits = list(qubits[: self.num_controls()]) + # If the subgate is an EigenGate with non-zero phase, try to decompose it + # into a phase-free gate and a global phase gate. + if isinstance(self.sub_gate, eigen_gate.EigenGate) and self.sub_gate.global_shift != 0: + result = self._decompose_sub_gate_with_controls(qubits, context) + if result is not NotImplemented: + return result if ( protocols.has_unitary(self.sub_gate) and protocols.num_qubits(self.sub_gate) == 1 @@ -219,6 +226,11 @@ def _decompose_with_context_( control_qid_shape=self.control_qid_shape, ).on(*control_qubits) return [result, controlled_phase_op] + return self._decompose_sub_gate_with_controls(qubits, context) + + def _decompose_sub_gate_with_controls( + self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None + ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: result = protocols.decompose_once_with_qubits( self.sub_gate, qubits[self.num_controls() :], diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..36601033d86 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -494,6 +494,22 @@ def _test_controlled_gate_is_consistent( np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13) +@pytest.mark.parametrize( + 'sub_gate, expected_decomposition', + [ + (cirq.XPowGate(global_shift=0.22), [cirq.Y**-0.5, cirq.CZ, cirq.Y**0.5, cirq.Z**0.22]), + (cirq.ZPowGate(exponent=1.2, global_shift=0.3), [cirq.CZ**1.2, cirq.Z**0.36]), + ], +) +def test_decompose_takes_out_global_phase( + sub_gate: cirq.Gate, expected_decomposition: Sequence[cirq.Gate] +): + cgate = cirq.ControlledGate(sub_gate, num_controls=1) + qubits = cirq.LineQubit.range(cgate.num_qubits()) + dec = cirq.decompose(cgate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition + + def test_pow_inverse(): assert cirq.inverse(CRestricted, None) is None assert cirq.pow(CRestricted, 1.5, None) is None diff --git a/cirq-core/cirq/ops/eigen_gate.py b/cirq-core/cirq/ops/eigen_gate.py index 2aec2976f27..300a20b3548 100644 --- a/cirq-core/cirq/ops/eigen_gate.py +++ b/cirq-core/cirq/ops/eigen_gate.py @@ -34,7 +34,7 @@ import sympy from cirq import protocols, value -from cirq.ops import raw_types +from cirq.ops import global_phase_op, raw_types if TYPE_CHECKING: import cirq @@ -375,6 +375,24 @@ def _json_dict_(self) -> Dict[str, Any]: def _measurement_key_objs_(self): return frozenset() + def _decompose_( + self, qubits: Tuple['cirq.Qid', ...] + ) -> Union[NotImplementedType, 'cirq.OP_TREE']: + """Attempts to decompose the gate into a phase-free gate and a global phase gate. + + Returns: + NotImplemented, if global phase or exponent are 0. Otherwise a phase-free gate + applied to the qubits followed by a global phase gate. + """ + if self.global_shift == 0 or self.exponent == 0: + return NotImplemented + self_without_phase = self._with_exponent(self.exponent) + # This doesn't work for gates that fix global_shift, such as Rx. These gates must define + # their own _decompose_ method. + self_without_phase._global_shift = 0 + global_phase = 1j ** (2 * self.global_shift * self.exponent) + return [self_without_phase.on(*qubits), global_phase_op.GlobalPhaseGate(global_phase)()] + def _lcm(vals: Iterable[int]) -> int: t = 1 diff --git a/cirq-core/cirq/ops/eigen_gate_test.py b/cirq-core/cirq/ops/eigen_gate_test.py index 177aac9a6a5..b04c4523c21 100644 --- a/cirq-core/cirq/ops/eigen_gate_test.py +++ b/cirq-core/cirq/ops/eigen_gate_test.py @@ -20,6 +20,7 @@ import cirq from cirq import value +from cirq.ops import global_phase_op from cirq.testing import assert_has_consistent_trace_distance_bound @@ -421,3 +422,43 @@ def _with_exponent(self, exponent): ) def test_equal_up_to_global_phase(gate1, gate2, eq_up_to_global_phase): assert cirq.equal_up_to_global_phase(gate1, gate2) == eq_up_to_global_phase + + +@pytest.mark.parametrize( + 'gate', + [ + cirq.Z, + cirq.Z**2, + cirq.XPowGate(global_shift=0.0), + cirq.rx(0), + cirq.ry(0), + cirq.rz(0), + cirq.CZPowGate(exponent=0.0, global_shift=0.25), + ], +) +def test_decompose_once_returns_not_implemented(gate: cirq.Gate): + qubits = cirq.LineQubit.range(gate.num_qubits()) + assert cirq.decompose_once(gate.on(*qubits), default=NotImplemented) == NotImplemented + + +@pytest.mark.parametrize( + 'gate, expected_decomposition', + [ + (cirq.X, [cirq.X]), + (cirq.ZPowGate(global_shift=0.5), [cirq.Z, global_phase_op.GlobalPhaseGate(1j)]), + ( + cirq.ZPowGate(global_shift=0.5) ** sympy.Symbol('e'), + [ + cirq.Z ** sympy.Symbol('e'), + global_phase_op.GlobalPhaseGate(1j ** (1.0 * sympy.Symbol('e'))), + ], + ), + (cirq.rx(np.pi / 2), [cirq.rx(np.pi / 2)]), + (cirq.ry(np.pi / 2), [cirq.ry(np.pi / 2)]), + (cirq.rz(np.pi / 2), [cirq.rz(np.pi / 2)]), + ], +) +def test_decompose_takes_out_global_phase(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]): + qubits = cirq.LineQubit.range(gate.num_qubits()) + dec = cirq.decompose(gate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition diff --git a/cirq-core/cirq/ops/parity_gates_test.py b/cirq-core/cirq/ops/parity_gates_test.py index 750571258c1..4f88aee8f46 100644 --- a/cirq-core/cirq/ops/parity_gates_test.py +++ b/cirq-core/cirq/ops/parity_gates_test.py @@ -14,6 +14,8 @@ """Tests for `parity_gates.py`.""" +from typing import List + import numpy as np import pytest import sympy @@ -348,3 +350,60 @@ def test_clifford_protocols(gate_cls: type[cirq.EigenGate], exponent: float, is_ else: assert not cirq.has_stabilizer_effect(gate) assert gate._decompose_into_clifford_with_qubits_(cirq.LineQubit.range(2)) is NotImplemented + + +@pytest.mark.parametrize( + 'gate, expected_decomposition', + [ + ( + cirq.XXPowGate(), + [ + (cirq.Y**-0.5).on(cirq.LineQubit(0)), + (cirq.Y**-0.5).on(cirq.LineQubit(1)), + cirq.Z(cirq.LineQubit(0)), + cirq.Z(cirq.LineQubit(1)), + (cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)), + (cirq.Y**0.5).on(cirq.LineQubit(0)), + (cirq.Y**0.5).on(cirq.LineQubit(1)), + ], + ), + ( + cirq.YYPowGate(), + [ + (cirq.X**0.5).on(cirq.LineQubit(0)), + (cirq.X**0.5).on(cirq.LineQubit(1)), + cirq.Z(cirq.LineQubit(0)), + cirq.Z(cirq.LineQubit(1)), + (cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)), + (cirq.X**-0.5).on(cirq.LineQubit(0)), + (cirq.X**-0.5).on(cirq.LineQubit(1)), + ], + ), + ( + cirq.ZZPowGate(), + [ + cirq.Z(cirq.LineQubit(0)), + cirq.Z(cirq.LineQubit(1)), + (cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)), + ], + ), + ( + cirq.MSGate(rads=0), + [ + (cirq.Y**-0.5).on(cirq.LineQubit(0)), + (cirq.Y**-0.5).on(cirq.LineQubit(1)), + (cirq.Z**0.0).on(cirq.LineQubit(0)), + (cirq.Z**0.0).on(cirq.LineQubit(1)), + cirq.CZPowGate(exponent=-0.0, global_shift=0.25).on( + cirq.LineQubit(0), cirq.LineQubit(1) + ), + (cirq.Y**0.5).on(cirq.LineQubit(0)), + (cirq.Y**0.5).on(cirq.LineQubit(1)), + ], + ), + ], +) +def test_gate_decomposition(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]): + qubits = cirq.LineQubit.range(gate.num_qubits()) + dec = cirq.decompose(gate.on(*qubits)) + assert [op for op in dec] == expected_decomposition diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index a4985ad1844..67fd728f618 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -647,7 +647,7 @@ def test_tagged_operation_forwards_protocols(): np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h)) assert cirq.has_unitary(tagged_h) assert cirq.decompose(tagged_h) == cirq.decompose(h) - assert [*tagged_h._decompose_()] == cirq.decompose(h) + assert [*tagged_h._decompose_()] == cirq.decompose_once(h) assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h) assert cirq.equal_up_to_global_phase(h, tagged_h) assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all() diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index c48e73fae8e..7e86b7c13f9 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -116,7 +116,7 @@ def merge_single_qubit_moments_to_phxz( def can_merge_moment(m: 'cirq.Moment'): return all( - protocols.num_qubits(op) == 1 + (protocols.num_qubits(op) == 1 or protocols.num_qubits(op) == 0) and protocols.has_unitary(op) and tags_to_ignore.isdisjoint(op.tags) for op in m @@ -144,6 +144,10 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ) if gate: ret_ops.append(gate(q)) + # Transfer global phase + for op in m1.operations + m2.operations: + if protocols.num_qubits(op) == 0: + ret_ops.append(op) return circuits.Moment(ret_ops) return transformer_primitives.merge_moments( diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 8ea1fd3d273..1eae6259d95 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -221,13 +221,45 @@ def test_merge_single_qubit_moments_to_phxz_deep(): ) -def test_merge_single_qubit_moments_to_phxz_global_phase(): +def test_merge_single_qubit_gates_to_phxz_global_phase(): c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on()) c2 = cirq.merge_single_qubit_gates_to_phxz(c) assert c == c2 -def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase(): +def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase(): c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on()) c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c) assert c == c2 + + +def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_first_moment(): + q0 = cirq.LineQubit(0) + c_orig = cirq.Circuit( + cirq.Moment(cirq.Y(q0) ** 0.5, cirq.GlobalPhaseGate(1j**0.5).on()), cirq.Moment(cirq.X(q0)) + ) + c_expected = cirq.Circuit( + cirq.Moment( + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0), + cirq.GlobalPhaseGate(1j**0.5).on(), + ) + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context) + assert c_new == c_expected + + +def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_second_moment(): + q0 = cirq.LineQubit(0) + c_orig = cirq.Circuit( + cirq.Moment(cirq.Y(q0) ** 0.5), cirq.Moment(cirq.X(q0), cirq.GlobalPhaseGate(1j**0.5).on()) + ) + c_expected = cirq.Circuit( + cirq.Moment( + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0), + cirq.GlobalPhaseGate(1j**0.5).on(), + ) + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context) + assert c_new == c_expected