Skip to content

Commit 7ac61a5

Browse files
committed
fix checks.
1 parent 99bd54e commit 7ac61a5

File tree

3 files changed

+289
-124
lines changed

3 files changed

+289
-124
lines changed

cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import sympy
2626
from attrs import field, frozen
2727

28-
from cirq.transformers import transformer_api
29-
from cirq import ops, circuits
28+
from cirq import circuits, ops
3029
from cirq.protocols import unitary_protocol
3130
from cirq.protocols.has_unitary_protocol import has_unitary
3231
from cirq.study.sweeps import Points, Sweep, Zip
32+
from cirq.transformers import transformer_api
3333
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
3434

3535

cirq-core/cirq/transformers/merge_single_qubit_gates.py

+184-114
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
"""Transformer passes to combine adjacent single-qubit rotations."""
1616

1717
import itertools
18-
import warnings
19-
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
18+
from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING
2019

2120
import sympy
2221

2322
from cirq import circuits, ops, protocols
2423
from cirq.study.sweeps import Points, Sweep, Zip
25-
from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives, align
24+
from cirq.transformers import align, merge_k_qubit_gates, transformer_api, transformer_primitives
2625
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
2726

2827
if TYPE_CHECKING:
@@ -159,115 +158,53 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
159158
).unfreeze(copy=False)
160159

161160

161+
# ----------------------------------------------------------------------
162+
# Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163+
# ----------------------------------------------------------------------
164+
165+
162166
def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol):
163167
p = sympy.Symbol(key) if isinstance(key, str) else key
164168
return [resolver.value_of(p) for resolver in sweep]
165169

166170

