Skip to content

Commit ecb6a83

Browse files
committed
Expose more transformers
tag_transformers: remove_tags, index_tags. symbolize: symbolize_single_qubit_gates_by_indexed_tags
1 parent 3c46507 commit ecb6a83

8 files changed

+420
-194
lines changed

cirq-core/cirq/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@
363363
eject_z as eject_z,
364364
expand_composite as expand_composite,
365365
HardCodedInitialMapper as HardCodedInitialMapper,
366+
index_tags as index_tags,
366367
is_negligible_turn as is_negligible_turn,
367368
LineInitialMapper as LineInitialMapper,
368369
MappingManager as MappingManager,
@@ -379,13 +380,15 @@
379380
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
380381
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
381382
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
383+
symbolize_single_qubit_gates_by_indexed_tags as symbolize_single_qubit_gates_by_indexed_tags,
382384
optimize_for_target_gateset as optimize_for_target_gateset,
383385
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,
384386
prepare_two_qubit_state_using_cz as prepare_two_qubit_state_using_cz,
385387
prepare_two_qubit_state_using_iswap as prepare_two_qubit_state_using_iswap,
386388
prepare_two_qubit_state_using_sqrt_iswap as prepare_two_qubit_state_using_sqrt_iswap,
387389
quantum_shannon_decomposition as quantum_shannon_decomposition,
388390
RouteCQC as RouteCQC,
391+
remove_tags as remove_tags,
389392
routed_circuit_with_mapping as routed_circuit_with_mapping,
390393
SqrtIswapTargetGateset as SqrtIswapTargetGateset,
391394
single_qubit_matrix_to_gates as single_qubit_matrix_to_gates,

cirq-core/cirq/transformers/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@
104104
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
105105
)
106106

107+
from cirq.transformers.tag_transformers import index_tags as index_tags, remove_tags as remove_tags
108+
from cirq.transformers.symbolize import (
109+
symbolize_single_qubit_gates_by_indexed_tags as symbolize_single_qubit_gates_by_indexed_tags,
110+
)
111+
112+
107113
from cirq.transformers.qubit_management_transformers import (
108114
map_clean_and_borrowable_qubits as map_clean_and_borrowable_qubits,
109115
)

cirq-core/cirq/transformers/merge_single_qubit_gates.py

+108-153
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@
1414

1515
"""Transformer passes to combine adjacent single-qubit rotations."""
1616

17-
import itertools
18-
from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING
17+
from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING
1918

2019
import sympy
2120

2221
from cirq import circuits, ops, protocols
22+
from cirq.study.result import TMeasurementKey
2323
from cirq.study.sweeps import Points, Sweep, Zip
24-
from cirq.transformers import align, merge_k_qubit_gates, transformer_api, transformer_primitives
24+
from cirq.transformers import (
25+
align,
26+
merge_k_qubit_gates,
27+
symbolize,
28+
transformer_api,
29+
transformer_primitives,
30+
)
2531
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
32+
from cirq.transformers.tag_transformers import index_tags, remove_tags
2633

