Skip to content

Refactor simulator RNG handling #6944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def to_numpy(self) -> np.ndarray:
"""An alias for the state vector."""
return self.state_vector()

def apply_op(self, op: Any, axes: Sequence[int], prng: np.random.RandomState):
def apply_op(self, op: Any, axes: Sequence[int], prng: np.random.Generator):
"""Applies a unitary operation, mutating the object to represent the new state.

op:
Expand Down Expand Up @@ -481,7 +481,7 @@ def estimation_stats(self):
}

def _measure(
self, axes: Sequence[int], prng: np.random.RandomState, collapse_state_vector=True
self, axes: Sequence[int], prng: np.random.Generator, collapse_state_vector=True
) -> List[int]:
results: List[int] = []

Expand Down Expand Up @@ -565,7 +565,7 @@ def __init__(
self,
*,
qubits: Sequence['cirq.Qid'],
prng: np.random.RandomState,
prng: Union[np.random.Generator, np.random.RandomState],
simulation_options: MPSOptions = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def random_rotations_between_two_qubit_circuit(
q1: 'cirq.Qid',
depth: int,
two_qubit_op_factory: Callable[
['cirq.Qid', 'cirq.Qid', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.Qid', 'cirq.Qid', 'np.random.Generator'], 'cirq.OP_TREE'
] = lambda a, b, _: ops.CZPowGate()(a, b),
single_qubit_gates: Sequence['cirq.Gate'] = (
ops.X**0.5,
Expand Down Expand Up @@ -354,7 +354,7 @@ def _get_random_combinations(

combinations_by_layer = []
for pairs, layer in pair_gen:
combinations = rs.randint(0, n_library_circuits, size=(n_combinations, len(pairs)))
combinations = rs.integers(0, n_library_circuits, size=(n_combinations, len(pairs)))
combinations_by_layer.append(
CircuitLibraryCombination(layer=layer, combinations=combinations, pairs=pairs)
)
Expand Down Expand Up @@ -553,7 +553,7 @@ def random_rotations_between_grid_interaction_layers_circuit(
depth: int,
*, # forces keyword arguments
two_qubit_op_factory: Callable[
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.Generator'], 'cirq.OP_TREE'
] = lambda a, b, _: ops.CZPowGate()(a, b),
pattern: Sequence[GridInteractionLayer] = GRID_STAGGERED_PATTERN,
single_qubit_gates: Sequence['cirq.Gate'] = (
Expand Down Expand Up @@ -641,7 +641,7 @@ def __init__(
self,
qubits: Sequence['cirq.Qid'],
single_qubit_gates: Sequence['cirq.Gate'],
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> None:
self.qubits = qubits
self.single_qubit_gates = single_qubit_gates
Expand All @@ -651,9 +651,9 @@ def new_layer(self, previous_single_qubit_layer: 'cirq.Moment') -> 'cirq.Moment'
def random_gate(qubit: 'cirq.Qid') -> 'cirq.Gate':
excluded_op = previous_single_qubit_layer.operation_at(qubit)
excluded_gate = excluded_op.gate if excluded_op is not None else None
g = self.single_qubit_gates[self.prng.randint(0, len(self.single_qubit_gates))]
g = self.single_qubit_gates[self.prng.integers(0, len(self.single_qubit_gates))]
while g is excluded_gate:
g = self.single_qubit_gates[self.prng.randint(0, len(self.single_qubit_gates))]
g = self.single_qubit_gates[self.prng.integers(0, len(self.single_qubit_gates))]
return g

return circuits.Moment(random_gate(q).on(q) for q in self.qubits)
Expand All @@ -673,7 +673,7 @@ def new_layer(self, previous_single_qubit_layer: 'cirq.Moment') -> 'cirq.Moment'
def _single_qubit_gates_arg_to_factory(
single_qubit_gates: Sequence['cirq.Gate'],
qubits: Sequence['cirq.Qid'],
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> _SingleQubitLayerFactory:
"""Parse the `single_qubit_gates` argument for circuit generation functions.

