Skip to content

Simplify decomposition of controlled eigengates with global phase #7291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def test_decompose():
op = cirq.H(q0).with_classical_controls('a')
assert cirq.decompose(op) == [
(cirq.Y(q0) ** 0.5).with_classical_controls('a'),
cirq.XPowGate(exponent=1.0, global_shift=-0.25).on(q0).with_classical_controls('a'),
cirq.X(q0).with_classical_controls('a'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so here's where we probably need to have an opt in flag on the decompose context object. Having the basic gates suddenly decompose to multiple gates by default probably would be too likely to break something. While generally it seems that changing details around complex decompositions have been approved, this change seems a little too fundamental.

cirq.global_phase_operation(1j**-0.5).with_classical_controls('a'),
]


Expand Down
18 changes: 18 additions & 0 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ def __init__(self, *, rads: value.TParamVal):
def _with_exponent(self, exponent: value.TParamVal) -> 'Rx':
return Rx(rads=exponent * _pi(exponent))

def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
"""Returns:
NotImplemented, to signify the gate doesn't decompose further.
"""
return NotImplemented

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
Expand Down Expand Up @@ -537,6 +543,12 @@ def __init__(self, *, rads: value.TParamVal):
def _with_exponent(self, exponent: value.TParamVal) -> 'Ry':
return Ry(rads=exponent * _pi(exponent))

def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
"""Returns:
NotImplemented, to signify the gate doesn't decompose further.
"""
return NotImplemented

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
Expand Down Expand Up @@ -882,6 +894,12 @@ def __init__(self, *, rads: value.TParamVal):
def _with_exponent(self, exponent: value.TParamVal) -> 'Rz':
return Rz(rads=exponent * _pi(exponent))

def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> NotImplementedType:
"""Returns:
NotImplemented, to signify the gate doesn't decompose further.
"""
return NotImplemented

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
Expand Down
12 changes: 12 additions & 0 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
control_values as cv,
controlled_operation as cop,
diagonal_gate as dg,
eigen_gate,
global_phase_op as gp,
op_tree,
raw_types,
Expand Down Expand Up @@ -159,6 +160,12 @@ def _decompose_with_context_(
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
control_qubits = list(qubits[: self.num_controls()])
# If the subgate is an EigenGate with non-zero phase, try to decompose it
# into a phase-free gate and a global phase gate.
if isinstance(self.sub_gate, eigen_gate.EigenGate) and self.sub_gate.global_shift != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way you can get rid of this condition? Like try decomposing the subgate first, then only enter the branch if the subgate has decomposed? A big goal is to reduce the amount of type checking needed here.

If this works, then the duplicate function call at the end of this function can be removed, and the function can just return NotImplemented if it gets to that point.

result = self._decompose_sub_gate_with_controls(qubits, context)
if result is not NotImplemented:
return result
if (
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
Expand Down Expand Up @@ -219,6 +226,11 @@ def _decompose_with_context_(
control_qid_shape=self.control_qid_shape,
).on(*control_qubits)
return [result, controlled_phase_op]
return self._decompose_sub_gate_with_controls(qubits, context)

def _decompose_sub_gate_with_controls(
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
result = protocols.decompose_once_with_qubits(
self.sub_gate,
qubits[self.num_controls() :],
Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,22 @@ def _test_controlled_gate_is_consistent(
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13)


@pytest.mark.parametrize(
'sub_gate, expected_decomposition',
[
(cirq.XPowGate(global_shift=0.22), [cirq.Y**-0.5, cirq.CZ, cirq.Y**0.5, cirq.Z**0.22]),
(cirq.ZPowGate(exponent=1.2, global_shift=0.3), [cirq.CZ**1.2, cirq.Z**0.36]),
],
)
def test_decompose_takes_out_global_phase(
sub_gate: cirq.Gate, expected_decomposition: Sequence[cirq.Gate]
):
cgate = cirq.ControlledGate(sub_gate, num_controls=1)
qubits = cirq.LineQubit.range(cgate.num_qubits())
dec = cirq.decompose(cgate.on(*qubits))
assert [op.gate for op in dec] == expected_decomposition


def test_pow_inverse():
assert cirq.inverse(CRestricted, None) is None
assert cirq.pow(CRestricted, 1.5, None) is None
Expand Down
20 changes: 19 additions & 1 deletion cirq-core/cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import sympy

from cirq import protocols, value
from cirq.ops import raw_types
from cirq.ops import global_phase_op, raw_types

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -375,6 +375,24 @@ def _json_dict_(self) -> Dict[str, Any]:
def _measurement_key_objs_(self):
return frozenset()

def _decompose_(
self, qubits: Tuple['cirq.Qid', ...]
) -> Union[NotImplementedType, 'cirq.OP_TREE']:
"""Attempts to decompose the gate into a phase-free gate and a global phase gate.

Returns:
NotImplemented, if global phase or exponent are 0. Otherwise a phase-free gate
applied to the qubits followed by a global phase gate.
"""
if self.global_shift == 0 or self.exponent == 0:
return NotImplemented
self_without_phase = self._with_exponent(self.exponent)
# This doesn't work for gates that fix global_shift, such as Rx. These gates must define
# their own _decompose_ method.
self_without_phase._global_shift = 0
global_phase = 1j ** (2 * self.global_shift * self.exponent)
return [self_without_phase.on(*qubits), global_phase_op.GlobalPhaseGate(global_phase)()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check that the remaining phase isn't zero, and drop it if it is. This could be the case if shift==0.5 and exponent==4 for instance. The decomposition would factor out the shift, leaving, say, X**4 and an identity phase gate that there's no reason to keep.



def _lcm(vals: Iterable[int]) -> int:
t = 1
Expand Down
41 changes: 41 additions & 0 deletions cirq-core/cirq/ops/eigen_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import cirq
from cirq import value
from cirq.ops import global_phase_op
from cirq.testing import assert_has_consistent_trace_distance_bound


Expand Down Expand Up @@ -421,3 +422,43 @@ def _with_exponent(self, exponent):
)
def test_equal_up_to_global_phase(gate1, gate2, eq_up_to_global_phase):
assert cirq.equal_up_to_global_phase(gate1, gate2) == eq_up_to_global_phase


@pytest.mark.parametrize(
'gate',
[
cirq.Z,
cirq.Z**2,
cirq.XPowGate(global_shift=0.0),
cirq.rx(0),
cirq.ry(0),
cirq.rz(0),
cirq.CZPowGate(exponent=0.0, global_shift=0.25),
],
)
def test_decompose_once_returns_not_implemented(gate: cirq.Gate):
qubits = cirq.LineQubit.range(gate.num_qubits())
assert cirq.decompose_once(gate.on(*qubits), default=NotImplemented) == NotImplemented


@pytest.mark.parametrize(
'gate, expected_decomposition',
[
(cirq.X, [cirq.X]),
(cirq.ZPowGate(global_shift=0.5), [cirq.Z, global_phase_op.GlobalPhaseGate(1j)]),
(
cirq.ZPowGate(global_shift=0.5) ** sympy.Symbol('e'),
[
cirq.Z ** sympy.Symbol('e'),
global_phase_op.GlobalPhaseGate(1j ** (1.0 * sympy.Symbol('e'))),
],
),
(cirq.rx(np.pi / 2), [cirq.rx(np.pi / 2)]),
(cirq.ry(np.pi / 2), [cirq.ry(np.pi / 2)]),
(cirq.rz(np.pi / 2), [cirq.rz(np.pi / 2)]),
],
)
def test_decompose_takes_out_global_phase(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]):
qubits = cirq.LineQubit.range(gate.num_qubits())
dec = cirq.decompose(gate.on(*qubits))
assert [op.gate for op in dec] == expected_decomposition
59 changes: 59 additions & 0 deletions cirq-core/cirq/ops/parity_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for `parity_gates.py`."""

from typing import List

import numpy as np
import pytest
import sympy
Expand Down Expand Up @@ -348,3 +350,60 @@ def test_clifford_protocols(gate_cls: type[cirq.EigenGate], exponent: float, is_
else:
assert not cirq.has_stabilizer_effect(gate)
assert gate._decompose_into_clifford_with_qubits_(cirq.LineQubit.range(2)) is NotImplemented


@pytest.mark.parametrize(
'gate, expected_decomposition',
[
(
cirq.XXPowGate(),
[
(cirq.Y**-0.5).on(cirq.LineQubit(0)),
(cirq.Y**-0.5).on(cirq.LineQubit(1)),
cirq.Z(cirq.LineQubit(0)),
cirq.Z(cirq.LineQubit(1)),
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
(cirq.Y**0.5).on(cirq.LineQubit(0)),
(cirq.Y**0.5).on(cirq.LineQubit(1)),
],
),
(
cirq.YYPowGate(),
[
(cirq.X**0.5).on(cirq.LineQubit(0)),
(cirq.X**0.5).on(cirq.LineQubit(1)),
cirq.Z(cirq.LineQubit(0)),
cirq.Z(cirq.LineQubit(1)),
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
(cirq.X**-0.5).on(cirq.LineQubit(0)),
(cirq.X**-0.5).on(cirq.LineQubit(1)),
],
),
(
cirq.ZZPowGate(),
[
cirq.Z(cirq.LineQubit(0)),
cirq.Z(cirq.LineQubit(1)),
(cirq.CZ**-2.0).on(cirq.LineQubit(0), cirq.LineQubit(1)),
],
),
(
cirq.MSGate(rads=0),
[
(cirq.Y**-0.5).on(cirq.LineQubit(0)),
(cirq.Y**-0.5).on(cirq.LineQubit(1)),
(cirq.Z**0.0).on(cirq.LineQubit(0)),
(cirq.Z**0.0).on(cirq.LineQubit(1)),
cirq.CZPowGate(exponent=-0.0, global_shift=0.25).on(
cirq.LineQubit(0), cirq.LineQubit(1)
),
(cirq.Y**0.5).on(cirq.LineQubit(0)),
(cirq.Y**0.5).on(cirq.LineQubit(1)),
],
),
],
)
def test_gate_decomposition(gate: cirq.Gate, expected_decomposition: List[cirq.Gate]):
qubits = cirq.LineQubit.range(gate.num_qubits())
dec = cirq.decompose(gate.on(*qubits))
assert [op for op in dec] == expected_decomposition
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_tagged_operation_forwards_protocols():
np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
assert cirq.has_unitary(tagged_h)
assert cirq.decompose(tagged_h) == cirq.decompose(h)
assert [*tagged_h._decompose_()] == cirq.decompose(h)
assert [*tagged_h._decompose_()] == cirq.decompose_once(h)
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
assert cirq.equal_up_to_global_phase(h, tagged_h)
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()
Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/transformers/merge_single_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def merge_single_qubit_moments_to_phxz(

def can_merge_moment(m: 'cirq.Moment'):
return all(
protocols.num_qubits(op) == 1
(protocols.num_qubits(op) == 1 or protocols.num_qubits(op) == 0)
and protocols.has_unitary(op)
and tags_to_ignore.isdisjoint(op.tags)
for op in m
Expand Down Expand Up @@ -144,6 +144,10 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']:
)
if gate:
ret_ops.append(gate(q))
# Transfer global phase
for op in m1.operations + m2.operations:
if protocols.num_qubits(op) == 0:
ret_ops.append(op)
return circuits.Moment(ret_ops)

return transformer_primitives.merge_moments(
Expand Down
36 changes: 34 additions & 2 deletions cirq-core/cirq/transformers/merge_single_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,45 @@ def test_merge_single_qubit_moments_to_phxz_deep():
)


def test_merge_single_qubit_moments_to_phxz_global_phase():
def test_merge_single_qubit_gates_to_phxz_global_phase():
c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on())
c2 = cirq.merge_single_qubit_gates_to_phxz(c)
assert c == c2


def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase():
def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase():
c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on())
c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c)
assert c == c2


def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_first_moment():
q0 = cirq.LineQubit(0)
c_orig = cirq.Circuit(
cirq.Moment(cirq.Y(q0) ** 0.5, cirq.GlobalPhaseGate(1j**0.5).on()), cirq.Moment(cirq.X(q0))
)
c_expected = cirq.Circuit(
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0),
cirq.GlobalPhaseGate(1j**0.5).on(),
)
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"])
c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context)
assert c_new == c_expected


def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_second_moment():
q0 = cirq.LineQubit(0)
c_orig = cirq.Circuit(
cirq.Moment(cirq.Y(q0) ** 0.5), cirq.Moment(cirq.X(q0), cirq.GlobalPhaseGate(1j**0.5).on())
)
c_expected = cirq.Circuit(
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0),
cirq.GlobalPhaseGate(1j**0.5).on(),
)
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"])
c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context)
assert c_new == c_expected