2734
if TYPE_CHECKING:
2835
import cirq
@@ -69,6 +76,7 @@ def merge_single_qubit_gates_to_phxz(
6976
circuit: 'cirq.AbstractCircuit',
7077
*,
7178
context: Optional['cirq.TransformerContext'] = None,
79+
merge_tags_fn: Optional[Callable[['cirq.CircuitOperation'], List[Hashable]]] = None,
7280
atol: float = 1e-8,
7381
) -> 'cirq.Circuit':
7482
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
@@ -79,19 +87,24 @@ def merge_single_qubit_gates_to_phxz(
7987
Args:
8088
circuit: Input circuit to transform. It will not be modified.
8189
context: `cirq.TransformerContext` storing common configurable options for transformers.
90+
merge_tag: If provided, tag merged PhXZ gate with it.
91+
merge_tags_fn: A callable returns the tags to be added to the merged operation.
8292
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
8393
dropped, smaller values increase accuracy.
8494
8595
Returns:
8696
Copy of the transformed input circuit.
8797
"""
8898

89-
def rewriter(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
90-
u = protocols.unitary(op)
91-
if protocols.num_qubits(op) == 0:
99+
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
100+
101+
u = protocols.unitary(circuit_op)
102+
if protocols.num_qubits(circuit_op) == 0:
92103
return ops.GlobalPhaseGate(u[0, 0]).on()
93-
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol)
94-
return gate(op.qubits[0]) if gate else []
104+
105+
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
106+
phxz_op = gate.on(circuit_op.qubits[0])
107+
return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op
95108

96109
return merge_k_qubit_gates.merge_k_qubit_unitaries(
97110
circuit, k=1, context=context, rewriter=rewriter
@@ -158,141 +171,33 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
158171
).unfreeze(copy=False)
159172

160173

161-
# ----------------------------------------------------------------------
162-
# Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163-
# ----------------------------------------------------------------------
164-
165-
166-
def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol):
174+
def _values_of_sweep(sweep: Sweep, key: TMeasurementKey):
167175
p = sympy.Symbol(key) if isinstance(key, str) else key
168176
return [resolver.value_of(p) for resolver in sweep]
169177

170178

171-
def _merge_single_qubit_gates_to_phxz_symbolized(
172-
resolved_circuits: List['cirq.AbstractCircuit'],
173-
symbolized_single_tag: str,
174-
context: Optional['cirq.TransformerContext'],
175-
atol: float,
176-
) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]:
177-
"""Helper function to merge single qubit ops of resolved circuits to PhasedXZ ops
178-
using merge_k_qubit_unitaries.
179-
180-
Args:
181-
resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182-
symbolized_single_tag: The tag applied to single-qubit operations that originally
183-
contained symbols before parameterizations.
184-
185-
Returns:
186-
Tuple of merge_counts, merged_circuits, and merge_tags, where
187-
merged ops in merged_circuits are tagged by merge_tags.
188-
"""
189-
merge_counts: list[int] = [] # number of merges per resolved_circuit
190-
merged_circuits: list['cirq.Circuit'] = []
191-
tag_iter: itertools.count
192-
phxz_tag_prefix = "_phxz"
193-
194-
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
195-
nonlocal tag_iter
196-
tag: Optional[str] = None
197-
198-
u = protocols.unitary(circuit_op)
199-
if protocols.num_qubits(circuit_op) == 0:
200-
return ops.GlobalPhaseGate(u[0, 0]).on()
201-
# If any of the op in the merged circuit_op is a symbolized single qubit gate,
202-
# tag the merged phxz gate with next tag id, for further parameterization references.
203-
for op in circuit_op.circuit.all_operations():
204-
if symbolized_single_tag in op.tags:
205-
tag = f"{phxz_tag_prefix}_{next(tag_iter)}"
206-
break
207-
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
208-
op = gate.on(circuit_op.qubits[0])
209-
return op.with_tags(tag) if tag else op
210-
211-
for resolved_circuit in resolved_circuits:
212-
tag_iter = itertools.count(start=0, step=1)
213-
merged_circuits.append(
214-
merge_k_qubit_gates.merge_k_qubit_unitaries(
215-
resolved_circuit, k=1, context=context, rewriter=rewriter
216-
)
217-
)
218-
merge_counts.append(next(tag_iter))
219-
220-
if not all(count == merge_counts[0] for count in merge_counts):
221-
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
222-
223-
merge_tags: frozenset[str] = frozenset(
224-
{f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])}
225-
)
226-
new_symbols: frozenset[str] = frozenset(
227-
set().union(*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])])
228-
)
229-
230-
return merged_circuits, merge_tags, new_symbols
231-
232-
233-
def _get_merge_tag_id(merge_tags: frozenset[str], op_tags: Tuple[Hashable, ...]) -> Optional[str]:
234-
"""Extract the id `i` from the merge tag `_phxz_i` if it exists."""
235-
the_merge_tag: set[str] = set(merge_tags.intersection(op_tags))
236-
if len(the_merge_tag) == 0:
237-
return None
238-
if len(the_merge_tag) > 1:
239-
raise RuntimeError("Multiple merge tags found.")
240-
return the_merge_tag.pop().split("_")[-1]
241-
242-
243-
def _map_merged_ops_to_symbolized_phxz(
244-
circuit: 'cirq.Circuit', merge_tags: frozenset[str], deep: bool
245-
) -> 'cirq.Circuit':
246-
"""Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
247-
248-
Args:
249-
circuit: Circuit with merge tags to be mapped.
250-
merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
251-
symbolized.
252-
deep: Whether to perform the mapping recursively within CircuitOperations.
253-
254-
Returns:
255-
A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
256-
"""
257-
258-
# Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
259-
def _map_func(op: 'cirq.Operation', _):
260-
"""Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
261-
sid = _get_merge_tag_id(merge_tags, op.tags)
262-
if sid is None:
263-
return op
264-
phxz_params = {
265-
"x_exponent": sympy.Symbol(f"x{sid}"),
266-
"z_exponent": sympy.Symbol(f"z{sid}"),
267-
"axis_phase_exponent": sympy.Symbol(f"a{sid}"),
268-
}
269-
return ops.PhasedXZGate(**phxz_params).on(*op.qubits)
270-
271-
return align.align_right(
272-
transformer_primitives.map_operations(circuit.freeze(), _map_func, deep=deep)
273-
)
274-
275-
276-
def _parameterize_merged_circuits(
277-
merged_circuits: List['cirq.Circuit'],
278-
merge_tags: frozenset[str],
279-
new_symbols: frozenset[str],
280-
remaining_symbols: frozenset[str],
179+
def _parameterize_phxz_in_circuits(
180+
circuit_list: List['cirq.Circuit'],
181+
merge_tag_prefix: str,
182+
phxz_symbols: frozenset[sympy.Symbol],
183+
remaining_symbols: frozenset[sympy.Symbol],
281184
sweep: Sweep,
282185
) -> Sweep:
283-
"""Parameterizes the merged circuits and returns a new sweep."""
186+
"""Parameterizes the circuits and returns a new sweep."""
284187
values_by_params: Dict[str, List[float]] = {
285-
**{s: [] for s in new_symbols}, # New symbols introduced during merging
286-
**{
287-
s: _values_of_sweep(sweep, s) for s in remaining_symbols
288-
}, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates.
188+
**{str(s): [] for s in phxz_symbols},
189+
**{str(s): _values_of_sweep(sweep, s) for s in remaining_symbols},
289190
}
290191

291-
for merged_circuit in merged_circuits:
292-
for op in merged_circuit.all_operations():
293-
sid = _get_merge_tag_id(merge_tags, op.tags)
294-
if sid is None:
192+
for circuit in circuit_list:
193+
for op in circuit.all_operations():
194+
the_merge_tag: Optional[str] = None
195+
for tag in op.tags:
196+
if str(tag).startswith(merge_tag_prefix):
197+
the_merge_tag = str(tag)
198+
if not the_merge_tag:
295199
continue
200+
sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1]
296201
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
297202
if isinstance(op.gate, ops.PhasedXZGate):
298203
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
@@ -308,6 +213,15 @@ def _parameterize_merged_circuits(
308213
return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()])
309214