Expand All @@ -690,10 +690,10 @@ def _single_qubit_gates_arg_to_factory(
def _two_qubit_layer(
coupled_qubit_pairs: List[GridQubitPairT],
two_qubit_op_factory: Callable[
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.Generator'], 'cirq.OP_TREE'
],
layer: GridInteractionLayer,
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> Iterator['cirq.OP_TREE']:
for a, b in coupled_qubit_pairs:
if (a, b) in layer or (b, a) in layer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def test_random_rotation_between_two_qubit_circuit():
"""\
0 1
│ │
Y^0.5 X^0.5
PhX(0.25)^0.5 Y^0.5
│ │
@─────────────@
│ │
PhX(0.25)^0.5 Y^0.5
X^0.5 PhX(0.25)^0.5
│ │
@─────────────@
│ │
Y^0.5 X^0.5
Y^0.5 Y^0.5
│ │
@─────────────@
│ │
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_random_rotations_between_grid_interaction_layers(
qubits: Iterable[cirq.GridQubit],
depth: int,
two_qubit_op_factory: Callable[
[cirq.GridQubit, cirq.GridQubit, np.random.RandomState], cirq.OP_TREE
[cirq.GridQubit, cirq.GridQubit, np.random.Generator], cirq.OP_TREE
],
pattern: Sequence[GridInteractionLayer],
single_qubit_gates: Sequence[cirq.Gate],
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _random_two_qubit_unitaries(num_samples: int, random_state: 'cirq.RANDOM_STA

prng = value.parse_random_state(random_state)
# Generate the non-local part by explict matrix exponentiation.
kak_vecs = prng.rand(num_samples, 3) * np.pi
kak_vecs = prng.random((num_samples, 3)) * np.pi
gens = np.einsum('...a,abc->...bc', kak_vecs, _kak_gens)
evals, evecs = np.linalg.eigh(gens)
A = np.einsum('...ab,...b,...cb', evecs, np.exp(1j * evals), evecs.conj())
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def measure(self, axes, seed=None):

class ExampleSimulationState(cirq.SimulationState):
def __init__(self, fallback_result: Any = NotImplemented):
super().__init__(prng=np.random.RandomState(), state=ExampleQuantumState())
super().__init__(prng=np.random.default_rng(), state=ExampleQuantumState())
self.fallback_result = fallback_result

def _act_on_fallback_(
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/qis/clifford_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def destabilizers(self) -> List['cirq.DensePauliString']:
generators above generate the full Pauli group on n qubits."""
return [self._row_to_dense_pauli(i) for i in range(self.n)]

def _measure(self, q, prng: np.random.RandomState) -> int:
def _measure(self, q, prng: np.random.Generator) -> int:
"""Performs a projective measurement on the q'th qubit.

Returns: the result (0 or 1) of the measurement.
Expand Down Expand Up @@ -544,7 +544,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int:

self.zs[p, q] = True

self.rs[p] = bool(prng.randint(2))
self.rs[p] = bool(prng.integers(2))

return int(self.rs[p])

Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def state_vector(self):

def apply_unitary(self, op: 'cirq.Operation'):
ch_form_args = clifford.StabilizerChFormSimulationState(
prng=np.random.RandomState(), qubits=self.qubit_map.keys(), initial_state=self.ch_form
prng=np.random.default_rng(), qubits=self.qubit_map.keys(), initial_state=self.ch_form
)
try:
act_on(op, ch_form_args)
Expand All @@ -254,7 +254,7 @@ def apply_measurement(
self,
op: 'cirq.Operation',
measurements: Dict[str, List[int]],
prng: np.random.RandomState,
prng: Union[np.random.Generator, np.random.RandomState],
collapse_state_vector=True,
):
if not isinstance(op.gate, cirq.MeasurementGate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""A protocol for implementing high performance clifford tableau evolutions
for Clifford Simulator."""

from typing import Optional, Sequence, TYPE_CHECKING
from typing import Optional, Sequence, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -31,7 +31,7 @@ class CliffordTableauSimulationState(StabilizerSimulationState[clifford_tableau.
def __init__(
self,
tableau: 'cirq.CliffordTableau',
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StabilizerChFormSimulationState(
def __init__(
self,
*,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
*,
state: TStabilizerState,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def to_state_vector(self) -> np.ndarray:

return arr

def _measure(self, q, prng: np.random.RandomState) -> int:
def _measure(self, q, prng: np.random.Generator) -> int:
"""Measures the q'th qubit.

Reference: Section 4.1 "Simulating measurements"
Expand All @@ -246,7 +246,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int:
w = self.s.copy()
for i, v_i in enumerate(self.v):
if v_i == 1:
w[i] = bool(prng.randint(2))
w[i] = bool(prng.integers(2))
x_i = sum(w & self.G[q, :]) % 2
# Project the state to the above measurement outcome.
self.project_Z(q, x_i)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
self,
*,
available_buffer: Optional[List[np.ndarray]] = None,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0,
dtype: Type[np.complexfloating] = np.complex64,
Expand Down
9 changes: 6 additions & 3 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TypeVar,
TYPE_CHECKING,
Tuple,
Union,
)
from typing_extensions import Self

Expand All @@ -49,7 +50,7 @@ def __init__(
self,
*,
state: TState,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
Expand All @@ -70,12 +71,14 @@ def __init__(
classical_data = classical_data or value.ClassicalDataDictionaryStore()
super().__init__(qubits=qubits, classical_data=classical_data)
if prng is None:
prng = cast(np.random.RandomState, np.random)
prng = np.random.default_rng()
elif isinstance(prng, np.random.RandomState):
prng = np.random.default_rng(prng._bit_generator)
self._prng = prng
self._state = state

@property
def prng(self) -> np.random.RandomState:
def prng(self) -> np.random.Generator:
return self._prng

def measure(
Expand Down
29 changes: 23 additions & 6 deletions cirq-core/cirq/sim/simulator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Type,
TypeVar,
TYPE_CHECKING,
Union,
)

import numpy as np
Expand Down Expand Up @@ -93,21 +94,27 @@ def __init__(
*,
dtype: Type[np.complexfloating] = np.complex64,
noise: 'cirq.NOISE_MODEL_LIKE' = None,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
split_untangled_states: bool = False,
):
"""Initializes the simulator.

Args:
dtype: The `numpy.dtype` used by the simulation.
noise: A noise model to apply while simulating.
seed: The random seed to use for this simulator.
seed: The random seed or generator to use for this simulator.
split_untangled_states: If True, optimizes simulation by running
unentangled qubit sets independently and merging those states
at the end.
"""
self._dtype = dtype
self._prng = value.parse_random_state(seed)
if isinstance(seed, np.random.RandomState):
# Convert RandomState to Generator for backward compatibility
self._prng = np.random.Generator(seed._bit_generator)
elif isinstance(seed, np.random.Generator):
self._prng = seed
else:
self._prng = np.random.default_rng(seed)
self._noise = devices.NoiseModel.from_noise_model_like(noise)
self._split_untangled_states = split_untangled_states

Expand Down Expand Up @@ -228,6 +235,7 @@ def _run(
circuit: 'cirq.AbstractCircuit',
param_resolver: 'cirq.ParamResolver',
repetitions: int,
rng: Optional[np.random.Generator] = None,
) -> Dict[str, np.ndarray]:
"""See definition in `cirq.SimulatesSamples`."""
param_resolver = param_resolver or study.ParamResolver({})
Expand All @@ -254,7 +262,10 @@ def _run(
assert step_result is not None
measurement_ops = [cast(ops.GateOperation, op) for op in general_ops]
return step_result.sample_measurement_ops(
measurement_ops, repetitions, seed=self._prng, _allow_repeated=True
measurement_ops,
repetitions,
seed=rng if rng is not None else self._prng,
_allow_repeated=True,
)

records: Dict['cirq.MeasurementKey', List[Sequence[Sequence[int]]]] = {}
Expand Down Expand Up @@ -395,9 +406,15 @@ def sample(
self,
qubits: List['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
) -> np.ndarray:
return self._sim_state.sample(qubits, repetitions, seed)
if isinstance(seed, np.random.RandomState):
rng = np.random.Generator(seed._bit_generator)
elif isinstance(seed, np.random.Generator):
rng = seed
else:
rng = np.random.default_rng(seed)
return self._sim_state.sample(qubits, repetitions, rng)


class SimulationTrialResultBase(
Expand Down
26 changes: 26 additions & 0 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,29 @@ def test_inhomogeneous_measurement_count_padding():
results = sim.run(c, repetitions=10)
for i in range(10):
assert np.sum(results.records['m'][i, :, :]) == 1


def test_run_with_custom_rng():
sim = cirq.Simulator()
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
rng1 = np.random.default_rng(seed=1234)
rng2 = np.random.default_rng(seed=1234)

result1 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng1)
result2 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng2)
assert np.array_equal(result1['q(0)'], result2['q(0)'])

rng3 = np.random.default_rng(seed=5678)
result3 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng3)
assert not np.array_equal(result1['q(0)'], result3['q(0)'])


def test_run_with_explicit_rng_override():
sim1 = cirq.Simulator(seed=1234)
sim2 = cirq.Simulator(seed=5678)
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
rng = np.random.default_rng(1234)

result1 = sim1._run(circuit, cirq.ParamResolver({}), repetitions=10)
result2 = sim2._run(circuit, cirq.ParamResolver({}), repetitions=10, rng=rng)
assert np.array_equal(result1['q(0)'], result2['q(0)'])
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(
self,
*,
available_buffer: Optional[np.ndarray] = None,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0,
dtype: Type[np.complexfloating] = np.complex64,
Expand Down
Loading
Loading