Skip to content

Commit d5da63a

Browse files
committed
Support merge 1-qubit gates for parameterized circuits
Note: if users expect update sweeps together with merging, it's not supported yet. It's a todo item to be supported in the followup PRs.
1 parent 0551f85 commit d5da63a

File tree

4 files changed

+162
-1
lines changed

4 files changed

+162
-1
lines changed

cirq-core/cirq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@
378378
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
379379
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
380380
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
381+
merge_into_symbolized_phxz as merge_into_symbolized_phxz,
381382
optimize_for_target_gateset as optimize_for_target_gateset,
382383
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,
383384
prepare_two_qubit_state_using_cz as prepare_two_qubit_state_using_cz,

cirq-core/cirq/transformers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
102102
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
103103
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
104+
merge_into_symbolized_phxz as merge_into_symbolized_phxz,
104105
)
105106

106107
from cirq.transformers.qubit_management_transformers import (

cirq-core/cirq/transformers/merge_single_qubit_gates.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414

1515
"""Transformer passes to combine adjacent single-qubit rotations."""
1616

17-
from typing import Optional, TYPE_CHECKING
17+
import enum
18+
import warnings
19+
from typing import Optional, TYPE_CHECKING
20+
1821

1922
from cirq import circuits, ops, protocols
23+
from cirq.study import sweepable
2024
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
2125
from cirq.transformers import transformer_api, transformer_primitives, merge_k_qubit_gates
2226

@@ -152,3 +156,87 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
152156
deep=context.deep if context else False,
153157
tags_to_ignore=tuple(tags_to_ignore),
154158
).unfreeze(copy=False)
159+
160+
161+
@transformer_api.transformer
162+
def merge_into_symbolized_phxz(
163+
circuit: 'cirq.AbstractCircuit',
164+
*,
165+
context: Optional['cirq.TransformerContext'] = None,
166+
sweeps: Optional['sweepable.Sweepable'] = None,
167+
atol: float = 1e-8,
168+
) -> 'cirq.Circuit':
169+
"""Merge consecutive single qubit gates into connected symbolized PhasedXZ gates.
170+
171+
Specifically, if at least one of the consecutive gates is symbolized, then the merged gate
172+
will be a symbolized gate.
173+
174+
e.g., X-Y-H-phxz(sa, sx, sz) ---transform---> phxz(sa, sx, sz)
175+
176+
Note, we only consider merging non-parameterized gates to symbolized phxz with
177+
3 degrees of freedom, meaning that gates like Z^exp_symbol will be considered non-mergable.
178+
179+
Args:
180+
circuit: Input circuit to transform. It will not be modified.
181+
sweeps: Sweeps of the symbols in the input circuit, updated Sweeps will be returned
182+
based on the transformation.
183+
context: `cirq.TransformerContext` storing common configurable options for transformers.
184+
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
185+
dropped, smaller values increase accuracy.
186+
187+
Returns:
188+
Copy of the transformed input circuit.
189+
"""
190+
191+
# TODO(#6994): support returning update sweeps when sweeps are provided.
192+
if sweeps is not None:
193+
raise NotImplementedError("To be supported in #6994.")
194+
195+
if not protocols.is_parameterized(circuit):
196+
warnings.warn(
197+
"Expect parameterized circuits. "
198+
"Please use cirq.merge_single_qubit_gates_to_phxz instead.",
199+
UserWarning,
200+
)
201+
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol)
202+
203+
# Merge all non parameterized single qubit gates first.
204+
circuit = merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol)
205+
206+
def _merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation'):
207+
208+
class _MergeGateType(enum.Enum):
209+
MERAGABLE_NON_PARAMETERIZED = 0
210+
MERAGABLE_PARAMETERIZED_PHXZ = 1
211+
NON_MERGEABLE = 2
212+
213+
def _categorize(op: 'cirq.Operation') -> _MergeGateType:
214+
if protocols.has_unitary(op) and protocols.num_qubits(op) == 1:
215+
return _MergeGateType.MERAGABLE_NON_PARAMETERIZED
216+
if isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op):
217+
return _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ
218+
return _MergeGateType.NON_MERGEABLE
219+
220+
merge_type1 = _categorize(op1)
221+
merge_type2 = _categorize(op2)
222+
223+
if (
224+
merge_type1 == _MergeGateType.NON_MERGEABLE
225+
or merge_type2 == _MergeGateType.NON_MERGEABLE
226+
):
227+
return None
228+
229+
# absorb the non-parameterized gate into the parameterized gate.
230+
if merge_type1 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ:
231+
return op1
232+
if merge_type2 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ:
233+
return op2
234+
235+
return None # pragma: no cover
236+
237+
return transformer_primitives.merge_operations(
238+
circuit,
239+
_merge_func,
240+
deep=context.deep if context else False,
241+
tags_to_ignore=context.tags_to_ignore if context else (),
242+
).unfreeze()

