15
15
"""Transformer passes to combine adjacent single-qubit rotations."""
16
16
17
17
import itertools
18
- import warnings
19
- from typing import Dict , List , Optional , Tuple , TYPE_CHECKING
18
+ from typing import Dict , Hashable , List , Optional , Tuple , TYPE_CHECKING
20
19
21
20
import sympy
22
21
23
22
from cirq import circuits , ops , protocols
24
23
from cirq .study .sweeps import Points , Sweep , Zip
25
- from cirq .transformers import merge_k_qubit_gates , transformer_api , transformer_primitives , align
24
+ from cirq .transformers import align , merge_k_qubit_gates , transformer_api , transformer_primitives
26
25
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
27
26
28
27
if TYPE_CHECKING :
@@ -159,115 +158,53 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
159
158
).unfreeze (copy = False )
160
159
161
160
161
+ # ----------------------------------------------------------------------
162
+ # Impl merge_single_qubit_gates_to_phxz_symbolized: Start
163
+ # ----------------------------------------------------------------------
164
+
165
+
162
166
def _values_of_sweep (sweep : Sweep , key : str | sympy .Symbol ):
163
167
p = sympy .Symbol (key ) if isinstance (key , str ) else key
164
168
return [resolver .value_of (p ) for resolver in sweep ]
165
169
166
170
167
- @transformer_api .transformer
168
- def merge_single_qubit_gates_to_phxz_symbolized (
169
- circuit : 'cirq.AbstractCircuit' ,
170
- * ,
171
- context : Optional ['cirq.TransformerContext' ] = None ,
172
- sweep : Sweep ,
173
- atol : float = 1e-8 ,
174
- ) -> Tuple ['cirq.Circuit' , Sweep ]:
175
- """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized.
176
-
177
- Example:
178
- # pylint: disable=line-too-long
179
- >>> q0, q1 = cirq.LineQubit.range(2)
180
- >>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0))
181
- >>> print(c)
182
- 0: ───X───@──────────Y^y_exp───X───
183
- │
184
- 1: ───────@^cz_exp─────────────────
185
- >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
186
- c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\
187
- )
188
- >>> print(new_circuit)
189
- 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
190
- │
191
- 1: ────────────────────────@^cz_exp──────────────────────────
192
- >>> print(new_sweep)
193
- cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1])
194
- # pylint: disable=line-too-long
171
+ def _merge_single_qubit_gates_to_circuit_op_symbolized (
172
+ resolved_circuits : List ['cirq.AbstractCircuit' ],
173
+ symbolized_single_tag : str ,
174
+ context : Optional ['cirq.TransformerContext' ],
175
+ atol : float ,
176
+ ) -> Tuple [List ['cirq.Circuit' ], frozenset [str ], frozenset [str ]]:
177
+ """Helper function to merge single qubit ops of resolved circuits to ops of CircuitOperation
178
+ type using merge_k_qubit_unitaries.
195
179
196
180
Args:
197
- circuit: Input circuit to transform. It will not be modified.
198
- sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
199
- based on the transformation.
200
- context: `cirq.TransformerContext` storing common configurable options for transformers.
201
- atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
202
- dropped, smaller values increase accuracy.
181
+ resolved_circuits: A list of circuits where symbols have been replaced with concrete values.
182
+ symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols
183
+ before parameterizations.
203
184
204
185
Returns:
205
- Copy of the transformed input circuit .
186
+ Tuple of merge counts, merged circuits, and merge tags .
206
187
"""
207
- deep = context .deep if context else False
208
-
209
- if not protocols .is_parameterized (circuit ):
210
- warnings .warn (
211
- "Expect parameterized circuits. "
212
- "Please use cirq.merge_single_qubit_gates_to_phxz instead." ,
213
- UserWarning ,
214
- )
215
- return merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol )
216
-
217
- # Tag symbolized single qubit op.
218
- symbolized_single_tag = "_symbolized_single"
219
-
220
- circuit_tagged = transformer_primitives .map_operations (
221
- circuit ,
222
- lambda op , _ : (
223
- op .with_tags (symbolized_single_tag )
224
- if protocols .is_parameterized (op ) and len (op .qubits ) == 1
225
- else op
226
- ),
227
- deep = deep ,
228
- )
229
-
230
- # Symbols of the single qubit symbolized ops.
231
- single_qubit_gate_symbols : set [sympy .Symbol ] = set ().union (
232
- * [
233
- protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
234
- for op in circuit_tagged .all_operations ()
235
- ]
236
- )
237
- # Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
238
- remaining_symbols = protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
239
-
240
- sweep_of_single : Sweep = Zip (
241
- * [Points (key = k , points = _values_of_sweep (sweep , k )) for k in single_qubit_gate_symbols ]
242
- )
243
-
244
- # Get all resolved circuits from all sets of resolvers in sweep.
245
- resolved_circuits = [
246
- protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
247
- ]
248
-
249
- # Store the number of merges for all set of resolvers,
250
- # it should be the same for all resolved circuits.
251
- merge_counts : list [int ] = []
252
- merged_circuits = []
253
- phxz_tag_prefix = "_phxz"
188
+ merge_counts : list [int ] = [] # number of merges per resolved_circuit
189
+ merged_circuits : list ['cirq.Circuit' ] = []
254
190
tag_iter : itertools .count
191
+ phxz_tag_prefix = "_phxz"
255
192
256
193
def rewriter (circuit_op : 'cirq.CircuitOperation' ) -> 'cirq.OP_TREE' :
257
194
nonlocal tag_iter
258
195
tag : Optional [str ] = None
196
+
259
197
u = protocols .unitary (circuit_op )
260
198
if protocols .num_qubits (circuit_op ) == 0 :
261
199
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
200
+ # If any of the op in the merged circuit_op is a symbolized single qubit gate,
201
+ # tag the merged phxz gate with next tag id, for further parameterization references.
262
202
for op in circuit_op .circuit .all_operations ():
263
203
if symbolized_single_tag in op .tags :
264
- # Record parameterizations info via tags.
265
204
tag = f"{ phxz_tag_prefix } _{ next (tag_iter )} "
266
205
break
267
206
gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
268
207
op = gate .on (circuit_op .qubits [0 ])
269
- if not gate :
270
- return []
271
208
return op .with_tags (tag ) if tag else op
272
209
273
210
for resolved_circuit in resolved_circuits :
@@ -280,62 +217,195 @@ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
280
217
merge_counts .append (next (tag_iter ))
281
218
282
219
if not all (count == merge_counts [0 ] for count in merge_counts ):
283
- raise RuntimeError ("Different resolvers in sweep result different merged strcuture ." )
220
+ raise RuntimeError ("Different resolvers in sweep resulted in different merged structures ." )
284
221
285
- # Get the output circuit from the first resolved circuits.
286
- merge_tags : set [str ] = {f"{ phxz_tag_prefix } _{ i } " for i in range (merge_counts [0 ])}
287
- new_symbols : set [str ] = set ().union (
288
- * [{f"x{ i } " , f"z{ i } " , f"a{ i } " } for i in range (merge_counts [0 ])]
222
+ merge_tags : frozenset [str ] = frozenset (
223
+ {f"{ phxz_tag_prefix } _{ i } " for i in range (merge_counts [0 ])}
224
+ )
225
+ new_symbols : frozenset [str ] = frozenset (
226
+ set ().union (* [{f"x{ i } " , f"z{ i } " , f"a{ i } " } for i in range (merge_counts [0 ])])
289
227
)
290
228
229
+ return merged_circuits , merge_tags , new_symbols
230
+
231
+
232
+ def _get_merge_tag_id (merge_tags : frozenset [str ], op_tags : Tuple [Hashable , ...]) -> Optional [str ]:
233
+ """Extract the id `i` from the merge tag `_phxz_i` if it exists."""
234
+ the_merge_tag : set [str ] = set (merge_tags .intersection (op_tags ))
235
+ if len (the_merge_tag ) == 0 :
236
+ return None
237
+ if len (the_merge_tag ) > 1 :
238
+ raise RuntimeError ("Multiple merge tags found." )
239
+ return the_merge_tag .pop ().split ("_" )[- 1 ]
240
+
241
+
242
+ def _map_merged_ops_to_symbolized_phxz (
243
+ circuit : 'cirq.Circuit' , merge_tags : frozenset [str ], deep : bool
244
+ ) -> 'cirq.Circuit' :
245
+ """Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates.
246
+
247
+ Args:
248
+ circuit: Circuit with merge tags to be mapped.
249
+ merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be
250
+ symbolized.
251
+ deep: Whether to perform the mapping recursively within CircuitOperations.
252
+
253
+ Returns:
254
+ A new circuit where tagged PhasedXZ gates are replaced by symbolized versions.
255
+ """
256
+
257
+ # Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i".
291
258
def _map_func (op : 'cirq.Operation' , _ ):
292
- """Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
293
- the_merge_tag = merge_tags . intersection ( op .tags )
294
- if len ( the_merge_tag ) == 0 :
259
+ """Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`"""
260
+ sid = _get_merge_tag_id ( merge_tags , op .tags )
261
+ if sid is None :
295
262
return op
296
- if len (the_merge_tag ) > 1 :
297
- raise RuntimeError ("Multiple merge tags found." )
298
- sid = the_merge_tag .pop ().split ("_" )[- 1 ]
299
263
phxz_params = {
300
264
"x_exponent" : sympy .Symbol (f"x{ sid } " ),
301
265
"z_exponent" : sympy .Symbol (f"z{ sid } " ),
302
266
"axis_phase_exponent" : sympy .Symbol (f"a{ sid } " ),
303
267
}
304
268
return ops .PhasedXZGate (** phxz_params ).on (* op .qubits )
305
269
306
- output_circuit : 'cirq.Circuit' = align .align_right (
307
- transformer_primitives .map_operations (merged_circuits [ 0 ] .freeze (), _map_func , deep = deep )
270
+ return align .align_right (
271
+ transformer_primitives .map_operations (circuit .freeze (), _map_func , deep = deep )
308
272
)
309
273
274
+
275
+ def _parameterize_merged_circuits (
276
+ merged_circuits : List ['cirq.Circuit' ],
277
+ merge_tags : frozenset [str ],
278
+ new_symbols : frozenset [str ],
279
+ remaining_symbols : frozenset [str ],
280
+ sweep : Sweep ,
281
+ ) -> Sweep :
282
+ """Parameterizes the merged circuits and returns a new sweep."""
310
283
values_by_params : Dict [str , List [float ]] = {
311
- ** {s : [] for s in new_symbols }, # New symbols introduced in merging
284
+ ** {s : [] for s in new_symbols }, # New symbols introduced during merging
312
285
** {
313
286
s : _values_of_sweep (sweep , s ) for s in remaining_symbols
314
- }, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates.
287
+ }, # Existing symbols in ops that were not merged, e.g., symbols in 2- qubit gates.
315
288
}
316
289
317
- # Get parameterization for the merged phxz gates.
318
290
for merged_circuit in merged_circuits :
319
291
for op in merged_circuit .all_operations ():
320
- the_merge_tag = merge_tags . intersection ( op .tags )
321
- if len ( the_merge_tag ) == 0 :
292
+ sid = _get_merge_tag_id ( merge_tags , op .tags )
293
+ if sid is None :
322
294
continue
323
- if len (the_merge_tag ) > 1 :
324
- raise RuntimeError ("Multiple merge tags found." )
325
- sid = the_merge_tag .pop ().split ("_" )[- 1 ]
326
- x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters.
295
+ x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters
327
296
if isinstance (op .gate , ops .PhasedXZGate ):
328
297
x , z , a = op .gate .x_exponent , op .gate .z_exponent , op .gate .axis_phase_exponent
329
298
elif op .gate is not ops .I :
330
299
raise RuntimeError (
331
- f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got { op .gate } ."
300
+ f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
301
+ f" but got { op .gate } ."
332
302
)
333
303
values_by_params [f"x{ sid } " ].append (x )
334
304
values_by_params [f"z{ sid } " ].append (z )
335
305
values_by_params [f"a{ sid } " ].append (a )
336
306
337
- new_sweep : Sweep = Zip (
338
- * [Points (key = key , points = values ) for key , values in values_by_params .items ()]
307
+ return Zip (* [Points (key = key , points = values ) for key , values in values_by_params .items ()])
308
+
309
+
310
+ def merge_single_qubit_gates_to_phxz_symbolized (
311
+ circuit : 'cirq.AbstractCircuit' ,
312
+ * ,
313
+ context : Optional ['cirq.TransformerContext' ] = None ,
314
+ sweep : Sweep ,
315
+ atol : float = 1e-8 ,
316
+ ) -> Tuple ['cirq.Circuit' , Sweep ]:
317
+ """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive
318
+ gates is symbolized.
319
+
320
+ Example:
321
+ >>> q0, q1 = cirq.LineQubit.range(2)
322
+ >>> c = cirq.Circuit(\
323
+ cirq.X(q0),\
324
+ cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
325
+ cirq.Y(q0)**sympy.Symbol("y_exp"),\
326
+ cirq.X(q0))
327
+ >>> print(c)
328
+ 0: ───X───@──────────Y^y_exp───X───
329
+ │
330
+ 1: ───────@^cz_exp─────────────────
331
+ >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
332
+ c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
333
+ cirq.Points(key="y_exp", points=[0, 1])))
334
+ >>> print(new_circuit)
335
+ 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
336
+ │
337
+ 1: ────────────────────────@^cz_exp──────────────────────────
338
+ >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
339
+ >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
340
+
341
+ Args:
342
+ circuit: Input circuit to transform. It will not be modified.
343
+ sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
344
+ based on the transformation.
345
+ context: `cirq.TransformerContext` storing common configurable options for transformers.
346
+ atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
347
+ dropped, smaller values increase accuracy.
348
+
349
+ Returns:
350
+ Copy of the transformed input circuit.
351
+ """
352
+ deep = context .deep if context else False
353
+
354
+ # Tag symbolized single-qubit op.
355
+ symbolized_single_tag = "_symbolized_single"
356
+
357
+ circuit_tagged = transformer_primitives .map_operations (
358
+ circuit ,
359
+ lambda op , _ : (
360
+ op .with_tags (symbolized_single_tag )
361
+ if protocols .is_parameterized (op ) and len (op .qubits ) == 1
362
+ else op
363
+ ),
364
+ deep = deep ,
365
+ )
366
+
367
+ # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them.
368
+
369
+ single_qubit_gate_symbols : frozenset [sympy .Symbol ] = frozenset (
370
+ set ().union (
371
+ * [
372
+ protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
373
+ for op in circuit_tagged .all_operations ()
374
+ ]
375
+ )
339
376
)
377
+ # If all single qubit gates are not parameterized, call the nonparamerized version of
378
+ # the transformer.
379
+ if not single_qubit_gate_symbols :
380
+ return merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep
381
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
382
+ remaining_symbols : frozenset [sympy .Symbol ] = frozenset (
383
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
384
+ )
385
+ sweep_of_single : Sweep = Zip (
386
+ * [Points (key = k , points = _values_of_sweep (sweep , k )) for k in single_qubit_gate_symbols ]
387
+ )
388
+ # Get all resolved circuits from all sets of resolvers in the sweep.
389
+ resolved_circuits = [
390
+ protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
391
+ ]
392
+
393
+ # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries.
394
+ merged_circuits , merge_tags , new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized (
395
+ resolved_circuits , symbolized_single_tag , context , atol
396
+ )
397
+
398
+ # Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations.
399
+ new_circuit = _map_merged_ops_to_symbolized_phxz (merged_circuits [0 ], merge_tags , deep )
400
+
401
+ # Step 3, get N sets of parameterizations as new_sweep.
402
+ new_sweep = _parameterize_merged_circuits (
403
+ merged_circuits , merge_tags , new_symbols , remaining_symbols , sweep
404
+ )
405
+
406
+ return new_circuit .unfreeze (copy = False ), new_sweep
407
+
340
408
341
- return output_circuit .unfreeze (copy = False ), new_sweep
409
+ # ----------------------------------------------------------------------
410
+ # Impl merge_single_qubit_gates_to_phxz_symbolized: End
411
+ # ----------------------------------------------------------------------
0 commit comments