Skip to content

Commit 71d2a4f

Browse files
Merge pull request #525 from quantumlib/qsim-pyopt
Reduce isinstance calls
2 parents 9e75230 + 19e0594 commit 71d2a4f

File tree

1 file changed

+187
-111
lines changed

1 file changed

+187
-111
lines changed

qsimcirq/qsim_circuit.py

Lines changed: 187 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -34,103 +34,183 @@
3434
]
3535

3636

37-
def _cirq_gate_kind(gate: cirq.ops.Gate):
38-
if isinstance(gate, cirq.ops.ControlledGate):
39-
return _cirq_gate_kind(gate.sub_gate)
40-
if isinstance(gate, cirq.ops.identity.IdentityGate):
41-
# Identity gates will decompose to no-ops.
42-
pass
43-
if isinstance(gate, cirq.ops.XPowGate):
44-
# cirq.rx also uses this path.
45-
if gate.exponent == 1 and gate.global_shift == 0:
46-
return qsim.kX
47-
return qsim.kXPowGate
48-
if isinstance(gate, cirq.ops.YPowGate):
49-
# cirq.ry also uses this path.
50-
if gate.exponent == 1 and gate.global_shift == 0:
51-
return qsim.kY
52-
return qsim.kYPowGate
53-
if isinstance(gate, cirq.ops.ZPowGate):
54-
# cirq.rz also uses this path.
55-
if gate.global_shift == 0:
56-
if gate.exponent == 1:
57-
return qsim.kZ
58-
if gate.exponent == 0.5:
59-
return qsim.kS
60-
if gate.exponent == 0.25:
61-
return qsim.kT
62-
return qsim.kZPowGate
63-
if isinstance(gate, cirq.ops.HPowGate):
64-
if gate.exponent == 1 and gate.global_shift == 0:
65-
return qsim.kH
66-
return qsim.kHPowGate
67-
if isinstance(gate, cirq.ops.CZPowGate):
68-
if gate.exponent == 1 and gate.global_shift == 0:
69-
return qsim.kCZ
70-
return qsim.kCZPowGate
71-
if isinstance(gate, cirq.ops.CXPowGate):
72-
if gate.exponent == 1 and gate.global_shift == 0:
73-
return qsim.kCX
74-
return qsim.kCXPowGate
75-
if isinstance(gate, cirq.ops.PhasedXPowGate):
76-
return qsim.kPhasedXPowGate
77-
if isinstance(gate, cirq.ops.PhasedXZGate):
78-
return qsim.kPhasedXZGate
79-
if isinstance(gate, cirq.ops.XXPowGate):
80-
if gate.exponent == 1 and gate.global_shift == 0:
81-
return qsim.kXX
82-
return qsim.kXXPowGate
83-
if isinstance(gate, cirq.ops.YYPowGate):
84-
if gate.exponent == 1 and gate.global_shift == 0:
85-
return qsim.kYY
86-
return qsim.kYYPowGate
87-
if isinstance(gate, cirq.ops.ZZPowGate):
88-
if gate.exponent == 1 and gate.global_shift == 0:
89-
return qsim.kZZ
90-
return qsim.kZZPowGate
91-
if isinstance(gate, cirq.ops.SwapPowGate):
92-
if gate.exponent == 1 and gate.global_shift == 0:
93-
return qsim.kSWAP
94-
return qsim.kSwapPowGate
95-
if isinstance(gate, cirq.ops.ISwapPowGate):
96-
# cirq.riswap also uses this path.
97-
if gate.exponent == 1 and gate.global_shift == 0:
98-
return qsim.kISWAP
99-
return qsim.kISwapPowGate
100-
if isinstance(gate, cirq.ops.PhasedISwapPowGate):
101-
# cirq.givens also uses this path.
102-
return qsim.kPhasedISwapPowGate
103-
if isinstance(gate, cirq.ops.FSimGate):
104-
return qsim.kFSimGate
105-
if isinstance(gate, cirq.ops.TwoQubitDiagonalGate):
106-
return qsim.kTwoQubitDiagonalGate
107-
if isinstance(gate, cirq.ops.ThreeQubitDiagonalGate):
108-
return qsim.kThreeQubitDiagonalGate
109-
if isinstance(gate, cirq.ops.CCZPowGate):
110-
if gate.exponent == 1 and gate.global_shift == 0:
111-
return qsim.kCCZ
112-
return qsim.kCCZPowGate
113-
if isinstance(gate, cirq.ops.CCXPowGate):
114-
if gate.exponent == 1 and gate.global_shift == 0:
115-
return qsim.kCCX
116-
return qsim.kCCXPowGate
117-
if isinstance(gate, cirq.ops.CSwapGate):
118-
return qsim.kCSwapGate
119-
if isinstance(gate, cirq.ops.MatrixGate):
120-
if gate.num_qubits() <= 6:
121-
return qsim.kMatrixGate
122-
raise NotImplementedError(
123-
f"Received matrix on {gate.num_qubits()} qubits; "
124-
+ "only up to 6-qubit gates are supported."
125-
)
126-
if isinstance(gate, cirq.ops.MeasurementGate):
127-
# needed to inherit SimulatesSamples in sims
128-
return qsim.kMeasurement
37+
def _translate_ControlledGate(gate: cirq.ControlledGate):
38+
return _cirq_gate_kind(gate.sub_gate)
39+
40+
41+
def _translate_XPowGate(gate: cirq.XPowGate):
42+
# cirq.rx also uses this path.
43+
if gate.exponent == 1 and gate.global_shift == 0:
44+
return qsim.kX
45+
return qsim.kXPowGate
46+
47+
48+
def _translate_YPowGate(gate: cirq.YPowGate):
49+
# cirq.ry also uses this path.
50+
if gate.exponent == 1 and gate.global_shift == 0:
51+
return qsim.kY
52+
return qsim.kYPowGate
53+
54+
55+
def _translate_ZPowGate(gate: cirq.ZPowGate):
56+
# cirq.rz also uses this path.
57+
if gate.global_shift == 0:
58+
if gate.exponent == 1:
59+
return qsim.kZ
60+
if gate.exponent == 0.5:
61+
return qsim.kS
62+
if gate.exponent == 0.25:
63+
return qsim.kT
64+
return qsim.kZPowGate
65+
66+
67+
def _translate_HPowGate(gate: cirq.HPowGate):
68+
if gate.exponent == 1 and gate.global_shift == 0:
69+
return qsim.kH
70+
return qsim.kHPowGate
71+
72+
73+
def _translate_CZPowGate(gate: cirq.CZPowGate):
74+
if gate.exponent == 1 and gate.global_shift == 0:
75+
return qsim.kCZ
76+
return qsim.kCZPowGate
77+
78+
79+
def _translate_CXPowGate(gate: cirq.CXPowGate):
80+
if gate.exponent == 1 and gate.global_shift == 0:
81+
return qsim.kCX
82+
return qsim.kCXPowGate
83+
84+
85+
def _translate_PhasedXPowGate(gate: cirq.PhasedXPowGate):
86+
return qsim.kPhasedXPowGate
87+
88+
89+
def _translate_PhasedXZGate(gate: cirq.PhasedXZGate):
90+
return qsim.kPhasedXZGate
91+
92+
93+
def _translate_XXPowGate(gate: cirq.XXPowGate):
94+
if gate.exponent == 1 and gate.global_shift == 0:
95+
return qsim.kXX
96+
return qsim.kXXPowGate
97+
98+
99+
def _translate_YYPowGate(gate: cirq.YYPowGate):
100+
if gate.exponent == 1 and gate.global_shift == 0:
101+
return qsim.kYY
102+
return qsim.kYYPowGate
103+
104+
105+
def _translate_ZZPowGate(gate: cirq.ZZPowGate):
106+
if gate.exponent == 1 and gate.global_shift == 0:
107+
return qsim.kZZ
108+
return qsim.kZZPowGate
109+
110+
111+
def _translate_SwapPowGate(gate: cirq.SwapPowGate):
112+
if gate.exponent == 1 and gate.global_shift == 0:
113+
return qsim.kSWAP
114+
return qsim.kSwapPowGate
115+
116+
117+
def _translate_ISwapPowGate(gate: cirq.ISwapPowGate):
118+
# cirq.riswap also uses this path.
119+
if gate.exponent == 1 and gate.global_shift == 0:
120+
return qsim.kISWAP
121+
return qsim.kISwapPowGate
122+
123+
124+
def _translate_PhasedISwapPowGate(gate: cirq.PhasedISwapPowGate):
125+
# cirq.givens also uses this path.
126+
return qsim.kPhasedISwapPowGate
127+
128+
129+
def _translate_FSimGate(gate: cirq.FSimGate):
130+
return qsim.kFSimGate
131+
132+
133+
def _translate_TwoQubitDiagonalGate(gate: cirq.TwoQubitDiagonalGate):
134+
return qsim.kTwoQubitDiagonalGate
135+
136+
137+
def _translate_ThreeQubitDiagonalGate(gate: cirq.ThreeQubitDiagonalGate):
138+
return qsim.kThreeQubitDiagonalGate
139+
140+
141+
def _translate_CCZPowGate(gate: cirq.CCZPowGate):
142+
if gate.exponent == 1 and gate.global_shift == 0:
143+
return qsim.kCCZ
144+
return qsim.kCCZPowGate
145+
146+
147+
def _translate_CCXPowGate(gate: cirq.CCXPowGate):
148+
if gate.exponent == 1 and gate.global_shift == 0:
149+
return qsim.kCCX
150+
return qsim.kCCXPowGate
151+
152+
153+
def _translate_CSwapGate(gate: cirq.CSwapGate):
154+
return qsim.kCSwapGate
155+
156+
157+
def _translate_MatrixGate(gate: cirq.MatrixGate):
158+
if gate.num_qubits() <= 6:
159+
return qsim.kMatrixGate
160+
raise NotImplementedError(
161+
f"Received matrix on {gate.num_qubits()} qubits; "
162+
+ "only up to 6-qubit gates are supported."
163+
)
164+
165+
166+
def _translate_MeasurementGate(gate: cirq.MeasurementGate):
167+
# needed to inherit SimulatesSamples in sims
168+
return qsim.kMeasurement
169+
170+
171+
TYPE_TRANSLATOR = {
172+
cirq.ControlledGate: _translate_ControlledGate,
173+
cirq.XPowGate: _translate_XPowGate,
174+
cirq.YPowGate: _translate_YPowGate,
175+
cirq.ZPowGate: _translate_ZPowGate,
176+
cirq.HPowGate: _translate_HPowGate,
177+
cirq.CZPowGate: _translate_CZPowGate,
178+
cirq.CXPowGate: _translate_CXPowGate,
179+
cirq.PhasedXPowGate: _translate_PhasedXPowGate,
180+
cirq.PhasedXZGate: _translate_PhasedXZGate,
181+
cirq.XXPowGate: _translate_XXPowGate,
182+
cirq.YYPowGate: _translate_YYPowGate,
183+
cirq.ZZPowGate: _translate_ZZPowGate,
184+
cirq.SwapPowGate: _translate_SwapPowGate,
185+
cirq.ISwapPowGate: _translate_ISwapPowGate,
186+
cirq.PhasedISwapPowGate: _translate_PhasedISwapPowGate,
187+
cirq.FSimGate: _translate_FSimGate,
188+
cirq.TwoQubitDiagonalGate: _translate_TwoQubitDiagonalGate,
189+
cirq.ThreeQubitDiagonalGate: _translate_ThreeQubitDiagonalGate,
190+
cirq.CCZPowGate: _translate_CCZPowGate,
191+
cirq.CCXPowGate: _translate_CCXPowGate,
192+
cirq.CSwapGate: _translate_CSwapGate,
193+
cirq.MatrixGate: _translate_MatrixGate,
194+
cirq.MeasurementGate: _translate_MeasurementGate,
195+
}
196+
197+
198+
def _cirq_gate_kind(gate: cirq.Gate):
199+
for gate_type in type(gate).mro():
200+
translator = TYPE_TRANSLATOR.get(gate_type, None)
201+
if translator is not None:
202+
return translator(gate)
129203
# Unrecognized gates will be decomposed.
130204
return None
131205