cirq-core/cirq/transformers/merge_single_qubit_gates_test.py

+71
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
from typing import List
1616

17+
import pytest
18+
import sympy
1719
import cirq
20+
from cirq.study.sweeps import Points
1821

1922

2023
def assert_optimizes(optimized: cirq.AbstractCircuit, expected: cirq.AbstractCircuit):
@@ -231,3 +234,71 @@ def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase():
231234
c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on())
232235
c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c)
233236
assert c == c2
237+
238+
239+
def test_merge_into_symbolized_phxz():
240+
"""Test case diagram.
241+
Input circuit:
242+
0: ───X───────@───H[ignore]───H───X───PhXZ(a=a1,x=x1,z=z1)───X───PhXZ(a=a2,x=x2,z=z2)───H───
243+
│ ║
244+
1: ───Y^0.5───@───M─────────────────────────────────────────────────────────────────────╫───
245+
║ ║
246+
m: ═══════════════@═════════════════════════════════════════════════════════════════════^═══
247+
Expected output:
248+
0: ───PhXZ(a=-1,x=1,z=0)──────@───H[ignore]───PhXZ(a=a1,x=x1,z=z1)───H───
249+
│ ║
250+
1: ───PhXZ(a=0.5,x=0.5,z=0)───@───M──────────────────────────────────╫───
251+
║ ║
252+
m: ═══════════════════════════════@══════════════════════════════════^═══
253+
"""
254+
a, b = cirq.LineQubit.range(2)
255+
sa1, sa2 = [sympy.Symbol(a) for a in ["a1", "a2"]]
256+
sx1, sx2 = [sympy.Symbol(x) for x in ["x1", "x2"]]
257+
sz1, sz2 = [sympy.Symbol(z) for z in ["z1", "z2"]]
258+
input_circuit = cirq.Circuit(
259+
cirq.X(a),
260+
cirq.Y(b) ** 0.5,
261+
cirq.CZ(a, b),
262+
cirq.H(a).with_tags("ignore"),
263+
cirq.H(a),
264+
cirq.X(a),
265+
_phxz(sa1, sx1, sz1).on(a),
266+
cirq.X(a),
267+
_phxz(sa2, sx2, sz2).on(a),
268+
cirq.measure(b, key="m"),
269+
cirq.H(a).with_classical_controls("m"),
270+
)
271+
context = cirq.TransformerContext(tags_to_ignore=["ignore"])
272+
assert_optimizes(
273+
optimized=cirq.merge_into_symbolized_phxz(input_circuit, context=context),
274+
expected=cirq.Circuit(
275+
_phxz(-1, 1, 0).on(a),
276+
_phxz(0.5, 0.5, 0).on(b),
277+
cirq.CZ(a, b),
278+
cirq.H(a).with_tags("ignore"),
279+
_phxz(sa1, sx1, sz1).on(a),
280+
cirq.measure(b, key="m"),
281+
cirq.H(a).with_classical_controls("m"),
282+
),
283+
)
284+
285+
286+
def test_merge_into_symbolized_phxz_other_symbolized_gates():
287+
a = cirq.NamedQubit('a')
288+
input_circuit = cirq.Circuit(_phxz(1, 1, 1).on(a), cirq.H(a) ** sympy.Symbol("exp"))
289+
assert_optimizes(
290+
optimized=cirq.merge_into_symbolized_phxz(input_circuit), expected=input_circuit
291+
)
292+
293+
294+
def test_merge_into_symbolized_phxz_non_symbolized_input():
295+
a = cirq.NamedQubit('a')
296+
with pytest.warns(UserWarning):
297+
cirq.merge_into_symbolized_phxz(cirq.Circuit(cirq.H(a), cirq.H(a)))
298+
299+
300+
def test_merge_into_symbolized_phxz_with_sweeps():
301+
with pytest.raises(NotImplementedError):
302+
cirq.merge_into_symbolized_phxz(
303+
cirq.Circuit(), sweeps=[Points(key="x", points=[0.1, 0.2, 0.5])]
304+
)

0 commit comments

Comments
 (0)