1717from __future__ import annotations
1818
1919import re
20- from collections .abc import Callable , Iterator , Sequence
20+ from collections .abc import Callable , Iterator , Sequence , Set
2121from typing import TYPE_CHECKING
2222
2323import numpy as np
24+ import sympy
2425
2526from cirq import linalg , ops , protocols , value
27+ from cirq ._compat import proper_repr
2628
2729if TYPE_CHECKING :
2830 import cirq
2931
3032
3133@value .value_equality (approximate = True )
3234class QasmUGate (ops .Gate ):
33- def __init__ (self , theta , phi , lmda ) -> None :
35+ def __init__ (self , theta : cirq . TParamVal , phi : cirq . TParamVal , lmda : cirq . TParamVal ) -> None :
3436 """A QASM gate representing any single qubit unitary with a series of
3537 three rotations, Z, Y, and Z.
3638
@@ -41,9 +43,9 @@ def __init__(self, theta, phi, lmda) -> None:
4143 phi: Half turns to rotate about Z (applied last).
4244 lmda: Half turns to rotate about Z (applied first).
4345 """
44- self .lmda = lmda % 2
4546 self .theta = theta % 2
4647 self .phi = phi % 2
48+ self .lmda = lmda % 2
4749
4850 def _num_qubits_ (self ) -> int :
4951 return 1
@@ -54,7 +56,28 @@ def from_matrix(mat: np.ndarray) -> QasmUGate:
5456 return QasmUGate (rotation / np .pi , post_phase / np .pi , pre_phase / np .pi )
5557
5658 def _has_unitary_ (self ):
57- return True
59+ return not self ._is_parameterized_ ()
60+
61+ def _is_parameterized_ (self ) -> bool :
62+ return (
63+ protocols .is_parameterized (self .theta )
64+ or protocols .is_parameterized (self .phi )
65+ or protocols .is_parameterized (self .lmda )
66+ )
67+
68+ def _parameter_names_ (self ) -> Set [str ]:
69+ return (
70+ protocols .parameter_names (self .theta )
71+ | protocols .parameter_names (self .phi )
72+ | protocols .parameter_names (self .lmda )
73+ )
74+
75+ def _resolve_parameters_ (self , resolver : cirq .ParamResolver , recursive : bool ) -> QasmUGate :
76+ return QasmUGate (
77+ protocols .resolve_parameters (self .theta , resolver , recursive ),
78+ protocols .resolve_parameters (self .phi , resolver , recursive ),
79+ protocols .resolve_parameters (self .lmda , resolver , recursive ),
80+ )
5881
5982 def _qasm_ (self , qubits : tuple [cirq .Qid , ...], args : cirq .QasmArgs ) -> str :
6083 args .validate_version ('2.0' , '3.0' )
@@ -69,18 +92,21 @@ def _qasm_(self, qubits: tuple[cirq.Qid, ...], args: cirq.QasmArgs) -> str:
6992 def __repr__ (self ) -> str :
7093 return (
7194 f'cirq.circuits.qasm_output.QasmUGate('
72- f'theta={ self .theta !r } , '
73- f'phi={ self .phi !r } , '
74- f'lmda={ self .lmda } )'
95+ f'theta={ proper_repr ( self .theta ) } , '
96+ f'phi={ proper_repr ( self .phi ) } , '
97+ f'lmda={ proper_repr ( self .lmda ) } )'
7598 )
7699
77100 def _decompose_ (self , qubits ):
101+ def mul_pi (x ):
102+ return x * (sympy .pi if protocols .is_parameterized (x ) else np .pi )
103+
78104 q = qubits [0 ]
79105 phase_correction_half_turns = (self .phi + self .lmda ) / 2
80106 return [
81- ops .rz (self .lmda * np . pi ).on (q ),
82- ops .ry (self .theta * np . pi ).on (q ),
83- ops .rz (self .phi * np . pi ).on (q ),
107+ ops .rz (mul_pi ( self .lmda ) ).on (q ),
108+ ops .ry (mul_pi ( self .theta ) ).on (q ),
109+ ops .rz (mul_pi ( self .phi ) ).on (q ),
84110 ops .global_phase_operation (1j ** (2 * phase_correction_half_turns )),
85111 ]
86112
0 commit comments