diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index d0ee7aae51d..aaf2343b90f 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -75,9 +75,20 @@ def __init__(self, param_dict: cirq.ParamResolverOrSimilarType = None) -> None: self._param_hash: int | None = None self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) + self._param_dict_with_str_keys = self._param_dict + generate_str_keys = False for key in self._param_dict: - if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): - raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') + if isinstance(key, sympy.Expr): + if isinstance(key, sympy.Symbol): + generate_str_keys = True + else: + raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') + if generate_str_keys: + # Remake dictionary with string keys for faster access + self._param_dict_with_str_keys = { + (key.name if isinstance(key, sympy.Symbol) else key): value + for key, value in self._param_dict.items() + } self._deep_eval_map: ParamDictType = {} @property @@ -119,22 +130,23 @@ def value_of( """ # Handle string or symbol - if isinstance(value, (str, sympy.Symbol)): - string = value if isinstance(value, str) else value.name - param_value = self._param_dict.get(string, _NOT_FOUND) + original_value = value + if isinstance(value, sympy.Symbol): + value = value.name + if isinstance(value, str): + param_value = self._param_dict_with_str_keys.get(value, _NOT_FOUND) + if isinstance(param_value, float): + return param_value if param_value is _NOT_FOUND: - symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value) - param_value = self._param_dict.get(symbol, _NOT_FOUND) - if param_value is _NOT_FOUND: - # Symbol or string cannot be resolved if not in param dict; return as symbol. - return symbol + # Symbol or string cannot be resolved if not in param dict; return as symbol. + return sympy.Symbol(value) v = _resolve_value(param_value) if v is not NotImplemented: return v if isinstance(param_value, str): param_value = sympy.Symbol(param_value) elif not isinstance(param_value, sympy.Basic): - return value + return original_value if recursive: param_value = self._value_of_recursive(value) return param_value @@ -210,7 +222,7 @@ def _value_of_recursive(self, value: cirq.TParamKey) -> cirq.TParamValComplex: self._deep_eval_map[value] = _RECURSION_FLAG v = self.value_of(value, recursive=False) - if v == value: + if v == value or (isinstance(v, sympy.Symbol) and v.name == value): self._deep_eval_map[value] = v else: self._deep_eval_map[value] = self.value_of(v, recursive=True) @@ -278,7 +290,7 @@ def _from_json_dict_(cls, param_dict, **kwargs): def _resolve_value(val: Any) -> Any: - if val is None or isinstance(val, float): + if isinstance(val, float) or val is None: return val if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic): return val