From 339897d90f7fe347e6762f999bedec581b68f136 Mon Sep 17 00:00:00 2001 From: ddddddanni Date: Mon, 17 Nov 2025 12:44:16 -0800 Subject: [PATCH] init version - need to test --- ...ing_measurement_with_readout_mitigation.py | 71 +++++++++++++++---- ...easurement_with_readout_mitigation_test.py | 9 ++- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py index e69d097b014..42ff9fc8b65 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py @@ -274,12 +274,19 @@ def _pauli_strings_to_basis_change_with_sweep( def _generate_basis_change_circuits( normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], insert_strategy: circuits.InsertStrategy, + qubits_to_measure: Sequence[ops.Qid] | None = None, ) -> list[circuits.Circuit]: """Generates basis change circuits for each group of Pauli strings.""" pauli_measurement_circuits = list[circuits.Circuit]() for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items(): - qid_list = list(sorted(input_circuit.all_qubits())) + # If qubits_to_measure is provided, use it; otherwise, use all qubits in the circuit. + qid_list = ( + list(qubits_to_measure) + if qubits_to_measure is not None + else list(sorted(input_circuit.all_qubits())) + ) + basis_change_circuits = [] input_circuit_unfrozen = input_circuit.unfreeze() for pauli_strings in pauli_string_groups: @@ -298,12 +305,19 @@ def _generate_basis_change_circuits( def _generate_basis_change_circuits_with_sweep( normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], insert_strategy: circuits.InsertStrategy, + qubits_to_measure: Sequence[ops.Qid] | None = None, ) -> tuple[list[circuits.Circuit], list[study.Sweepable]]: """Generates basis change circuits for each group of Pauli strings with sweep.""" parameterized_circuits = list[circuits.Circuit]() sweep_params = list[study.Sweepable]() for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items(): - qid_list = list(sorted(input_circuit.all_qubits())) + # If qubits_to_measure is provided, use it; otherwise default to circuit qubits. + qid_list = ( + list(qubits_to_measure) + if qubits_to_measure is not None + else list(sorted(input_circuit.all_qubits())) + ) + phi_symbols = sympy.symbols(f"phi:{len(qid_list)}") theta_symbols = sympy.symbols(f"theta:{len(qid_list)}") @@ -379,6 +393,7 @@ def _process_pauli_measurement_results( pauli_repetitions: int, timestamp: float, disable_readout_mitigation: bool = False, + fixed_calibration_key: tuple[ops.Qid, ...] | None = None, ) -> list[PauliStringMeasurementResult]: """Calculates both error-mitigated expectation values and unmitigated expectation values from measurement results. @@ -399,11 +414,13 @@ def _process_pauli_measurement_results( timestamp: The timestamp of the calibration results. disable_readout_mitigation: If set to True, returns no error-mitigated error expectation values. + fixed_calibration_key: If provided, uses this key to retrieve the calibration result + from `calibration_results` for all Pauli strings, regardless of their specific + support. This is used when `measure_on_full_support` is True. Returns: A list of PauliStringMeasurementResult. """ - pauli_measurement_results: list[PauliStringMeasurementResult] = [] for pauli_group_index, circuit_result in enumerate(circuit_results): @@ -411,10 +428,13 @@ def _process_pauli_measurement_results( pauli_strs = pauli_string_groups[pauli_group_index] pauli_readout_qubits = _extract_readout_qubits(pauli_strs) + if fixed_calibration_key is not None: + calibration_key = fixed_calibration_key + else: + calibration_key = tuple(pauli_readout_qubits) + calibration_result = ( - calibration_results[tuple(pauli_readout_qubits)] - if not disable_readout_mitigation - else None + calibration_results[calibration_key] if not disable_readout_mitigation else None ) for pauli_str in pauli_strs: @@ -491,6 +511,7 @@ def measure_pauli_strings( rng_or_seed: np.random.Generator | int, use_sweep: bool = False, insert_strategy: circuits.InsertStrategy = circuits.InsertStrategy.INLINE, + measure_on_full_support: bool = False, ) -> list[CircuitToPauliStringsMeasurementResult]: """Measures expectation values of Pauli strings on given circuits with/without readout error mitigation. @@ -524,7 +545,11 @@ def measure_pauli_strings( use_sweep: If True, uses parameterized circuits and sweeps parameters for both Pauli measurements and readout benchmarking. Defaults to False. insert_strategy: The strategy for inserting measurement operations into the circuit. - Defaults to circuits.InsertStrategy.INLINE. + measure_on_full_support: If True, calculates the union of all qubits used in all + Pauli strings (the full support). All circuits will then measure this full set + of qubits, and readout benchmarking will be performed only once on this full set, + rather than for every unique subset of Pauli qubits. This significantly reduces + overhead when measuring many Pauli strings with varying support. Returns: A list of CircuitToPauliStringsMeasurementResult objects, where each object contains: @@ -545,12 +570,24 @@ def measure_pauli_strings( # Extract unique qubit tuples from input pauli strings unique_qubit_tuples = set() - for pauli_string_groups in normalized_circuits_to_pauli.values(): - for pauli_strings in pauli_string_groups: - unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings))) + if measure_on_full_support: + full_support = set() + for pauli_string_groups in normalized_circuits_to_pauli.values(): + for pauli_strings in pauli_string_groups: + for pauli_string in pauli_strings: + full_support.update(pauli_string.qubits) + # One calibration group + unique_qubit_tuples.add(tuple(sorted(full_support))) + else: + for pauli_string_groups in normalized_circuits_to_pauli.values(): + for pauli_strings in pauli_string_groups: + unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings))) + # qubits_list is a list of qubit tuples qubits_list = sorted(unique_qubit_tuples) + qubits_to_measure_arg = list(qubits_list[0]) if measure_on_full_support else None + # Build the basis-change circuits for each Pauli string group pauli_measurement_circuits: list[circuits.Circuit] = [] sweep_params: list[study.Sweepable] = [] @@ -565,7 +602,7 @@ def measure_pauli_strings( if use_sweep: pauli_measurement_circuits, sweep_params = _generate_basis_change_circuits_with_sweep( - normalized_circuits_to_pauli, insert_strategy + normalized_circuits_to_pauli, insert_strategy, qubits_to_measure_arg ) # Run benchmarking using sweep for readout calibration @@ -580,7 +617,7 @@ def measure_pauli_strings( else: pauli_measurement_circuits = _generate_basis_change_circuits( - normalized_circuits_to_pauli, insert_strategy + normalized_circuits_to_pauli, insert_strategy, qubits_to_measure_arg ) # Run shuffled benchmarking for readout calibration @@ -598,7 +635,11 @@ def measure_pauli_strings( results: list[CircuitToPauliStringsMeasurementResult] = [] circuit_result_index = 0 for i, (input_circuit, pauli_string_groups) in enumerate(normalized_circuits_to_pauli.items()): - qubits_in_circuit = tuple(sorted(input_circuit.all_qubits())) + current_measurement_qubits = ( + qubits_to_measure_arg + if measure_on_full_support and qubits_to_measure_arg is not None + else sorted(input_circuit.all_qubits()) + ) disable_readout_mitigation = False if num_random_bitstrings != 0 else True @@ -611,14 +652,16 @@ def measure_pauli_strings( ] circuit_result_index += len(pauli_string_groups) + fixed_calibration_key = tuple(qubits_to_measure_arg) if measure_on_full_support else None pauli_measurement_results = _process_pauli_measurement_results( - list(qubits_in_circuit), + list(current_measurement_qubits), pauli_string_groups, circuits_results_for_group, calibration_results, pauli_repetitions, time.time(), disable_readout_mitigation, + fixed_calibration_key, ) results.append( CircuitToPauliStringsMeasurementResult( diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py index 2d052713078..40d16b3defe 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py @@ -510,7 +510,14 @@ def test_many_group_pauli_in_circuits_with_coefficient(use_sweep: bool) -> None: simulator = cirq.Simulator() circuits_with_pauli_expectations = measure_pauli_strings( - circuits_to_pauli, sampler, 300, 300, 300, np.random.default_rng(), use_sweep + circuits_to_pauli, + sampler, + 300, + 300, + 300, + np.random.default_rng(), + use_sweep, + measure_on_full_support=True, ) for circuit_with_pauli_expectations in circuits_with_pauli_expectations: