Skip to content

Commit 82435c1

Browse files
authored
Parameterize QasmUGate (#7759)
Update `QasmUGate` to handle parameterized values for `theta`, `phi`, and `lmda`. Resolves #5983 Resolves #5985
1 parent e95ae4b commit 82435c1

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

cirq-core/cirq/circuits/qasm_output.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717
from __future__ import annotations
1818

1919
import re
20-
from collections.abc import Callable, Iterator, Sequence
20+
from collections.abc import Callable, Iterator, Sequence, Set
2121
from typing import TYPE_CHECKING
2222

2323
import numpy as np
24+
import sympy
2425

2526
from cirq import linalg, ops, protocols, value
27+
from cirq._compat import proper_repr
2628

2729
if TYPE_CHECKING:
2830
import cirq
2931

3032

3133
@value.value_equality(approximate=True)
3234
class QasmUGate(ops.Gate):
33-
def __init__(self, theta, phi, lmda) -> None:
35+
def __init__(self, theta: cirq.TParamVal, phi: cirq.TParamVal, lmda: cirq.TParamVal) -> None:
3436
"""A QASM gate representing any single qubit unitary with a series of
3537
three rotations, Z, Y, and Z.
3638
@@ -41,9 +43,9 @@ def __init__(self, theta, phi, lmda) -> None:
4143
phi: Half turns to rotate about Z (applied last).
4244
lmda: Half turns to rotate about Z (applied first).
4345
"""
44-
self.lmda = lmda % 2
4546
self.theta = theta % 2
4647
self.phi = phi % 2
48+
self.lmda = lmda % 2
4749

4850
def _num_qubits_(self) -> int:
4951
return 1
@@ -54,7 +56,28 @@ def from_matrix(mat: np.ndarray) -> QasmUGate:
5456
return QasmUGate(rotation / np.pi, post_phase / np.pi, pre_phase / np.pi)
5557

5658
def _has_unitary_(self):
57-
return True
59+
return not self._is_parameterized_()
60+
61+
def _is_parameterized_(self) -> bool:
62+
return (
63+
protocols.is_parameterized(self.theta)
64+
or protocols.is_parameterized(self.phi)
65+
or protocols.is_parameterized(self.lmda)
66+
)
67+
68+
def _parameter_names_(self) -> Set[str]:
69+
return (
70+
protocols.parameter_names(self.theta)
71+
| protocols.parameter_names(self.phi)
72+
| protocols.parameter_names(self.lmda)
73+
)
74+
75+
def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> QasmUGate:
76+
return QasmUGate(
77+
protocols.resolve_parameters(self.theta, resolver, recursive),
78+
protocols.resolve_parameters(self.phi, resolver, recursive),
79+
protocols.resolve_parameters(self.lmda, resolver, recursive),
80+
)
5881

5982
def _qasm_(self, qubits: tuple[cirq.Qid, ...], args: cirq.QasmArgs) -> str:
6083
args.validate_version('2.0', '3.0')
@@ -69,18 +92,21 @@ def _qasm_(self, qubits: tuple[cirq.Qid, ...], args: cirq.QasmArgs) -> str:
6992
def __repr__(self) -> str:
7093
return (
7194
f'cirq.circuits.qasm_output.QasmUGate('
72-
f'theta={self.theta!r}, '
73-
f'phi={self.phi!r}, '
74-
f'lmda={self.lmda})'
95+
f'theta={proper_repr(self.theta)}, '
96+
f'phi={proper_repr(self.phi)}, '
97+
f'lmda={proper_repr(self.lmda)})'
7598
)
7699

77100
def _decompose_(self, qubits):
101+
def mul_pi(x):
102+
return x * (sympy.pi if protocols.is_parameterized(x) else np.pi)
103+
78104
q = qubits[0]
79105
phase_correction_half_turns = (self.phi + self.lmda) / 2
80106
return [
81-
ops.rz(self.lmda * np.pi).on(q),
82-
ops.ry(self.theta * np.pi).on(q),
83-
ops.rz(self.phi * np.pi).on(q),
107+
ops.rz(mul_pi(self.lmda)).on(q),
108+
ops.ry(mul_pi(self.theta)).on(q),
109+
ops.rz(mul_pi(self.phi)).on(q),
84110
ops.global_phase_operation(1j ** (2 * phase_correction_half_turns)),
85111
]
86112

cirq-core/cirq/circuits/qasm_output_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121
import pytest
22+
import sympy
2223

2324
import cirq
2425
from cirq.circuits.qasm_output import QasmTwoQubitGate, QasmUGate
@@ -68,6 +69,28 @@ def test_u_gate_from_qiskit_ugate_unitary(_) -> None:
6869
np.testing.assert_allclose(cirq.unitary(g), u, atol=1e-7)
6970

7071

72+
def test_u_gate_params() -> None:
73+
q = cirq.LineQubit(0)
74+
a, b, c = sympy.symbols('a b c')
75+
u_gate = QasmUGate(a, b, c)
76+
assert u_gate == QasmUGate(a, b + 2, c - 2)
77+
assert u_gate != QasmUGate(a, b + 1, c - 1)
78+
assert cirq.is_parameterized(u_gate)
79+
assert cirq.parameter_names(u_gate) == {'a', 'b', 'c'}
80+
assert not cirq.has_unitary(u_gate)
81+
cirq.testing.assert_equivalent_repr(u_gate)
82+
cirq.testing.assert_implements_consistent_protocols(u_gate)
83+
u_gate_caps = cirq.resolve_parameters(u_gate, {'a': 'A', 'b': 'B', 'c': 'C'})
84+
assert u_gate_caps == QasmUGate(*sympy.symbols('A B C'))
85+
resolver = {'A': 0.1, 'B': 2.2, 'C': -1.7}
86+
resolved = cirq.resolve_parameters(u_gate_caps, resolver)
87+
assert cirq.approx_eq(resolved, QasmUGate(0.1, 0.2, 0.3))
88+
resolved_then_decomposed = cirq.decompose_once_with_qubits(resolved, [q])
89+
decomposed = cirq.decompose_once_with_qubits(u_gate_caps, [q])
90+
decomposed_then_resolved = [cirq.resolve_parameters(g, resolver) for g in decomposed]
91+
assert resolved_then_decomposed == decomposed_then_resolved
92+
93+
7194
def test_qasm_two_qubit_gate_repr() -> None:
7295
cirq.testing.assert_equivalent_repr(
7396
QasmTwoQubitGate.from_matrix(cirq.testing.random_unitary(4))

0 commit comments

Comments
 (0)