Skip to content

Commit 33a8233

Browse files
committed
fix checks.
1 parent 99bd54e commit 33a8233

File tree

3 files changed

+288
-124
lines changed

3 files changed

+288
-124
lines changed

cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import sympy
2626
from attrs import field, frozen
2727

28-
from cirq.transformers import transformer_api
29-
from cirq import ops, circuits
28+
from cirq import circuits, ops
3029
from cirq.protocols import unitary_protocol
3130
from cirq.protocols.has_unitary_protocol import has_unitary
3231
from cirq.study.sweeps import Points, Sweep, Zip
32+
from cirq.transformers import transformer_api
3333
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
3434

3535

cirq-core/cirq/transformers/merge_single_qubit_gates.py

+183-114
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
"""Transformer passes to combine adjacent single-qubit rotations."""
1616

1717
import itertools
18-
import warnings
19-
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
18+
from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING
2019

2120
import sympy
2221

2322
from cirq import circuits, ops, protocols
2423
from cirq.study.sweeps import Points, Sweep, Zip
25-
from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives, align
24+
from cirq.transformers import align, merge_k_qubit_gates, transformer_api, transformer_primitives
2625
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
2726

2827
if TYPE_CHECKING:
@@ -159,115 +158,52 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
159158
).unfreeze(copy=False)
160159

161160

161+
# ----------------------------------------------------------------------
162+
# Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163+
# ----------------------------------------------------------------------
164+
165+
162166
def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol):
163167
p = sympy.Symbol(key) if isinstance(key, str) else key
164168
return [resolver.value_of(p) for resolver in sweep]
165169

166170

167-
@transformer_api.transformer
168-
def merge_single_qubit_gates_to_phxz_symbolized(
169-
circuit: 'cirq.AbstractCircuit',
170-
*,
171-
context: Optional['cirq.TransformerContext'] = None,
172-
sweep: Sweep,
173-
atol: float = 1e-8,
174-
) -> Tuple['cirq.Circuit', Sweep]:
175-
"""Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized.
176-
177-
Example:
178-
# pylint: disable=line-too-long
179-
>>> q0, q1 = cirq.LineQubit.range(2)
180-
>>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0))
181-
>>> print(c)
182-
0: ───X───@──────────Y^y_exp───X───
183-
184-
1: ───────@^cz_exp─────────────────
185-
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
186-
c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\
187-
)
188-
>>> print(new_circuit)
189-
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
190-
191-
1: ────────────────────────@^cz_exp──────────────────────────
192-
>>> print(new_sweep)
193-
cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1])
194-
# pylint: disable=line-too-long
171+
def _merge_single_qubit_gates_to_circuit_op_symbolized(
172+
resolved_circuits: List['cirq.AbstractCircuit'],
173+
symbolized_single_tag: str,
174+
context: Optional['cirq.TransformerContext'],
175+
atol: float,
176+
) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]:
177+
"""Helper function to merge single qubit gates of resolved circuits to op of CircuitOperation type
178+
using merge_k_qubit_unitaries.
195179
196180
Args:
197-
circuit: Input circuit to transform. It will not be modified.
198-
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
199-
based on the transformation.
200-
context: `cirq.TransformerContext` storing common configurable options for transformers.
201-
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
202-
dropped, smaller values increase accuracy.
181+
resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182+
symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols before parameterization.
203183
204184
Returns:
205-
Copy of the transformed input circuit.
185+
Tuple of merge counts, merged circuits, and merge tags.
206186
"""
207-
deep = context.deep if context else False
208-
209-
if not protocols.is_parameterized(circuit):
210-
warnings.warn(
211-
"Expect parameterized circuits. "
212-
"Please use cirq.merge_single_qubit_gates_to_phxz instead.",
213-
UserWarning,
214-
)
215-
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol)
216-
217-
# Tag symbolized single qubit op.
218-
symbolized_single_tag = "_symbolized_single"
219-
220-
circuit_tagged = transformer_primitives.map_operations(
221-
circuit,
222-
lambda op, _: (
223-
op.with_tags(symbolized_single_tag)
224-
if protocols.is_parameterized(op) and len(op.qubits) == 1
225-
else op
226-
),
227-
deep=deep,
228-
)
229-
230-
# Symbols of the single qubit symbolized ops.
231-
single_qubit_gate_symbols: set[sympy.Symbol] = set().union(
232-
*[
233-
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
234-
for op in circuit_tagged.all_operations()
235-
]
236-
)
237-
# Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
238-
remaining_symbols = protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
239-
240-
sweep_of_single: Sweep = Zip(
241-
*[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols]
242-
)
243-
244-
# Get all resolved circuits from all sets of resolvers in sweep.
245-
resolved_circuits = [
246-
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
247-
]
248-
249-
# Store the number of merges for all set of resolvers,
250-
# it should be the same for all resolved circuits.
251-
merge_counts: list[int] = []
252-
merged_circuits = []
253-
phxz_tag_prefix = "_phxz"
187+
merge_counts: list[int] = [] # number of merges per resolved_circuit
188+
merged_circuits: list['cirq.Circuit'] = []
254189
tag_iter: itertools.count
190+
phxz_tag_prefix = "_phxz"
255191

256192
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
257193
nonlocal tag_iter
258194
tag: Optional[str] = None
195+
259196
u = protocols.unitary(circuit_op)
260197
if protocols.num_qubits(circuit_op) == 0:
261198
return ops.GlobalPhaseGate(u[0, 0]).on()
199+
# If any of the op in the merged circuit_op is a symbolized single qubit gate,
200+
# tag the merged phxz gate with next tag id, for further parameterization references.
262201
for op in circuit_op.circuit.all_operations():
263202
if symbolized_single_tag in op.tags:
264-
# Record parameterizations info via tags.
265203
tag = f"{phxz_tag_prefix}_{next(tag_iter)}"
266204
break
267205
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
268206
op = gate.on(circuit_op.qubits[0])
269-
if not gate:
270-
return []
271207
return op.with_tags(tag) if tag else op
272208

273209
for resolved_circuit in resolved_circuits:
@@ -280,62 +216,195 @@ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
280216
merge_counts.append(next(tag_iter))
281217

282218
if not all(count == merge_counts[0] for count in merge_counts):
283-
raise RuntimeError("Different resolvers in sweep result different merged strcuture.")
219+
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
284220

285-
# Get the output circuit from the first resolved circuits.
286-
merge_tags: set[str] = {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])}
287-
new_symbols: set[str] = set().union(
288-
*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])]
221+
merge_tags: frozenset[str] = frozenset(
222+
{f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])}
223+
)
224+
new_symbols: frozenset[str] = frozenset(
225+
set().union(*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])])
289226
)
290227

228+
return merged_circuits, merge_tags, new_symbols
229+
230+
231+
def _get_merge_tag_id(merge_tags: frozenset[str], op_tags: Tuple[Hashable, ...]) -> Optional[str]:
232+
"""Extract the id `i` from the merge tag `_phxz_i` if it exists."""
233+
the_merge_tag: set[str] = set(merge_tags.intersection(op_tags))
234+
if len(the_merge_tag) == 0:
235+
return None
236+
if len(the_merge_tag) > 1:
237+
raise RuntimeError("Multiple merge tags found.")
238+
return the_merge_tag.pop().split("_")[-1]
239+
240+
241+
def _map_merged_ops_to_symbolized_phxz(
242+
circuit: 'cirq.Circuit', merge_tags: frozenset[str], deep: bool
243+
) -> 'cirq.Circuit':
244+
"""Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
245+
246+
Args:
247+
circuit: Circuit with merge tags to be mapped.
248+
merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
249+
symbolized.
250+
deep: Whether to perform the mapping recursively within CircuitOperations.
251+
252+
Returns:
253+
A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
254+
"""
255+
256+
# Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
291257
def _map_func(op: 'cirq.Operation', _):
292-
"""Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
293-
the_merge_tag = merge_tags.intersection(op.tags)
294-
if len(the_merge_tag) == 0:
258+
"""Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
259+
sid = _get_merge_tag_id(merge_tags, op.tags)
260+
if sid is None:
295261
return op
296-
if len(the_merge_tag) > 1:
297-
raise RuntimeError("Multiple merge tags found.")
298-
sid = the_merge_tag.pop().split("_")[-1]
299262
phxz_params = {
300263
"x_exponent": sympy.Symbol(f"x{sid}"),
301264
"z_exponent": sympy.Symbol(f"z{sid}"),
302265
"axis_phase_exponent": sympy.Symbol(f"a{sid}"),
303266
}
304267
return ops.PhasedXZGate(**phxz_params).on(*op.qubits)
305268

306-
output_circuit: 'cirq.Circuit' = align.align_right(
307-
transformer_primitives.map_operations(merged_circuits[0].freeze(), _map_func, deep=deep)
269+
return align.align_right(
270+
transformer_primitives.map_operations(circuit.freeze(), _map_func, deep=deep)
308271
)
309272

273+
274+
def _parameterize_merged_circuits(
275+
merged_circuits: List['cirq.Circuit'],
276+
merge_tags: frozenset[str],
277+
new_symbols: frozenset[str],
278+
remaining_symbols: frozenset[str],
279+
sweep: Sweep,
280+
) -> Sweep:
281+
"""Parameterizes the merged circuits and returns a new sweep."""
310282
values_by_params: Dict[str, List[float]] = {
311-
**{s: [] for s in new_symbols}, # New symbols introduced in merging
283+
**{s: [] for s in new_symbols}, # New symbols introduced during merging
312284
**{
313285
s: _values_of_sweep(sweep, s) for s in remaining_symbols
314-
}, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates.
286+
}, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates.
315287
}
316288

317-
# Get parameterization for the merged phxz gates.
318289
for merged_circuit in merged_circuits:
319290
for op in merged_circuit.all_operations():
320-
the_merge_tag = merge_tags.intersection(op.tags)
321-
if len(the_merge_tag) == 0:
291+
sid = _get_merge_tag_id(merge_tags, op.tags)
292+
if sid is None:
322293
continue
323-
if len(the_merge_tag) > 1:
324-
raise RuntimeError("Multiple merge tags found.")
325-
sid = the_merge_tag.pop().split("_")[-1]
326-
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters.
294+
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
327295
if isinstance(op.gate, ops.PhasedXZGate):
328296
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
329297
elif op.gate is not ops.I:
330298
raise RuntimeError(
331-
f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got {op.gate}."
299+
f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
300+
f" but got {op.gate}."
332301
)
333302
values_by_params[f"x{sid}"].append(x)
334303
values_by_params[f"z{sid}"].append(z)
335304
values_by_params[f"a{sid}"].append(a)
336305

337-
new_sweep: Sweep = Zip(
338-
*[Points(key=key, points=values) for key, values in values_by_params.items()]
306+
return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()])
307+
308+
309+
def merge_single_qubit_gates_to_phxz_symbolized(
310+
circuit: 'cirq.AbstractCircuit',
311+
*,
312+
context: Optional['cirq.TransformerContext'] = None,
313+
sweep: Sweep,
314+
atol: float = 1e-8,
315+
) -> Tuple['cirq.Circuit', Sweep]:
316+
"""Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive
317+
gates is symbolized.
318+
319+
Example:
320+
>>> q0, q1 = cirq.LineQubit.range(2)
321+
>>> c = cirq.Circuit(\
322+
cirq.X(q0),\
323+
cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
324+
cirq.Y(q0)**sympy.Symbol("y_exp"),\
325+
cirq.X(q0))
326+
>>> print(c)
327+
0: ───X───@──────────Y^y_exp───X───
328+
329+
1: ───────@^cz_exp─────────────────
330+
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
331+
c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
332+
cirq.Points(key="y_exp", points=[0, 1])))
333+
>>> print(new_circuit)
334+
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
335+
336+
1: ────────────────────────@^cz_exp──────────────────────────
337+
>>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
338+
>>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
339+
340+
Args:
341+
circuit: Input circuit to transform. It will not be modified.
342+
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
343+
based on the transformation.
344+
context: `cirq.TransformerContext` storing common configurable options for transformers.
345+
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
346+
dropped, smaller values increase accuracy.
347+
348+
Returns:
349+
Copy of the transformed input circuit.
350+
"""
351+
deep = context.deep if context else False
352+
353+
# Tag symbolized single-qubit op.
354+
symbolized_single_tag = "_symbolized_single"
355+
356+
circuit_tagged = transformer_primitives.map_operations(
357+
circuit,
358+
lambda op, _: (
359+
op.with_tags(symbolized_single_tag)
360+
if protocols.is_parameterized(op) and len(op.qubits) == 1
361+
else op
362+
),
363+
deep=deep,
364+
)
365+
366+
# Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
367+
368+
single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset(
369+
set().union(
370+
*[
371+
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
372+
for op in circuit_tagged.all_operations()
373+
]
374+
)
339375
)
376+
# If all single qubit gates are not parameterized, call the nonparamerized version of
377+
# the transformer.
378+
if not single_qubit_gate_symbols:
379+
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep
380+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
381+
remaining_symbols: frozenset[sympy.Symbol] = frozenset(
382+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
383+
)
384+
sweep_of_single: Sweep = Zip(
385+
*[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols]
386+
)
387+
# Get all resolved circuits from all sets of resolvers in the sweep.
388+
resolved_circuits = [
389+
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
390+
]
391+
392+
# Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
393+
merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized(
394+
resolved_circuits, symbolized_single_tag, context, atol
395+
)
396+
397+
# Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
398+
new_circuit = _map_merged_ops_to_symbolized_phxz(merged_circuits[0], merge_tags, deep)
399+
400+
# Step 3, get N sets of parameterizations as new_sweep.
401+
new_sweep = _parameterize_merged_circuits(
402+
merged_circuits, merge_tags, new_symbols, remaining_symbols, sweep
403+
)
404+
405+
return new_circuit.unfreeze(copy=False), new_sweep
406+
340407

341-
return output_circuit.unfreeze(copy=False), new_sweep
408+
# ----------------------------------------------------------------------
409+
# Impl merge_single_qubit_gates_to_phxz_symbolized: End
410+
# ----------------------------------------------------------------------

0 commit comments

Comments
 (0)