167-
@transformer_api.transformer
168-
def merge_single_qubit_gates_to_phxz_symbolized(
169-
circuit: 'cirq.AbstractCircuit',
170-
*,
171-
context: Optional['cirq.TransformerContext'] = None,
172-
sweep: Sweep,
173-
atol: float = 1e-8,
174-
) -> Tuple['cirq.Circuit', Sweep]:
175-
"""Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized.
176-
177-
Example:
178-
# pylint: disable=line-too-long
179-
>>> q0, q1 = cirq.LineQubit.range(2)
180-
>>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0))
181-
>>> print(c)
182-
0: ───X───@──────────Y^y_exp───X───
183-
184-
1: ───────@^cz_exp─────────────────
185-
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
186-
c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\
187-
)
188-
>>> print(new_circuit)
189-
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
190-
191-
1: ────────────────────────@^cz_exp──────────────────────────
192-
>>> print(new_sweep)
193-
cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1])
194-
# pylint: disable=line-too-long
171+
def _merge_single_qubit_gates_to_circuit_op_symbolized(
172+
resolved_circuits: List['cirq.AbstractCircuit'],
173+
symbolized_single_tag: str,
174+
context: Optional['cirq.TransformerContext'],
175+
atol: float,
176+
) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]:
177+
"""Helper function to merge single qubit ops of resolved circuits to ops of CircuitOperation
178+
type using merge_k_qubit_unitaries.
195179
196180
Args:
197-
circuit: Input circuit to transform. It will not be modified.
198-
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
199-
based on the transformation.
200-
context: `cirq.TransformerContext` storing common configurable options for transformers.
201-
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
202-
dropped, smaller values increase accuracy.
181+
resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182+
symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols
183+
before parameterizations.
203184
204185
Returns:
205-
Copy of the transformed input circuit.
186+
Tuple of merge counts, merged circuits, and merge tags.
206187
"""
207-
deep = context.deep if context else False
208-
209-
if not protocols.is_parameterized(circuit):
210-
warnings.warn(
211-
"Expect parameterized circuits. "
212-
"Please use cirq.merge_single_qubit_gates_to_phxz instead.",
213-
UserWarning,
214-
)
215-
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol)
216-
217-
# Tag symbolized single qubit op.
218-
symbolized_single_tag = "_symbolized_single"
219-
220-
circuit_tagged = transformer_primitives.map_operations(
221-
circuit,
222-
lambda op, _: (
223-
op.with_tags(symbolized_single_tag)
224-
if protocols.is_parameterized(op) and len(op.qubits) == 1
225-
else op
226-
),
227-
deep=deep,
228-
)
229-
230-
# Symbols of the single qubit symbolized ops.
231-
single_qubit_gate_symbols: set[sympy.Symbol] = set().union(
232-
*[
233-
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
234-
for op in circuit_tagged.all_operations()
235-
]
236-
)
237-
# Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
238-
remaining_symbols = protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
239-
240-
sweep_of_single: Sweep = Zip(
241-
*[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols]
242-
)
243-
244-
# Get all resolved circuits from all sets of resolvers in sweep.
245-
resolved_circuits = [
246-
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
247-
]
248-
249-
# Store the number of merges for all set of resolvers,
250-
# it should be the same for all resolved circuits.
251-
merge_counts: list[int] = []
252-
merged_circuits = []
253-
phxz_tag_prefix = "_phxz"
188+
merge_counts: list[int] = [] # number of merges per resolved_circuit
189+
merged_circuits: list['cirq.Circuit'] = []
254190
tag_iter: itertools.count
191+
phxz_tag_prefix = "_phxz"
255192

256193
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
257194
nonlocal tag_iter
258195
tag: Optional[str] = None
196+
259197
u = protocols.unitary(circuit_op)
260198
if protocols.num_qubits(circuit_op) == 0:
261199
return ops.GlobalPhaseGate(u[0, 0]).on()
200+
# If any of the op in the merged circuit_op is a symbolized single qubit gate,
201+
# tag the merged phxz gate with next tag id, for further parameterization references.
262202
for op in circuit_op.circuit.all_operations():
263203
if symbolized_single_tag in op.tags:
264-
# Record parameterizations info via tags.
265204
tag = f"{phxz_tag_prefix}_{next(tag_iter)}"
266205
break
267206
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
268207
op = gate.on(circuit_op.qubits[0])
269-
if not gate:
270-
return []
271208
return op.with_tags(tag) if tag else op
272209

273210
for resolved_circuit in resolved_circuits:
@@ -280,62 +217,195 @@ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
280217
merge_counts.append(next(tag_iter))
281218

282219
if not all(count == merge_counts[0] for count in merge_counts):
283-
raise RuntimeError("Different resolvers in sweep result different merged strcuture.")
220+
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
284221

285-
# Get the output circuit from the first resolved circuits.
286-
merge_tags: set[str] = {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])}
287-
new_symbols: set[str] = set().union(
288-
*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])]
222+
merge_tags: frozenset[str] = frozenset(
223+
{f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])}
224+
)
225+
new_symbols: frozenset[str] = frozenset(
226+
set().union(*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])])
289227
)
290228

229+
return merged_circuits, merge_tags, new_symbols
230+
231+
232+
def _get_merge_tag_id(merge_tags: frozenset[str], op_tags: Tuple[Hashable, ...]) -> Optional[str]:
233+
"""Extract the id `i` from the merge tag `_phxz_i` if it exists."""
234+
the_merge_tag: set[str] = set(merge_tags.intersection(op_tags))
235+
if len(the_merge_tag) == 0:
236+
return None
237+
if len(the_merge_tag) > 1:
238+
raise RuntimeError("Multiple merge tags found.")
239+
return the_merge_tag.pop().split("_")[-1]
240+
241+
242+
def _map_merged_ops_to_symbolized_phxz(
243+
circuit: 'cirq.Circuit', merge_tags: frozenset[str], deep: bool
244+
) -> 'cirq.Circuit':
245+
"""Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
246+
247+
Args:
248+
circuit: Circuit with merge tags to be mapped.
249+
merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
250+
symbolized.
251+
deep: Whether to perform the mapping recursively within CircuitOperations.
252+
253+
Returns:
254+
A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
255+
"""
256+
257+
# Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
291258
def _map_func(op: 'cirq.Operation', _):
292-
"""Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
293-
the_merge_tag = merge_tags.intersection(op.tags)
294-
if len(the_merge_tag) == 0:
259+
"""Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
260+
sid = _get_merge_tag_id(merge_tags, op.tags)
261+
if sid is None:
295262
return op
296-
if len(the_merge_tag) > 1:
297-
raise RuntimeError("Multiple merge tags found.")
298-
sid = the_merge_tag.pop().split("_")[-1]
299263
phxz_params = {
300264
"x_exponent": sympy.Symbol(f"x{sid}"),
301265
"z_exponent": sympy.Symbol(f"z{sid}"),
302266
"axis_phase_exponent": sympy.Symbol(f"a{sid}"),
303267
}
304268
return ops.PhasedXZGate(**phxz_params).on(*op.qubits)
305269

306-
output_circuit: 'cirq.Circuit' = align.align_right(
307-
transformer_primitives.map_operations(merged_circuits[0].freeze(), _map_func, deep=deep)
270+
return align.align_right(
271+
transformer_primitives.map_operations(circuit.freeze(), _map_func, deep=deep)
308272
)
309273

274+
275+
def _parameterize_merged_circuits(
276+
merged_circuits: List['cirq.Circuit'],
277+
merge_tags: frozenset[str],
278+
new_symbols: frozenset[str],
279+
remaining_symbols: frozenset[str],
280+
sweep: Sweep,
281+
) -> Sweep:
282+
"""Parameterizes the merged circuits and returns a new sweep."""
310283
values_by_params: Dict[str, List[float]] = {
311-
**{s: [] for s in new_symbols}, # New symbols introduced in merging
284+
**{s: [] for s in new_symbols}, # New symbols introduced during merging
312285
**{
313286
s: _values_of_sweep(sweep, s) for s in remaining_symbols
314-
}, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates.
287+
}, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates.
315288
}
316289

317-
# Get parameterization for the merged phxz gates.
318290
for merged_circuit in merged_circuits:
319291
for op in merged_circuit.all_operations():
320-
the_merge_tag = merge_tags.intersection(op.tags)
321-
if len(the_merge_tag) == 0:
292+
sid = _get_merge_tag_id(merge_tags, op.tags)
293+
if sid is None:
322294
continue
323-
if len(the_merge_tag) > 1:
324-
raise RuntimeError("Multiple merge tags found.")
325-
sid = the_merge_tag.pop().split("_")[-1]
326-
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters.
295+
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
327296
if isinstance(op.gate, ops.PhasedXZGate):
328297
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
329298
elif op.gate is not ops.I:
330299
raise RuntimeError(
331-
f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got {op.gate}."
300+
f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
301+
f" but got {op.gate}."
332302
)
333303
values_by_params[f"x{sid}"].append(x)
334304
values_by_params[f"z{sid}"].append(z)
335305
values_by_params[f"a{sid}"].append(a)
336306

337-
new_sweep: Sweep = Zip(
338-
*[Points(key=key, points=values) for key, values in values_by_params.items()]
307+
return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()])
308+
309+
310+
def merge_single_qubit_gates_to_phxz_symbolized(
311+
circuit: 'cirq.AbstractCircuit',
312+
*,
313+
context: Optional['cirq.TransformerContext'] = None,
314+
sweep: Sweep,
315+
atol: float = 1e-8,
316+
) -> Tuple['cirq.Circuit', Sweep]:
317+
"""Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive
318+
gates is symbolized.
319+
320+
Example:
321+
>>> q0, q1 = cirq.LineQubit.range(2)
322+
>>> c = cirq.Circuit(\
323+
cirq.X(q0),\
324+
cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
325+
cirq.Y(q0)**sympy.Symbol("y_exp"),\
326+
cirq.X(q0))
327+
>>> print(c)
328+
0: ───X───@──────────Y^y_exp───X───
329+
330+
1: ───────@^cz_exp─────────────────
331+
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
332+
c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
333+
cirq.Points(key="y_exp", points=[0, 1])))
334+
>>> print(new_circuit)
335+
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
336+
337+
1: ────────────────────────@^cz_exp──────────────────────────
338+
>>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
339+
>>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
340+
341+
Args:
342+
circuit: Input circuit to transform. It will not be modified.
343+
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
344+
based on the transformation.
345+
context: `cirq.TransformerContext` storing common configurable options for transformers.
346+
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
347+
dropped, smaller values increase accuracy.
348+
349+
Returns:
350+
Copy of the transformed input circuit.
351+
"""
352+
deep = context.deep if context else False
353+
354+
# Tag symbolized single-qubit op.
355+
symbolized_single_tag = "_symbolized_single"
356+
357+
circuit_tagged = transformer_primitives.map_operations(
358+
circuit,
359+
lambda op, _: (
360+
op.with_tags(symbolized_single_tag)
361+
if protocols.is_parameterized(op) and len(op.qubits) == 1
362+
else op
363+
),
364+
deep=deep,
365+
)
366+
367+
# Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
368+
369+
single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset(
370+
set().union(
371+
*[
372+
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
373+
for op in circuit_tagged.all_operations()
374+
]
375+
)
339376
)
377+
# If all single qubit gates are not parameterized, call the nonparamerized version of
378+
# the transformer.
379+
if not single_qubit_gate_symbols:
380+
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep
381+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
382+
remaining_symbols: frozenset[sympy.Symbol] = frozenset(
383+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
384+
)
385+
sweep_of_single: Sweep = Zip(
386+
*[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols]
387+
)
388+
# Get all resolved circuits from all sets of resolvers in the sweep.
389+
resolved_circuits = [
390+
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
391+
]
392+
393+
# Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
394+
merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized(
395+
resolved_circuits, symbolized_single_tag, context, atol
396+
)
397+
398+
# Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
399+
new_circuit = _map_merged_ops_to_symbolized_phxz(merged_circuits[0], merge_tags, deep)
400+
401+
# Step 3, get N sets of parameterizations as new_sweep.
402+
new_sweep = _parameterize_merged_circuits(
403+
merged_circuits, merge_tags, new_symbols, remaining_symbols, sweep
404+
)
405+
406+
return new_circuit.unfreeze(copy=False), new_sweep
407+
340408

341-
return output_circuit.unfreeze(copy=False), new_sweep
409+
# ----------------------------------------------------------------------
410+
# Impl merge_single_qubit_gates_to_phxz_symbolized: End
411+
# ----------------------------------------------------------------------

0 commit comments

Comments
 (0)