132206

133-
def _control_details(gate: cirq.ops.ControlledGate, qubits):
207+
def _has_cirq_gate_kind(op: cirq.Operation):
208+
if isinstance(op, cirq.ControlledOperation):
209+
return _has_cirq_gate_kind(op.sub_operation)
210+
return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro())
211+
212+
213+
def _control_details(gate: cirq.ControlledGate, qubits):
134214
control_qubits = []
135215
control_values = []
136216
# TODO: support qudit control
@@ -169,7 +249,7 @@ def add_op_to_opstring(
169249
if len(qsim_op.qubits) != 1:
170250
raise ValueError(f"OpString ops should have 1 qubit; got {len(qsim_op.qubits)}")
171251

172-
is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate)
252+
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
173253
if is_controlled:
174254
raise ValueError(f"OpString ops should not be controlled.")
175255

@@ -189,7 +269,7 @@ def add_op_to_circuit(
189269
qubits = [qubit_to_index_dict[q] for q in qsim_op.qubits]
190270

191271
qsim_qubits = qubits
192-
is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate)
272+
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
193273
if is_controlled:
194274
control_qubits, control_values = _control_details(qsim_gate, qubits)
195275
if control_qubits is None:
@@ -276,7 +356,7 @@ def _resolve_parameters_(
276356
return QSimCircuit(cirq.resolve_parameters(super(), param_resolver, recursive))
277357

278358
def translate_cirq_to_qsim(
279-
self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT
359+
self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT
280360
) -> qsim.Circuit:
281361
"""
282362
Translates this Cirq circuit to the qsim representation.
@@ -286,31 +366,30 @@ def translate_cirq_to_qsim(
286366
"""
287367

288368
qsim_circuit = qsim.Circuit()
289-
ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for(
369+
ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for(
290370
self.all_qubits()
291371
)
292372
qsim_circuit.num_qubits = len(ordered_qubits)
293373

294374
# qsim numbers qubits in reverse order from cirq
295375
ordered_qubits = list(reversed(ordered_qubits))
296376

297-
def has_qsim_kind(op: cirq.ops.GateOperation):
298-
return _cirq_gate_kind(op.gate) != None
299-
300-
def to_matrix(op: cirq.ops.GateOperation):
377+
def to_matrix(op: cirq.GateOperation):
301378
mat = cirq.unitary(op.gate, None)
302379
if mat is None:
303380
return NotImplemented
304381

305-
return cirq.ops.MatrixGate(mat).on(*op.qubits)
382+
return cirq.MatrixGate(mat).on(*op.qubits)
306383

307384
qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)}
308385
time_offset = 0
309386
gate_count = 0
310387
moment_indices = []
311388
for moment in self:
312389
ops_by_gate = [
313-
cirq.decompose(op, fallback_decomposer=to_matrix, keep=has_qsim_kind)
390+
cirq.decompose(
391+
op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind
392+
)
314393
for op in moment
315394
]
316395
moment_length = max((len(gate_ops) for gate_ops in ops_by_gate), default=0)
@@ -330,7 +409,7 @@ def to_matrix(op: cirq.ops.GateOperation):
330409
return qsim_circuit, moment_indices
331410