310215

216+
def _all_tags_startswith(circuit: 'cirq.AbstractCircuit', startswith: str):
217+
tag_set: set[Hashable] = set()
218+
for op in circuit.all_operations():
219+
for tag in op.tags:
220+
if str(tag).startswith(startswith):
221+
tag_set.add(tag)
222+
return tag_set
223+
224+
311225
def merge_single_qubit_gates_to_phxz_symbolized(
312226
circuit: 'cirq.AbstractCircuit',
313227
*,
@@ -353,7 +267,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
353267
deep = context.deep if context else False
354268

355269
# Tag symbolized single-qubit op.
356-
symbolized_single_tag = "_symbolized_single"
270+
symbolized_single_tag = "TMP-TAG-symbolized-single"
357271

358272
circuit_tagged = transformer_primitives.map_operations(
359273
circuit,
@@ -366,7 +280,6 @@ def merge_single_qubit_gates_to_phxz_symbolized(
366280
)
367281

368282
# Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
369-
370283
single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset(
371284
set().union(
372285
*[
@@ -378,11 +291,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
378291
# If all single qubit gates are not parameterized, call the nonparamerized version of
379292
# the transformer.
380293
if not single_qubit_gate_symbols:
381-
return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep
382-
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
383-
remaining_symbols: frozenset[sympy.Symbol] = frozenset(
384-
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
385-
)
294+
return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep)
386295
sweep_of_single: Sweep = Zip(
387296
*[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols]
388297
)
@@ -391,22 +300,68 @@ def merge_single_qubit_gates_to_phxz_symbolized(
391300
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
392301
]
393302

394-
# Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
395-
merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_phxz_symbolized(
396-
resolved_circuits, symbolized_single_tag, context, atol
397-
)
303+
# Step 1, merge single qubit gates per resolved circuit, preserving the "symbolized_single_tag".
304+
merged_circuits: List['cirq.Circuit'] = []
305+
phxz_symbols: set[sympy.Symbols] = set()
306+
for resolved_circuit in resolved_circuits:
307+
merged_circuit = index_tags(
308+
merge_single_qubit_gates_to_phxz(
309+
resolved_circuit,
310+
context=context,
311+
merge_tags_fn=lambda circuit_op: (
312+
[symbolized_single_tag]
313+
if any(
314+
symbolized_single_tag in set(op.tags)
315+
for op in circuit_op.circuit.all_operations()
316+
)
317+
else []
318+
),
319+
atol=atol,
320+
),
321+
target_tags={symbolized_single_tag},
322+
context=context,
323+
)
324+
merged_circuits.append(merged_circuit)
325+
326+
if not all(
327+
_all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
328+
== _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag)
329+
for merged_circuit in merged_circuits
330+
):
331+
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
398332

399-
# Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
400-
new_circuit = _map_merged_ops_to_symbolized_phxz(merged_circuits[0], merge_tags, deep)
333+
# Step 2, get the new symbolized circuit by mapping merged operations.
334+
new_circuit = align.align_right(
335+
remove_tags(
336+
symbolize.symbolize_single_qubit_gates_by_indexed_tags(
337+
merged_circuits[0], tag_prefix=symbolized_single_tag
338+
),
339+
remove_if=lambda tag: tag.startswith(symbolized_single_tag),
340+
)
341+
)
401342

402343
# Step 3, get N sets of parameterizations as new_sweep.
403-
new_sweep = _parameterize_merged_circuits(
404-
merged_circuits, merge_tags, new_symbols, remaining_symbols, sweep
344+
phxz_symbols: frozenset[sympy.Symbol] = frozenset(
345+
set().union(
346+
*[
347+
set(
348+
[
349+
sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s))
350+
for s in ["x", "z", "a"]
351+
]
352+
)
353+
for tag in _all_tags_startswith(
354+
merged_circuits[0], startswith=symbolized_single_tag
355+
)
356+
]
357+
)
358+
)
359+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
360+
remaining_symbols: frozenset[sympy.Symbol] = frozenset(
361+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
362+
)
363+
new_sweep = _parameterize_phxz_in_circuits(
364+
merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep
405365
)
406366

407367
return new_circuit.unfreeze(copy=False), new_sweep
408-
409-
410-
# ----------------------------------------------------------------------
411-
# Impl merge_single_qubit_gates_to_phxz_symbolized: End
412-
# ----------------------------------------------------------------------

0 commit comments

Comments
 (0)