diff --git a/once/__init__.py b/once/__init__.py index 5d1c67a..81d29f9 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -12,6 +12,7 @@ import weakref from . import _iterator_wrappers +from . import _state def _is_method(func: collections.abc.Callable): @@ -91,10 +92,107 @@ def return_value(self, value: typing.Any) -> None: self._return_value = value +# Instead of just passing in a state, we generally use a state_factory +# function, which returns a state. This lets us implement a version which +# returns a unique state per thread to implement per_thread, or the same object +# for a globally unique once. if sys.version_info.minor > 8: - _ONCE_FACTORY_TYPE = collections.abc.Callable[[], _OnceBase] + _STATE_FACTORY_TYPE = collections.abc.Callable[[], _state._CallState] else: - _ONCE_FACTORY_TYPE = collections.abc.Callable # type: ignore + _STATE_FACTORY_TYPE = collections.abc.Callable # type: ignore + + +class _OnceCallableBase: + + def __init__( + self, + func: collections.abc.Callable, + state_factory: _STATE_FACTORY_TYPE, + retry_exceptions: bool, + ): + self._func = func + self._state_factory = state_factory + self._retry_exceptions = retry_exceptions + + +class _OnceCallableSyncBase(_OnceCallableBase): + def reset(self): + with self._state_factory().lock: + self._state_factory().reset() + + +class _OnceCallableAsyncBase(_OnceCallableBase): + async def reset(self): + async with self._state_factory().async_lock: + self._state_factory().reset() + + +class _OnceCallableSyncFunction(_OnceCallableSyncBase): + def __call__(self, *args, **kwargs): + call_state = self._state_factory() + with call_state.lock: + if not call_state.called: + try: + call_state.return_value = self._func(*args, **kwargs) + except Exception as exception: + if self._retry_exceptions: + raise exception + call_state.return_value = _CachedException(exception) + call_state.called = True + return_value = call_state.return_value + if isinstance(return_value, _CachedException): + raise return_value.exception + return return_value + + +class _OnceCallableSyncGenerator(_OnceCallableSyncBase): + def __call__(self, *args, **kwargs): + call_state = self._state_factory() + with call_state.lock: + if not call_state.called: + call_state.return_value = _iterator_wrappers.GeneratorWrapper( + self._retry_exceptions, self._func, *args, **kwargs + ) + call_state.called = True + iterator = call_state.return_value + yield from iterator.yield_results() + + +class _OnceCallableAsyncFunction(_OnceCallableAsyncBase): + async def __call__(self, *args, **kwargs): + call_state = self._state_factory() + async with call_state.async_lock: + if not call_state.called: + try: + call_state.return_value = await self._func(*args, **kwargs) + except Exception as exception: + if self._retry_exceptions: + raise exception + call_state.return_value = _CachedException(exception) + call_state.called = True + return_value = call_state.return_value + if isinstance(return_value, _CachedException): + raise return_value.exception + return return_value + + +class _OnceCallableAsyncGenerator(_OnceCallableAsyncBase): + async def __call__(self, *args, **kwargs): + call_state = self._state_factory() + async with call_state.async_lock: + if not call_state.called: + call_state.return_value = _iterator_wrappers.AsyncGeneratorWrapper( + self._retry_exceptions, self._func, *args, **kwargs + ) + call_state.called = True + return_value = call_state.return_value + next_value = None + iterator = return_value.yield_results() + while True: + try: + next_value = yield await iterator.asend(next_value) + except StopAsyncIteration: + return class _CachedException: @@ -102,16 +200,21 @@ def __init__(self, exception: Exception): self.exception = exception +def _not_allow_reset(): + raise RuntimeError("function was not created with allow_reset flag.") + + def _wrap( func: collections.abc.Callable, - once_factory: _ONCE_FACTORY_TYPE, + state_factory: _STATE_FACTORY_TYPE, fn_type: _WrappedFunctionType, retry_exceptions: bool, + allow_reset: bool, ) -> collections.abc.Callable: """Generate a wrapped function appropriate to the function type. - The once_factory lets us reuse logic for both per-thread and singleton. - For a singleton, the factory always returns the same _OnceBase object, but + The state_factory lets us reuse logic for both per-thread and singleton. + For a singleton, the factory always returns the same _CallState object, but for per thread, it would return a unique one for each thread. """ # Theoretically, we could compute fn_type now. However, this code may be executed at runtime @@ -119,148 +222,44 @@ def _wrap( # definition time, so we force the caller to pass it in. But, if we're in debug mode, why not # check it again? assert fn_type == _wrapped_function_type(func) - wrapped: collections.abc.Callable + once_callable: _OnceCallableSyncBase | _OnceCallableAsyncBase if fn_type == _WrappedFunctionType.ASYNC_GENERATOR: - - async def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - async with once_base.async_lock: - if not once_base.called: - once_base.return_value = _iterator_wrappers.AsyncGeneratorWrapper( - retry_exceptions, - func, - *args, - allow_reset=once_base.allow_reset, - **kwargs, - ) - once_base.called = True - return_value = once_base.return_value - next_value = None - iterator = return_value.yield_results() - while True: - try: - next_value = yield await iterator.asend(next_value) - except StopAsyncIteration: - return - + once_callable = _OnceCallableAsyncGenerator(func, state_factory, retry_exceptions) elif fn_type == _WrappedFunctionType.ASYNC_FUNCTION: - - async def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - async with once_base.async_lock: - if not once_base.called: - try: - once_base.return_value = await func(*args, **kwargs) - except Exception as exception: - if retry_exceptions: - raise exception - once_base.return_value = _CachedException(exception) - once_base.called = True - return_value = once_base.return_value - if isinstance(return_value, _CachedException): - raise return_value.exception - return return_value - + once_callable = _OnceCallableAsyncFunction(func, state_factory, retry_exceptions) elif fn_type == _WrappedFunctionType.SYNC_FUNCTION: - - def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - with once_base.lock: - if not once_base.called: - try: - once_base.return_value = func(*args, **kwargs) - except Exception as exception: - if retry_exceptions: - raise exception - once_base.return_value = _CachedException(exception) - once_base.called = True - return_value = once_base.return_value - if isinstance(return_value, _CachedException): - raise return_value.exception - return return_value - + once_callable = _OnceCallableSyncFunction(func, state_factory, retry_exceptions) elif fn_type == _WrappedFunctionType.SYNC_GENERATOR: - - def wrapped(*args, **kwargs) -> typing.Any: - once_base: _OnceBase = once_factory() - with once_base.lock: - if not once_base.called: - once_base.return_value = _iterator_wrappers.GeneratorWrapper( - retry_exceptions, - func, - *args, - allow_reset=once_base.allow_reset, - **kwargs, - ) - once_base.called = True - iterator = once_base.return_value - yield from iterator.yield_results() - + once_callable = _OnceCallableSyncGenerator(func, state_factory, retry_exceptions) else: raise NotImplementedError() - def reset(): - once_base: _OnceBase = once_factory() - with once_base.lock: - if not once_base.called: - return - if fn_type == _WrappedFunctionType.SYNC_GENERATOR: - iterator = once_base.return_value - with iterator.lock: - iterator.reset() - else: - once_base.called = False - - async def async_reset(): - once_base: _OnceBase = once_factory() - async with once_base.async_lock: - if not once_base.called: - return - if fn_type == _WrappedFunctionType.ASYNC_GENERATOR: - iterator = once_base.return_value - async with iterator.lock: - iterator.reset() - else: - once_base.called = False - - def not_allowed_reset(): - # This doesn't need to be awaitable even in the async case because it will - # raise the error before an `await` has a chance to do anything. - raise RuntimeError( - f"reset() is not allowed to be called on onced function {func}.\n" - "Did you mean to add `allow_reset=True` to your once.once() annotation?" - ) - - # No need for the lock here since we're the only thread that could be running, - # since we haven't even finished wrapping the func yet. - once_base: _OnceBase = once_factory() - if not once_base.allow_reset: - wrapped.reset = not_allowed_reset # type: ignore - else: - if once_base.is_async: - wrapped.reset = async_reset # type: ignore - else: - wrapped.reset = reset # type: ignore - + # We return the class which exposes the reset function only if resettable, + # otherwise we just return the function. + wrapped = functools.partial(once_callable.__class__.__call__, once_callable) # type: ignore functools.update_wrapper(wrapped, func) + if allow_reset: + wrapped.reset = once_callable.reset # type: ignore + else: + wrapped.reset = _not_allow_reset # type: ignore return wrapped -def _once_factory(is_async: bool, per_thread: bool, allow_reset: bool) -> _ONCE_FACTORY_TYPE: +def _state_factory(is_async: bool, per_thread: bool) -> _STATE_FACTORY_TYPE: if not per_thread: - singleton_once = _OnceBase(is_async, allow_reset=allow_reset) - return lambda: singleton_once + singleton_state = _state._CallState(is_async) + return lambda: singleton_state - per_thread_onces = threading.local() + per_thread_states = threading.local() def _get_once_per_thread(): # Read then modify is thread-safe without a lock because each thread sees its own copy of - # copy of `per_thread_onces` thanks to `threading.local`, and each thread cannot race with + # copy of `per_thread_states` thanks to `threading.local`, and each thread cannot race with # itself! - if once := getattr(per_thread_onces, "once", None): - return once - per_thread_onces.once = _OnceBase(is_async, allow_reset=allow_reset) - return per_thread_onces.once + if state := getattr(per_thread_states, "state", None): + return state + per_thread_states.state = _state._CallState(is_async) + return per_thread_states.state return _get_once_per_thread @@ -324,12 +323,8 @@ def once( "instead of @once.once_per_class or @once.once_per_instance" ) fn_type = _wrapped_function_type(func) - once_factory = _once_factory( - is_async=fn_type in _ASYNC_FN_TYPES, - per_thread=per_thread, - allow_reset=allow_reset, - ) - return _wrap(func, once_factory, fn_type, retry_exceptions) + state_factory = _state_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread) + return _wrap(func, state_factory, fn_type, retry_exceptions, allow_reset) class once_per_class: # pylint: disable=invalid-name @@ -341,10 +336,7 @@ class once_per_class: # pylint: disable=invalid-name @classmethod def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False): return lambda func: cls( - func, - per_thread=per_thread, - retry_exceptions=retry_exceptions, - allow_reset=allow_reset, + func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_reset=allow_reset ) def __init__( @@ -356,12 +348,11 @@ def __init__( ) -> None: self.func = self._inspect_function(func) self.fn_type = _wrapped_function_type(self.func) - self.once_factory = _once_factory( - is_async=self.fn_type in _ASYNC_FN_TYPES, - per_thread=per_thread, - allow_reset=allow_reset, + self.state_factory = _state_factory( + is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread ) self.retry_exceptions = retry_exceptions + self.allow_reset = allow_reset def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): @@ -390,7 +381,9 @@ def __get__(self, obj, cls) -> collections.abc.Callable: func = self.func else: func = functools.partial(self.func, obj) - return _wrap(func, self.once_factory, self.fn_type, self.retry_exceptions) + return _wrap( + func, self.state_factory, self.fn_type, self.retry_exceptions, self.allow_reset + ) class once_per_instance: # pylint: disable=invalid-name @@ -401,7 +394,7 @@ class once_per_instance: # pylint: disable=invalid-name @classmethod def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False): return lambda func: cls( - func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_reset=False + func, per_thread=per_thread, retry_exceptions=retry_exceptions, allow_reset=allow_reset ) def __init__( @@ -422,14 +415,12 @@ def __init__( self.retry_exceptions = retry_exceptions self.allow_reset = allow_reset - def once_factory(self) -> _ONCE_FACTORY_TYPE: - """Generate a new once factory. + def _state_factory(self) -> _STATE_FACTORY_TYPE: + """Generate a new state factory. - A once factory factory if you will. + A state factory factory if you will. """ - return _once_factory( - self.is_async_fn, per_thread=self.per_thread, allow_reset=self.allow_reset - ) + return _state_factory(self.is_async_fn, per_thread=self.per_thread) def _inspect_function(self, func: collections.abc.Callable): if isinstance(func, (classmethod, staticmethod)): @@ -451,12 +442,16 @@ def _inspect_function(self, func: collections.abc.Callable): def __get__(self, obj, cls) -> collections.abc.Callable: del cls with self.callables_lock: - if (callable := self.callables.get(obj)) is None: + if (bound_callable := self.callables.get(obj)) is None: bound_func = functools.partial(self.func, obj) - callable = _wrap( - bound_func, self.once_factory(), self.fn_type, self.retry_exceptions + bound_callable = _wrap( + bound_func, + self._state_factory(), + self.fn_type, + self.retry_exceptions, + self.allow_reset, ) - self.callables[obj] = callable + self.callables[obj] = bound_callable if self.is_property: - return callable() - return callable + return bound_callable() + return bound_callable diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 8416ff0..d66ed61 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -4,21 +4,8 @@ import enum import functools import threading -import time import typing -# Before we begin, a note on the assert statements in this file: -# Why are we using assert in here, you might ask, instead of implementing "proper" error handling? -# In this case, it is actually not being done out of laziness! The assert statements here -# represent our assumptions about the state at that point in time, and are always called with locks -# held, so they **REALLY** should always hold. If the assumption behind one of these asserts fails, -# the subsequent calls are going to fail anyways, so it's not like they are making the code -# artificially brittle. However, they do make testing easer, because we can directly test our -# assumption instead of having hard-to-trace errors, and also serve as very convenient -# documentation of the assumptions. -# We are always open to suggestions if there are other ways to achieve the same functionality in -# python! - class IteratorResults: def __init__(self) -> None: @@ -217,7 +204,7 @@ async def yield_results(self) -> collections.abc.AsyncGenerator: except StopAsyncIteration: async with self.lock: self.record_successful_completion(result) - except Exception as e: + except Exception as e: # pylint: disable=broad-except async with self.lock: self.record_exception(result, e) else: @@ -278,7 +265,7 @@ def yield_results(self) -> collections.abc.Generator: except StopIteration: with self.lock: self.record_successful_completion(result) - except Exception as e: + except Exception as e: # pylint: disable=broad-except with self.lock: self.record_exception(result, e) else: diff --git a/once/_state.py b/once/_state.py new file mode 100644 index 0000000..c8ceb6a --- /dev/null +++ b/once/_state.py @@ -0,0 +1,70 @@ +"""The internal _CallState class holds the call state and return value. + +This is a simple data class, and could have been implemented as a tuple. +However, it has a lock, and ensures its properties are called while they are +held, and also defines a reset method. +""" + +import asyncio +import typing +import threading + +# Before we begin, a note on the assert statements in this file: +# Why are we using assert in here, you might ask, instead of implementing "proper" error handling? +# In this case, it is actually not being done out of laziness! The assert statements here +# represent our assumptions about the state at that point in time, and are always called with locks +# held, so they **REALLY** should always hold. If the assumption behind one of these asserts fails, +# the subsequent calls are going to fail anyways, so it's not like they are making the code +# artificially brittle. However, they do make testing easer, because we can directly test our +# assumption instead of having hard-to-trace errors, and also serve as very convenient +# documentation of the assumptions. +# We are always open to suggestions if there are other ways to achieve the same functionality in +# python! + + +class _CallState: + + def __init__(self, is_async: bool) -> None: + self.is_async = is_async + # We are going to be extra pedantic about these next two variables only being read or set + # with a lock by defining getters and setters which enforce that the lock is held. If this + # was C++, we would use something like the ABSL_GUARDED_BY macro for compile-time checking + # (https://github.com/abseil/abseil-cpp/blob/master/absl/base/thread_annotations.h), but + # this is python :) + self._called = False + self._return_value: typing.Any = None + if self.is_async: + self.async_lock = asyncio.Lock() + else: + self.lock = threading.Lock() + + def _locked(self) -> bool: + return self.async_lock.locked() if self.is_async else self.lock.locked() + + @property + def called(self) -> bool: + """Indicates if the function has been called.""" + assert self._locked() + return self._called + + @called.setter + def called(self, state: bool) -> None: + assert self._locked() + self._called = state + + @property + def return_value(self) -> typing.Any: + """Stores the returned value of the function.""" + assert self._locked() + return self._return_value + + @return_value.setter + def return_value(self, value: typing.Any) -> None: + assert self._locked() + self._return_value = value + + def reset(self): + """Resets the state back to uncalled.""" + assert self._locked() + self._called = False + self._return_value = None