Skip to content

Commit 69d59aa

Browse files
committed
Add tag_transformers remove_tags, index_tags.
To be used in #6994.
1 parent de37106 commit 69d59aa

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

cirq-core/cirq/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@
363363
eject_z as eject_z,
364364
expand_composite as expand_composite,
365365
HardCodedInitialMapper as HardCodedInitialMapper,
366+
index_tags as index_tags,
366367
is_negligible_turn as is_negligible_turn,
367368
LineInitialMapper as LineInitialMapper,
368369
MappingManager as MappingManager,
@@ -385,6 +386,7 @@
385386
prepare_two_qubit_state_using_sqrt_iswap as prepare_two_qubit_state_using_sqrt_iswap,
386387
quantum_shannon_decomposition as quantum_shannon_decomposition,
387388
RouteCQC as RouteCQC,
389+
remove_tags as remove_tags,
388390
routed_circuit_with_mapping as routed_circuit_with_mapping,
389391
SqrtIswapTargetGateset as SqrtIswapTargetGateset,
390392
single_qubit_matrix_to_gates as single_qubit_matrix_to_gates,

cirq-core/cirq/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
transformer as transformer,
120120
)
121121

122+
from cirq.transformers.tag_transformers import index_tags, remove_tags
123+
122124
from cirq.transformers.transformer_primitives import (
123125
map_moments as map_moments,
124126
map_operations as map_operations,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
from typing import Callable, Hashable, Optional, TYPE_CHECKING
17+
18+
from cirq.transformers import transformer_api, transformer_primitives
19+
20+
if TYPE_CHECKING:
21+
import cirq
22+
23+
24+
@transformer_api.transformer
25+
def index_tags(
26+
circuit: 'cirq.AbstractCircuit',
27+
*,
28+
context: Optional['cirq.TransformerContext'] = None,
29+
target_tags: Optional[set[Hashable]] = None,
30+
) -> 'cirq.Circuit':
31+
"""Indexes tags in target_tags as tag_0, tag_1, ... per tag.
32+
33+
Args:
34+
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
35+
context: `cirq.TransformerContext` storing common configurable options for transformers.
36+
target_tags: Tags to be indexed.
37+
38+
Returns:
39+
Copy of the transformed input circuit.
40+
"""
41+
target_tags = target_tags or set()
42+
tag_iter_by_tags = {tag: itertools.count(start=0, step=1) for tag in target_tags}
43+
44+
def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
45+
tag_set = set(op.tags)
46+
nonlocal tag_iter_by_tags
47+
for tag in target_tags.intersection(op.tags):
48+
tag_set.remove(tag)
49+
tag_set.add(f"{tag}_{next(tag_iter_by_tags[tag])}")
50+
51+
return op.untagged.with_tags(*tag_set)
52+
53+
return transformer_primitives.map_operations(
54+
circuit,
55+
_map_func,
56+
deep=context.deep if context else False,
57+
tags_to_ignore=context.tags_to_ignore if context else [],
58+
).unfreeze(copy=False)
59+
60+
61+
@transformer_api.transformer
62+
def remove_tags(
63+
circuit: 'cirq.AbstractCircuit',
64+
*,
65+
context: Optional['cirq.TransformerContext'] = None,
66+
target_tags: Optional[set[Hashable]] = None,
67+
remove_if: Callable[[Hashable], bool] = lambda _: False,
68+
) -> 'cirq.Circuit':
69+
"""Removes tags from the operations based on the input args.
70+
71+
Note: context.tags_to_ignore has higher priority than target_tags and remove_if.
72+
73+
Args:
74+
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
75+
context: `cirq.TransformerContext` storing common configurable options for transformers.
76+
target_tags: Tags to be removed.
77+
remove_if: A callable(tag) that returns True if the tag should be removed.
78+
Defaults to False.
79+
80+
Returns:
81+
Copy of the transformed input circuit.
82+
"""
83+
target_tags = target_tags or set()
84+
85+
def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
86+
remaing_tags = set()
87+
for tag in op.tags:
88+
if not remove_if(tag) and tag not in target_tags:
89+
remaing_tags.add(tag)
90+
91+
return op.untagged.with_tags(*remaing_tags)
92+
93+
return transformer_primitives.map_operations(
94+
circuit,
95+
_map_func,
96+
deep=context.deep if context else False,
97+
tags_to_ignore=context.tags_to_ignore if context else [],
98+
).unfreeze(copy=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cirq
16+
17+
18+
def check_same_circuit_with_same_tag_sets(circuit1, circuit2):
19+
for op1, op2 in zip(circuit1.all_operations(), circuit2.all_operations()):
20+
assert set(op1.tags) == set(op2.tags)
21+
assert op1.untagged == op2.untagged
22+
23+
24+
def test_index_tags():
25+
q0, q1 = cirq.LineQubit.range(2)
26+
input_circuit = cirq.Circuit(
27+
cirq.X(q0).with_tags("tag1", "tag2"),
28+
cirq.Y(q1).with_tags("tag1"),
29+
cirq.CZ(q0, q1).with_tags("tag2"),
30+
)
31+
expected_circuit = cirq.Circuit(
32+
cirq.X(q0).with_tags("tag1_0", "tag2_0"),
33+
cirq.Y(q1).with_tags("tag1_1"),
34+
cirq.CZ(q0, q1).with_tags("tag2_1"),
35+
)
36+
check_same_circuit_with_same_tag_sets(
37+
cirq.index_tags(input_circuit, target_tags={"tag1", "tag2"}), expected_circuit
38+
)
39+
40+
41+
def test_remove_tags():
42+
q0, q1 = cirq.LineQubit.range(2)
43+
input_circuit = cirq.Circuit(
44+
cirq.X(q0).with_tags("tag1", "tag2"),
45+
cirq.Y(q1).with_tags("tag1"),
46+
cirq.CZ(q0, q1).with_tags("tag2"),
47+
)
48+
expected_circuit = cirq.Circuit(
49+
cirq.X(q0).with_tags("tag2"), cirq.Y(q1), cirq.CZ(q0, q1).with_tags("tag2")
50+
)
51+
check_same_circuit_with_same_tag_sets(
52+
cirq.remove_tags(input_circuit, target_tags={"tag1"}), expected_circuit
53+
)
54+
55+
56+
def test_remove_tags_via_remove_if():
57+
q0, q1 = cirq.LineQubit.range(2)
58+
input_circuit = cirq.Circuit(
59+
cirq.X(q0).with_tags("tag1", "tag2"),
60+
cirq.Y(q1).with_tags("not_tag1"),
61+
cirq.CZ(q0, q1).with_tags("tag2"),
62+
)
63+
expected_circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q1).with_tags("not_tag1"), cirq.CZ(q0, q1))
64+
check_same_circuit_with_same_tag_sets(
65+
cirq.remove_tags(input_circuit, remove_if=lambda tag: tag.startswith("tag")),
66+
expected_circuit,
67+
)
68+
69+
70+
def test_remove_tags_with_tags_to_ignore():
71+
q0, q1 = cirq.LineQubit.range(2)
72+
input_circuit = cirq.Circuit(
73+
cirq.X(q0).with_tags("tag1", "tag0"),
74+
cirq.Y(q1).with_tags("not_tag1"),
75+
cirq.CZ(q0, q1).with_tags("tag2"),
76+
)
77+
expected_circuit = cirq.Circuit(
78+
cirq.X(q0).with_tags("tag1", "tag0"), cirq.Y(q1).with_tags("not_tag1"), cirq.CZ(q0, q1)
79+
)
80+
check_same_circuit_with_same_tag_sets(
81+
cirq.remove_tags(
82+
input_circuit,
83+
remove_if=lambda tag: tag.startswith("tag"),
84+
context=cirq.TransformerContext(tags_to_ignore=["tag0"]),
85+
),
86+
expected_circuit,
87+
)

0 commit comments

Comments
 (0)