14
14
15
15
"""Transformer passes to combine adjacent single-qubit rotations."""
16
16
17
- import itertools
18
- from typing import Dict , Hashable , List , Optional , Tuple , TYPE_CHECKING
17
+ from typing import Callable , Dict , Hashable , List , Optional , Tuple , TYPE_CHECKING
19
18
20
19
import sympy
21
20
22
21
from cirq import circuits , ops , protocols
22
+ from cirq .study .result import TMeasurementKey
23
23
from cirq .study .sweeps import Points , Sweep , Zip
24
- from cirq .transformers import align , merge_k_qubit_gates , transformer_api , transformer_primitives
24
+ from cirq .transformers import (
25
+ align ,
26
+ merge_k_qubit_gates ,
27
+ symbolize ,
28
+ transformer_api ,
29
+ transformer_primitives ,
30
+ )
25
31
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
32
+ from cirq .transformers .tag_transformers import index_tags , remove_tags
26
33
27
34
if TYPE_CHECKING :
28
35
import cirq
@@ -69,6 +76,7 @@ def merge_single_qubit_gates_to_phxz(
69
76
circuit : 'cirq.AbstractCircuit' ,
70
77
* ,
71
78
context : Optional ['cirq.TransformerContext' ] = None ,
79
+ merge_tags_fn : Optional [Callable [['cirq.CircuitOperation' ], List [Hashable ]]] = None ,
72
80
atol : float = 1e-8 ,
73
81
) -> 'cirq.Circuit' :
74
82
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
@@ -79,19 +87,24 @@ def merge_single_qubit_gates_to_phxz(
79
87
Args:
80
88
circuit: Input circuit to transform. It will not be modified.
81
89
context: `cirq.TransformerContext` storing common configurable options for transformers.
90
+ merge_tag: If provided, tag merged PhXZ gate with it.
91
+ merge_tags_fn: A callable returns the tags to be added to the merged operation.
82
92
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
83
93
dropped, smaller values increase accuracy.
84
94
85
95
Returns:
86
96
Copy of the transformed input circuit.
87
97
"""
88
98
89
- def rewriter (op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
90
- u = protocols .unitary (op )
91
- if protocols .num_qubits (op ) == 0 :
99
+ def rewriter (circuit_op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
100
+
101
+ u = protocols .unitary (circuit_op )
102
+ if protocols .num_qubits (circuit_op ) == 0 :
92
103
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
93
- gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol )
94
- return gate (op .qubits [0 ]) if gate else []
104
+
105
+ gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
106
+ phxz_op = gate .on (circuit_op .qubits [0 ])
107
+ return phxz_op .with_tags (* merge_tags_fn (circuit_op )) if merge_tags_fn else phxz_op
95
108
96
109
return merge_k_qubit_gates .merge_k_qubit_unitaries (
97
110
circuit , k = 1 , context = context , rewriter = rewriter
@@ -158,141 +171,33 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
158
171
).unfreeze (copy = False )
159
172
160
173
161
- # ----------------------------------------------------------------------
162
- # Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163
- # ----------------------------------------------------------------------
164
-
165
-
166
- def _values_of_sweep (sweep : Sweep , key : str | sympy .Symbol ):
174
+ def _values_of_sweep (sweep : Sweep , key : TMeasurementKey ):
167
175
p = sympy .Symbol (key ) if isinstance (key , str ) else key
168
176
return [resolver .value_of (p ) for resolver in sweep ]
169
177
170
178
171
- def _merge_single_qubit_gates_to_phxz_symbolized (
172
- resolved_circuits : List ['cirq.AbstractCircuit' ],
173
- symbolized_single_tag : str ,
174
- context : Optional ['cirq.TransformerContext' ],
175
- atol : float ,
176
- ) -> Tuple [List ['cirq.Circuit' ], frozenset [str ], frozenset [str ]]:
177
- """Helper function to merge single qubit ops of resolved circuits to PhasedXZ ops
178
- using merge_k_qubit_unitaries.
179
-
180
- Args:
181
- resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182
- symbolized_single_tag: The tag applied to single-qubit operations that originally
183
- contained symbols before parameterizations.
184
-
185
- Returns:
186
- Tuple of merge_counts, merged_circuits, and merge_tags, where
187
- merged ops in merged_circuits are tagged by merge_tags.
188
- """
189
- merge_counts : list [int ] = [] # number of merges per resolved_circuit
190
- merged_circuits : list ['cirq.Circuit' ] = []
191
- tag_iter : itertools .count
192
- phxz_tag_prefix = "_phxz"
193
-
194
- def rewriter (circuit_op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
195
- nonlocal tag_iter
196
- tag : Optional [str ] = None
197
-
198
- u = protocols .unitary (circuit_op )
199
- if protocols .num_qubits (circuit_op ) == 0 :
200
- return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
201
- # If any of the op in the merged circuit_op is a symbolized single qubit gate,
202
- # tag the merged phxz gate with next tag id, for further parameterization references.
203
- for op in circuit_op .circuit .all_operations ():
204
- if symbolized_single_tag in op .tags :
205
- tag = f"{ phxz_tag_prefix } _{ next (tag_iter )} "
206
- break
207
- gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
208
- op = gate .on (circuit_op .qubits [0 ])
209
- return op .with_tags (tag ) if tag else op
210
-
211
- for resolved_circuit in resolved_circuits :
212
- tag_iter = itertools .count (start = 0 , step = 1 )
213
- merged_circuits .append (
214
- merge_k_qubit_gates .merge_k_qubit_unitaries (
215
- resolved_circuit , k = 1 , context = context , rewriter = rewriter
216
- )
217
- )
218
- merge_counts .append (next (tag_iter ))
219
-
220
- if not all (count == merge_counts [0 ] for count in merge_counts ):
221
- raise RuntimeError ("Different resolvers in sweep resulted in different merged structures." )
222
-
223
- merge_tags : frozenset [str ] = frozenset (
224
- {f"{ phxz_tag_prefix } _{ i } " for i in range (merge_counts [0 ])}
225
- )
226
- new_symbols : frozenset [str ] = frozenset (
227
- set ().union (* [{f"x{ i } " , f"z{ i } " , f"a{ i } " } for i in range (merge_counts [0 ])])
228
- )
229
-
230
- return merged_circuits , merge_tags , new_symbols
231
-
232
-
233
- def _get_merge_tag_id (merge_tags : frozenset [str ], op_tags : Tuple [Hashable , ...]) -> Optional [str ]:
234
- """Extract the id `i` from the merge tag `_phxz_i` if it exists."""
235
- the_merge_tag : set [str ] = set (merge_tags .intersection (op_tags ))
236
- if len (the_merge_tag ) == 0 :
237
- return None
238
- if len (the_merge_tag ) > 1 :
239
- raise RuntimeError ("Multiple merge tags found." )
240
- return the_merge_tag .pop ().split ("_" )[- 1 ]
241
-
242
-
243
- def _map_merged_ops_to_symbolized_phxz (
244
- circuit : 'cirq.Circuit' , merge_tags : frozenset [str ], deep : bool
245
- ) -> 'cirq.Circuit' :
246
- """Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
247
-
248
- Args:
249
- circuit: Circuit with merge tags to be mapped.
250
- merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
251
- symbolized.
252
- deep: Whether to perform the mapping recursively within CircuitOperations.
253
-
254
- Returns:
255
- A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
256
- """
257
-
258
- # Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
259
- def _map_func (op : 'cirq.Operation' , _ ):
260
- """Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
261
- sid = _get_merge_tag_id (merge_tags , op .tags )
262
- if sid is None :
263
- return op
264
- phxz_params = {
265
- "x_exponent" : sympy .Symbol (f"x{ sid } " ),
266
- "z_exponent" : sympy .Symbol (f"z{ sid } " ),
267
- "axis_phase_exponent" : sympy .Symbol (f"a{ sid } " ),
268
- }
269
- return ops .PhasedXZGate (** phxz_params ).on (* op .qubits )
270
-
271
- return align .align_right (
272
- transformer_primitives .map_operations (circuit .freeze (), _map_func , deep = deep )
273
- )
274
-
275
-
276
- def _parameterize_merged_circuits (
277
- merged_circuits : List ['cirq.Circuit' ],
278
- merge_tags : frozenset [str ],
279
- new_symbols : frozenset [str ],
280
- remaining_symbols : frozenset [str ],
179
+ def _parameterize_phxz_in_circuits (
180
+ circuit_list : List ['cirq.Circuit' ],
181
+ merge_tag_prefix : str ,
182
+ phxz_symbols : frozenset [sympy .Symbol ],
183
+ remaining_symbols : frozenset [sympy .Symbol ],
281
184
sweep : Sweep ,
282
185
) -> Sweep :
283
- """Parameterizes the merged circuits and returns a new sweep."""
186
+ """Parameterizes the circuits and returns a new sweep."""
284
187
values_by_params : Dict [str , List [float ]] = {
285
- ** {s : [] for s in new_symbols }, # New symbols introduced during merging
286
- ** {
287
- s : _values_of_sweep (sweep , s ) for s in remaining_symbols
288
- }, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates.
188
+ ** {str (s ): [] for s in phxz_symbols },
189
+ ** {str (s ): _values_of_sweep (sweep , s ) for s in remaining_symbols },
289
190
}
290
191
291
- for merged_circuit in merged_circuits :
292
- for op in merged_circuit .all_operations ():
293
- sid = _get_merge_tag_id (merge_tags , op .tags )
294
- if sid is None :
192
+ for circuit in circuit_list :
193
+ for op in circuit .all_operations ():
194
+ the_merge_tag : Optional [str ] = None
195
+ for tag in op .tags :
196
+ if str (tag ).startswith (merge_tag_prefix ):
197
+ the_merge_tag = str (tag )
198
+ if not the_merge_tag :
295
199
continue
200
+ sid = the_merge_tag .rsplit ("_" , maxsplit = - 1 )[- 1 ]
296
201
x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters
297
202
if isinstance (op .gate , ops .PhasedXZGate ):
298
203
x , z , a = op .gate .x_exponent , op .gate .z_exponent , op .gate .axis_phase_exponent
@@ -308,6 +213,15 @@ def _parameterize_merged_circuits(
308
213
return Zip (* [Points (key = key , points = values ) for key , values in values_by_params .items ()])
309
214
310
215
216
+ def _all_tags_startswith (circuit : 'cirq.AbstractCircuit' , startswith : str ):
217
+ tag_set : set [Hashable ] = set ()
218
+ for op in circuit .all_operations ():
219
+ for tag in op .tags :
220
+ if str (tag ).startswith (startswith ):
221
+ tag_set .add (tag )
222
+ return tag_set
223
+
224
+
311
225
def merge_single_qubit_gates_to_phxz_symbolized (
312
226
circuit : 'cirq.AbstractCircuit' ,
313
227
* ,
@@ -353,7 +267,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
353
267
deep = context .deep if context else False
354
268
355
269
# Tag symbolized single-qubit op.
356
- symbolized_single_tag = "_symbolized_single "
270
+ symbolized_single_tag = "TMP-TAG-symbolized-single "
357
271
358
272
circuit_tagged = transformer_primitives .map_operations (
359
273
circuit ,
@@ -366,7 +280,6 @@ def merge_single_qubit_gates_to_phxz_symbolized(
366
280
)
367
281
368
282
# Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
369
-
370
283
single_qubit_gate_symbols : frozenset [sympy .Symbol ] = frozenset (
371
284
set ().union (
372
285
* [
@@ -378,11 +291,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
378
291
# If all single qubit gates are not parameterized, call the nonparamerized version of
379
292
# the transformer.
380
293
if not single_qubit_gate_symbols :
381
- return merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep
382
- # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
383
- remaining_symbols : frozenset [sympy .Symbol ] = frozenset (
384
- protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
385
- )
294
+ return (merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep )
386
295
sweep_of_single : Sweep = Zip (
387
296
* [Points (key = k , points = _values_of_sweep (sweep , k )) for k in single_qubit_gate_symbols ]
388
297
)
@@ -391,22 +300,68 @@ def merge_single_qubit_gates_to_phxz_symbolized(
391
300
protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
392
301
]
393
302
394
- # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
395
- merged_circuits , merge_tags , new_symbols = _merge_single_qubit_gates_to_phxz_symbolized (
396
- resolved_circuits , symbolized_single_tag , context , atol
397
- )
303
+ # Step 1, merge single qubit gates per resolved circuit, preserving the "symbolized_single_tag".
304
+ merged_circuits : List ['cirq.Circuit' ] = []
305
+ phxz_symbols : set [sympy .Symbols ] = set ()
306
+ for resolved_circuit in resolved_circuits :
307
+ merged_circuit = index_tags (
308
+ merge_single_qubit_gates_to_phxz (
309
+ resolved_circuit ,
310
+ context = context ,
311
+ merge_tags_fn = lambda circuit_op : (
312
+ [symbolized_single_tag ]
313
+ if any (
314
+ symbolized_single_tag in set (op .tags )
315
+ for op in circuit_op .circuit .all_operations ()
316
+ )
317
+ else []
318
+ ),
319
+ atol = atol ,
320
+ ),
321
+ target_tags = {symbolized_single_tag },
322
+ context = context ,
323
+ )
324
+ merged_circuits .append (merged_circuit )
325
+
326
+ if not all (
327
+ _all_tags_startswith (merged_circuits [0 ], startswith = symbolized_single_tag )
328
+ == _all_tags_startswith (merged_circuit , startswith = symbolized_single_tag )
329
+ for merged_circuit in merged_circuits
330
+ ):
331
+ raise RuntimeError ("Different resolvers in sweep resulted in different merged structures." )
398
332
399
- # Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
400
- new_circuit = _map_merged_ops_to_symbolized_phxz (merged_circuits [0 ], merge_tags , deep )
333
+ # Step 2, get the new symbolized circuit by mapping merged operations.
334
+ new_circuit = align .align_right (
335
+ remove_tags (
336
+ symbolize .symbolize_single_qubit_gates_by_indexed_tags (
337
+ merged_circuits [0 ], tag_prefix = symbolized_single_tag
338
+ ),
339
+ remove_if = lambda tag : tag .startswith (symbolized_single_tag ),
340
+ )
341
+ )
401
342
402
343
# Step 3, get N sets of parameterizations as new_sweep.
403
- new_sweep = _parameterize_merged_circuits (
404
- merged_circuits , merge_tags , new_symbols , remaining_symbols , sweep
344
+ phxz_symbols : frozenset [sympy .Symbol ] = frozenset (
345
+ set ().union (
346
+ * [
347
+ set (
348
+ [
349
+ sympy .Symbol (tag .replace (f"{ symbolized_single_tag } _" , s ))
350
+ for s in ["x" , "z" , "a" ]
351
+ ]
352
+ )
353
+ for tag in _all_tags_startswith (
354
+ merged_circuits [0 ], startswith = symbolized_single_tag
355
+ )
356
+ ]
357
+ )
358
+ )
359
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
360
+ remaining_symbols : frozenset [sympy .Symbol ] = frozenset (
361
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
362
+ )
363
+ new_sweep = _parameterize_phxz_in_circuits (
364
+ merged_circuits , symbolized_single_tag , phxz_symbols , remaining_symbols , sweep
405
365
)
406
366
407
367
return new_circuit .unfreeze (copy = False ), new_sweep
408
-
409
-
410
- # ----------------------------------------------------------------------
411
- # Impl merge_single_qubit_gates_to_phxz_symbolized: End
412
- # ----------------------------------------------------------------------
0 commit comments