332411
def translate_cirq_to_qtrajectory(
333-
self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT
412+
self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT
334413
) -> qsim.NoisyCircuit:
335414
"""
336415
Translates this noisy Cirq circuit to the qsim representation.
@@ -339,7 +418,7 @@ def translate_cirq_to_qtrajectory(
339418
gate indices)
340419
"""
341420
qsim_ncircuit = qsim.NoisyCircuit()
342-
ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for(
421+
ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for(
343422
self.all_qubits()
344423
)
345424

@@ -348,15 +427,12 @@ def translate_cirq_to_qtrajectory(
348427

349428
qsim_ncircuit.num_qubits = len(ordered_qubits)
350429

351-
def has_qsim_kind(op: cirq.ops.GateOperation):
352-
return _cirq_gate_kind(op.gate) != None
353-
354-
def to_matrix(op: cirq.ops.GateOperation):
430+
def to_matrix(op: cirq.GateOperation):
355431
mat = cirq.unitary(op.gate, None)
356432
if mat is None:
357433
return NotImplemented
358434

359-
return cirq.ops.MatrixGate(mat).on(*op.qubits)
435+
return cirq.MatrixGate(mat).on(*op.qubits)
360436

361437
qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)}
362438
time_offset = 0
@@ -371,7 +447,7 @@ def to_matrix(op: cirq.ops.GateOperation):
371447
for qsim_op in moment:
372448
if cirq.has_unitary(qsim_op) or cirq.is_measurement(qsim_op):
373449
oplist = cirq.decompose(
374-
qsim_op, fallback_decomposer=to_matrix, keep=has_qsim_kind
450+
qsim_op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind
375451
)
376452
ops_by_gate.append(oplist)
377453
moment_length = max(moment_length, len(oplist))

0 commit comments

Comments
 (0)