15
15
"""Transformer passes to combine adjacent single-qubit rotations."""
16
16
17
17
import itertools
18
- import warnings
19
- from typing import Dict , List , Optional , Tuple , TYPE_CHECKING
18
+ from typing import Dict , Hashable , List , Optional , Tuple , TYPE_CHECKING
20
19
21
20
import sympy
22
21
23
22
from cirq import circuits , ops , protocols
24
23
from cirq .study .sweeps import Points , Sweep , Zip
25
- from cirq .transformers import merge_k_qubit_gates , transformer_api , transformer_primitives , align
24
+ from cirq .transformers import align , merge_k_qubit_gates , transformer_api , transformer_primitives
26
25
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
27
26
28
27
if TYPE_CHECKING :
@@ -159,115 +158,52 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
159
158
).unfreeze (copy = False )
160
159
161
160
161
+ # ----------------------------------------------------------------------
162
+ # Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163
+ # ----------------------------------------------------------------------
164
+
165
+
162
166
def _values_of_sweep (sweep : Sweep , key : str | sympy .Symbol ):
163
167
p = sympy .Symbol (key ) if isinstance (key , str ) else key
164
168
return [resolver .value_of (p ) for resolver in sweep ]
165
169
166
170
167
- @transformer_api .transformer
168
- def merge_single_qubit_gates_to_phxz_symbolized (
169
- circuit : 'cirq.AbstractCircuit' ,
170
- * ,
171
- context : Optional ['cirq.TransformerContext' ] = None ,
172
- sweep : Sweep ,
173
- atol : float = 1e-8 ,
174
- ) -> Tuple ['cirq.Circuit' , Sweep ]:
175
- """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized.
176
-
177
- Example:
178
- # pylint: disable=line-too-long
179
- >>> q0, q1 = cirq.LineQubit.range(2)
180
- >>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0))
181
- >>> print(c)
182
- 0: ───X───@──────────Y^y_exp───X───
183
- │
184
- 1: ───────@^cz_exp─────────────────
185
- >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
186
- c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\
187
- )
188
- >>> print(new_circuit)
189
- 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
190
- │
191
- 1: ────────────────────────@^cz_exp──────────────────────────
192
- >>> print(new_sweep)
193
- cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1])
194
- # pylint: disable=line-too-long
171
+ def _merge_single_qubit_gates_to_circuit_op_symbolized (
172
+ resolved_circuits : List ['cirq.AbstractCircuit' ],
173
+ symbolized_single_tag : str ,
174
+ context : Optional ['cirq.TransformerContext' ],
175
+ atol : float ,
176
+ ) -> Tuple [List ['cirq.Circuit' ], frozenset [str ], frozenset [str ]]:
177
+ """Helper function to merge single qubit gates of resolved circuits to op of CircuitOperation type
178
+ using merge_k_qubit_unitaries.
195
179
196
180
Args:
197
- circuit: Input circuit to transform. It will not be modified.
198
- sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
199
- based on the transformation.
200
- context: `cirq.TransformerContext` storing common configurable options for transformers.
201
- atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
202
- dropped, smaller values increase accuracy.
181
+ resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182
+ symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols before parameterization.
203
183
204
184
Returns:
205
- Copy of the transformed input circuit .
185
+ Tuple of merge counts, merged circuits, and merge tags .
206
186
"""
207
- deep = context .deep if context else False
208
-
209
- if not protocols .is_parameterized (circuit ):
210
- warnings .warn (
211
- "Expect parameterized circuits. "
212
- "Please use cirq.merge_single_qubit_gates_to_phxz instead." ,
213
- UserWarning ,
214
- )
215
- return merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol )
216
-
217
- # Tag symbolized single qubit op.
218
- symbolized_single_tag = "_symbolized_single"
219
-
220
- circuit_tagged = transformer_primitives .map_operations (
221
- circuit ,
222
- lambda op , _ : (
223
- op .with_tags (symbolized_single_tag )
224
- if protocols .is_parameterized (op ) and len (op .qubits ) == 1
225
- else op
226
- ),
227
- deep = deep ,
228
- )
229
-
230
- # Symbols of the single qubit symbolized ops.
231
- single_qubit_gate_symbols : set [sympy .Symbol ] = set ().union (
232
- * [
233
- protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
234
- for op in circuit_tagged .all_operations ()
235
- ]
236
- )
237
- # Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
238
- remaining_symbols = protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
239
-
240
- sweep_of_single : Sweep = Zip (
241
- * [Points (key = k , points = _values_of_sweep (sweep , k )) for k in single_qubit_gate_symbols ]
242
- )
243
-
244
- # Get all resolved circuits from all sets of resolvers in sweep.
245
- resolved_circuits = [
246
- protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
247
- ]
248
-
249
- # Store the number of merges for all set of resolvers,
250
- # it should be the same for all resolved circuits.
251
- merge_counts : list [int ] = []
252
- merged_circuits = []
253
- phxz_tag_prefix = "_phxz"
187
+ merge_counts : list [int ] = [] # number of merges per resolved_circuit
188
+ merged_circuits : list ['cirq.Circuit' ] = []
254
189
tag_iter : itertools .count
190
+ phxz_tag_prefix = "_phxz"
255
191
256
192
def rewriter (circuit_op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
257
193
nonlocal tag_iter
258
194
tag : Optional [str ] = None
195
+
259
196
u = protocols .unitary (circuit_op )
260
197
if protocols .num_qubits (circuit_op ) == 0 :
261
198
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
199
+ # If any of the op in the merged circuit_op is a symbolized single qubit gate,
200
+ # tag the merged phxz gate with next tag id, for further parameterization references.
262
201
for op in circuit_op .circuit .all_operations ():
263
202
if symbolized_single_tag in op .tags :
264
- # Record parameterizations info via tags.
265
203
tag = f"{ phxz_tag_prefix } _{ next (tag_iter )} "
266
204
break
267
205
gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
268
206
op = gate .on (circuit_op .qubits [0 ])
269
- if not gate :
270
- return []
271
207
return op .with_tags (tag ) if tag else op
272
208
273
209
for resolved_circuit in resolved_circuits :
@@ -280,62 +216,195 @@ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
280
216
merge_counts .append (next (tag_iter ))
281
217
282
218
if not all (count == merge_counts [0 ] for count in merge_counts ):
283
- raise RuntimeError ("Different resolvers in sweep result different merged strcuture ." )
219
+ raise RuntimeError ("Different resolvers in sweep resulted in different merged structures ." )
284
220
285
- # Get the output circuit from the first resolved circuits.
286
- merge_tags : set [str ] = {f"{ phxz_tag_prefix } _{ i } " for i in range (merge_counts [0 ])}
287
- new_symbols : set [str ] = set ().union (
288
- * [{f"x{ i } " , f"z{ i } " , f"a{ i } " } for i in range (merge_counts [0 ])]
221
+ merge_tags : frozenset [str ] = frozenset (
222
+ {f"{ phxz_tag_prefix } _{ i } " for i in range (merge_counts [0 ])}
223
+ )
224
+ new_symbols : frozenset [str ] = frozenset (
225
+ set ().union (* [{f"x{ i } " , f"z{ i } " , f"a{ i } " } for i in range (merge_counts [0 ])])
289
226
)
290
227
228
+ return merged_circuits , merge_tags , new_symbols
229
+
230
+
231
+ def _get_merge_tag_id (merge_tags : frozenset [str ], op_tags : Tuple [Hashable , ...]) -> Optional [str ]:
232
+ """Extract the id `i` from the merge tag `_phxz_i` if it exists."""
233
+ the_merge_tag : set [str ] = set (merge_tags .intersection (op_tags ))
234
+ if len (the_merge_tag ) == 0 :
235
+ return None
236
+ if len (the_merge_tag ) > 1 :
237
+ raise RuntimeError ("Multiple merge tags found." )
238
+ return the_merge_tag .pop ().split ("_" )[- 1 ]
239
+
240
+
241
+ def _map_merged_ops_to_symbolized_phxz (
242
+ circuit : 'cirq.Circuit' , merge_tags : frozenset [str ], deep : bool
243
+ ) -> 'cirq.Circuit' :
244
+ """Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
245
+
246
+ Args:
247
+ circuit: Circuit with merge tags to be mapped.
248
+ merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
249
+ symbolized.
250
+ deep: Whether to perform the mapping recursively within CircuitOperations.
251
+
252
+ Returns:
253
+ A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
254
+ """
255
+
256
+ # Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
291
257
def _map_func (op : 'cirq.Operation' , _ ):
292
- """Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
293
- the_merge_tag = merge_tags . intersection ( op .tags )
294
- if len ( the_merge_tag ) == 0 :
258
+ """Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
259
+ sid = _get_merge_tag_id ( merge_tags , op .tags )
260
+ if sid is None :
295
261
return op
296
- if len (the_merge_tag ) > 1 :
297
- raise RuntimeError ("Multiple merge tags found." )
298
- sid = the_merge_tag .pop ().split ("_" )[- 1 ]
299
262
phxz_params = {
300
263
"x_exponent" : sympy .Symbol (f"x{ sid } " ),
301
264
"z_exponent" : sympy .Symbol (f"z{ sid } " ),
302
265
"axis_phase_exponent" : sympy .Symbol (f"a{ sid } " ),
303
266
}
304
267
return ops .PhasedXZGate (** phxz_params ).on (* op .qubits )
305
268
306
- output_circuit : 'cirq.Circuit' = align .align_right (
307
- transformer_primitives .map_operations (merged_circuits [ 0 ] .freeze (), _map_func , deep = deep )
269
+ return align .align_right (
270
+ transformer_primitives .map_operations (circuit .freeze (), _map_func , deep = deep )
308
271
)
309
272
273
+
274
+ def _parameterize_merged_circuits (
275
+ merged_circuits : List ['cirq.Circuit' ],
276
+ merge_tags : frozenset [str ],
277
+ new_symbols : frozenset [str ],
278
+ remaining_symbols : frozenset [str ],
279
+ sweep : Sweep ,
280
+ ) -> Sweep :
281
+ """Parameterizes the merged circuits and returns a new sweep."""
310
282
values_by_params : Dict [str , List [float ]] = {
311
- ** {s : [] for s in new_symbols }, # New symbols introduced in merging
283
+ ** {s : [] for s in new_symbols }, # New symbols introduced during merging
312
284
** {
313
285
s : _values_of_sweep (sweep , s ) for s in remaining_symbols
314
- }, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates.
286
+ }, # Existing symbols in ops that were not merged, e.g., symbols in 2- qubit gates.
315
287
}
316
288
317
- # Get parameterization for the merged phxz gates.
318
289
for merged_circuit in merged_circuits :
319
290
for op in merged_circuit .all_operations ():
320
- the_merge_tag = merge_tags . intersection ( op .tags )
321
- if len ( the_merge_tag ) == 0 :
291
+ sid = _get_merge_tag_id ( merge_tags , op .tags )
292
+ if sid is None :
322
293
continue
323
- if len (the_merge_tag ) > 1 :
324
- raise RuntimeError ("Multiple merge tags found." )
325
- sid = the_merge_tag .pop ().split ("_" )[- 1 ]
326
- x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters.
294
+ x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters
327
295
if isinstance (op .gate , ops .PhasedXZGate ):
328
296
x , z , a = op .gate .x_exponent , op .gate .z_exponent , op .gate .axis_phase_exponent
329
297
elif op .gate is not ops .I :
330
298
raise RuntimeError (
331
- f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got { op .gate } ."
299
+ f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
300
+ f" but got { op .gate } ."
332
301
)
333
302
values_by_params [f"x{ sid } " ].append (x )
334
303
values_by_params [f"z{ sid } " ].append (z )
335
304
values_by_params [f"a{ sid } " ].append (a )
336
305
337
- new_sweep : Sweep = Zip (
338
- * [Points (key = key , points = values ) for key , values in values_by_params .items ()]
306
+ return Zip (* [Points (key = key , points = values ) for key , values in values_by_params .items ()])
307
+
308
+
309
+ def merge_single_qubit_gates_to_phxz_symbolized (
310
+ circuit : 'cirq.AbstractCircuit' ,
311
+ * ,
312
+ context : Optional ['cirq.TransformerContext' ] = None ,
313
+ sweep : Sweep ,
314
+ atol : float = 1e-8 ,
315
+ ) -> Tuple ['cirq.Circuit' , Sweep ]:
316
+ """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive
317
+ gates is symbolized.
318
+
319
+ Example:
320
+ >>> q0, q1 = cirq.LineQubit.range(2)
321
+ >>> c = cirq.Circuit(\
322
+ cirq.X(q0),\
323
+ cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
324
+ cirq.Y(q0)**sympy.Symbol("y_exp"),\
325
+ cirq.X(q0))
326
+ >>> print(c)
327
+ 0: ───X───@──────────Y^y_exp───X───
328
+ │
329
+ 1: ───────@^cz_exp─────────────────
330
+ >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
331
+ c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
332
+ cirq.Points(key="y_exp", points=[0, 1])))
333
+ >>> print(new_circuit)
334
+ 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
335
+ │
336
+ 1: ────────────────────────@^cz_exp──────────────────────────
337
+ >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
338
+ >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
339
+
340
+ Args:
341
+ circuit: Input circuit to transform. It will not be modified.
342
+ sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
343
+ based on the transformation.
344
+ context: `cirq.TransformerContext` storing common configurable options for transformers.
345
+ atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
346
+ dropped, smaller values increase accuracy.
347
+
348
+ Returns:
349
+ Copy of the transformed input circuit.
350
+ """
351
+ deep = context .deep if context else False
352
+
353
+ # Tag symbolized single-qubit op.
354
+ symbolized_single_tag = "_symbolized_single"
355
+
356
+ circuit_tagged = transformer_primitives .map_operations (
357
+ circuit ,
358
+ lambda op , _ : (
359
+ op .with_tags (symbolized_single_tag )
360
+ if protocols .is_parameterized (op ) and len (op .qubits ) == 1
361
+ else op
362
+ ),
363
+ deep = deep ,
364
+ )
365
+
366
+ # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
367
+
368
+ single_qubit_gate_symbols : frozenset [sympy .Symbol ] = frozenset (
369
+ set ().union (
370
+ * [
371
+ protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
372
+ for op in circuit_tagged .all_operations ()
373
+ ]
374
+ )
339
375
)
376
+ # If all single qubit gates are not parameterized, call the nonparamerized version of
377
+ # the transformer.
378
+ if not single_qubit_gate_symbols :
379
+ return merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep
380
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
381
+ remaining_symbols : frozenset [sympy .Symbol ] = frozenset (
382
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
383
+ )
384
+ sweep_of_single : Sweep = Zip (
385
+ * [Points (key = k , points = _values_of_sweep (sweep , k )) for k in single_qubit_gate_symbols ]
386
+ )
387
+ # Get all resolved circuits from all sets of resolvers in the sweep.
388
+ resolved_circuits = [
389
+ protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
390
+ ]
391
+
392
+ # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
393
+ merged_circuits , merge_tags , new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized (
394
+ resolved_circuits , symbolized_single_tag , context , atol
395
+ )
396
+
397
+ # Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
398
+ new_circuit = _map_merged_ops_to_symbolized_phxz (merged_circuits [0 ], merge_tags , deep )
399
+
400
+ # Step 3, get N sets of parameterizations as new_sweep.
401
+ new_sweep = _parameterize_merged_circuits (
402
+ merged_circuits , merge_tags , new_symbols , remaining_symbols , sweep
403
+ )
404
+
405
+ return new_circuit .unfreeze (copy = False ), new_sweep
406
+
340
407
341
- return output_circuit .unfreeze (copy = False ), new_sweep
408
+ # ----------------------------------------------------------------------
409
+ # Impl merge_single_qubit_gates_to_phxz_symbolized: End
410
+ # ----------------------------------------------------------------------
0 commit comments