diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 42cb5ed..3d284b7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,6 +20,8 @@ jobs: cache: pip - run: pip install pytest - run: pytest . --junitxml=junit/test_py${{ matrix.python-version }}_on_${{ matrix.os }}.xml + - run: python -O once_test.py + - run: python -OO once_test.py - name: Upload pytest test results uses: actions/upload-artifact@v3 if: success() || failure() diff --git a/once_test.py b/once_test.py index 6c10d10..627f11d 100644 --- a/once_test.py +++ b/once_test.py @@ -1,6 +1,7 @@ """Unit tests for once decorators.""" # pylint: disable=missing-function-docstring import asyncio +import collections.abc import concurrent.futures import functools import gc @@ -9,6 +10,7 @@ import sys import threading import unittest +import uuid import weakref import once @@ -30,6 +32,74 @@ async def anext(iter, default=StopAsyncIteration): _N_WORKERS = 32 +class WrappedException: + def __init__(self, exception): + self.exception = exception + + +def parallel_map( + test: unittest.TestCase, + func: collections.abc.Callable, + # would be collections.abc.Iterable[tuple] | None on py >= 3.10 + call_args=None, + n_threads: int = _N_WORKERS, + timeout: float = 10.0, +) -> list: + """Run a function multiple times in parallel. + + We ensure that N parallel tasks are all launched at the "same time", which + means all have parallel threads which are released to the GIL to execute at + the same time. + Why? + We can't rely on the thread pool excector to always spin up the full list of _N_WORKERS. + In pypy, we have observed that even with blocked tasks, the same thread executes multiple + function calls. This lets us handle the scheduling in a predictable way for testing. + """ + if call_args is None: + call_args = (tuple() for _ in range(n_threads)) + + batches = [[] for i in range(n_threads)] # type: list[list[tuple[int, tuple]]] + for i, call_args in enumerate(call_args): + if not isinstance(call_args, tuple): + raise TypeError("call arguments must be a tuple") + batches[i % n_threads].append((i, call_args)) + n_calls = i + 1 # len(call_args), but it is now an exhuasted iterator. + unset = object() + results_lock = threading.Lock() + results = [unset for _ in range(n_calls)] + + # This barrier is used to ensure that all calls release together, after this function has + # completed its setup of creating them. + start_barrier = threading.Barrier(min(n_threads, n_calls)) + + def wrapped_fn(batch): + start_barrier.wait() + for index, args in batch: + try: + result = func(*args) + except Exception as e: + result = WrappedException(e) + with results_lock: + results[index] = result + + # We manually set thread names for easier debugging. + invocation_id = str(uuid.uuid4()) + threads = [ + threading.Thread(target=wrapped_fn, args=[batch], name=f"{test.id()}-{i}-{invocation_id}") + for i, batch in enumerate(batches) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=timeout) + for i, result in enumerate(results): + if result is unset: + test.fail(f"Call {i} did not complete succesfully") + elif isinstance(result, WrappedException): + raise result.exception + return results + + class Counter: """Holding object for a counter. @@ -429,16 +499,17 @@ def sample_failing_fn(): def test_iterator_parallel_execution(self): counter = Counter() - # Must be called over an integer multiple of _N_WORKERS - @execute_with_barrier(n_workers=_N_WORKERS) @once.once def yielding_iterator(): nonlocal counter for _ in range(3): yield counter.get_incremented() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(lambda _: list(yielding_iterator()), range(_N_WORKERS * 2)) + results = parallel_map( + self, + lambda: list(yielding_iterator()), + (tuple() for _ in range(_N_WORKERS * 2)), + ) for result in results: self.assertEqual(result, [1, 2, 3]) @@ -470,10 +541,7 @@ def yielding_iterator(): def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() - barrier_counting_fn = execute_with_barrier(counting_fn, n_workers=_N_WORKERS) - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results_generator = executor.map(barrier_counting_fn, range(_N_WORKERS)) - results = list(results_generator) + results = parallel_map(self, counting_fn) self.assertEqual(len(results), _N_WORKERS) for r in results: self.assertEqual(r, 1) @@ -482,7 +550,6 @@ def test_threaded_single_function(self): def test_once_per_thread(self): counter = Counter() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race @once.once(per_thread=True) @execute_with_barrier(n_workers=_N_WORKERS) def counting_fn(*args) -> int: @@ -491,8 +558,7 @@ def counting_fn(*args) -> int: del args return counter.get_incremented() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(counting_fn, range(_N_WORKERS * 4))) + results = parallel_map(self, counting_fn, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), _N_WORKERS) @@ -503,17 +569,13 @@ def test_threaded_multiple_functions(self): for _ in range(4): cfn, counter = generate_once_counter_fn() counters.append(counter) - fns.append(execute_with_barrier(cfn, n_workers=_N_WORKERS)) - - promises = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - for cfn in fns: - for _ in range(_N_WORKERS): - promises.append(executor.submit(cfn)) - del cfn - fns.clear() - for promise in promises: - self.assertEqual(promise.result(), 1) + fns.append(cfn) + + def call_all_functions(i): + for j in range(i, i + 4): + self.assertEqual(fns[j % 4](), 1) + + parallel_map(self, call_all_functions, ((i,) for i in range(_N_WORKERS))) for counter in counters: self.assertEqual(counter.value, 1) @@ -575,16 +637,22 @@ def closure(): self.assertIsNone(ephemeral_ref()) def test_function_signature_preserved(self): - @once.once def type_annotated_fn(arg: float) -> int: """Very descriptive docstring.""" del arg return 1 - sig = inspect.signature(type_annotated_fn) - self.assertIs(sig.parameters["arg"].annotation, float) - self.assertIs(sig.return_annotation, int) - self.assertEqual(type_annotated_fn.__doc__, "Very descriptive docstring.") + decorated_function = once.once(type_annotated_fn) + original_sig = inspect.signature(type_annotated_fn) + decorated_sig = inspect.signature(decorated_function) + self.assertIs(original_sig.parameters["arg"].annotation, float) + self.assertIs(decorated_sig.parameters["arg"].annotation, float) + self.assertIs(original_sig.return_annotation, int) + self.assertIs(decorated_sig.return_annotation, int) + self.assertEqual(inspect.getdoc(type_annotated_fn), inspect.getdoc(decorated_function)) + if sys.flags.optimize >= 2: + self.skipTest("docstrings get stripped with -OO") + self.assertEqual(inspect.getdoc(type_annotated_fn), "Very descriptive docstring.") def test_once_per_class(self): class _CallOnceClass(Counter): @@ -608,12 +676,10 @@ def once_fn(self): once_obj = _CallOnceClass() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race - def execute(_): + def execute(): return once_obj.once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), 1) @@ -626,12 +692,10 @@ def once_fn(self): once_obj = _CallOnceClass() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race - def execute(_): + def execute(): return once_obj.once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, (tuple() for _ in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), _N_WORKERS) @@ -686,18 +750,30 @@ def value(self): # pylint: disable=inconsistent-return-statements a = _CallOnceClass("a", self) # pylint: disable=invalid-name b = _CallOnceClass("b", self) # pylint: disable=invalid-name - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - a_jobs = [executor.submit(a.value) for _ in range(_N_WORKERS // 2)] - b_jobs = [executor.submit(b.value) for _ in range(_N_WORKERS // 2)] - for a_job in a_jobs: - self.assertEqual(a_job.result(), "a") - for b_job in b_jobs: - self.assertEqual(b_job.result(), "b") - - self.assertEqual(a.value(), "a") - self.assertEqual(a.value(), "a") - self.assertEqual(b.value(), "b") - self.assertEqual(b.value(), "b") + def call_and_check_both(i: int): + # Run in different order based on the call + if i % 4 == 0: + self.assertEqual(a.value(), "a") + self.assertEqual(a.value(), "a") + self.assertEqual(b.value(), "b") + self.assertEqual(b.value(), "b") + elif i % 4 == 1: + self.assertEqual(a.value(), "a") + self.assertEqual(b.value(), "b") + self.assertEqual(a.value(), "a") + self.assertEqual(b.value(), "b") + elif i % 4 == 2: + self.assertEqual(b.value(), "b") + self.assertEqual(a.value(), "a") + self.assertEqual(b.value(), "b") + self.assertEqual(a.value(), "a") + else: + self.assertEqual(b.value(), "b") + self.assertEqual(b.value(), "b") + self.assertEqual(a.value(), "a") + self.assertEqual(a.value(), "a") + + parallel_map(self, call_and_check_both, ((i,) for i in range(_N_WORKERS))) def test_once_per_instance_do_not_block_each_other(self): class _BlockableClass: @@ -736,12 +812,10 @@ def once_fn(self): once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()] - @execute_with_barrier(n_workers=_N_WORKERS) def execute(i): return once_objs[i % 4].once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS * 4))) + results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS * 4))) self.assertEqual(min(results), 1) self.assertEqual(max(results), 1) @@ -754,12 +828,10 @@ def once_fn(self): once_objs = [_CallOnceClass(), _CallOnceClass(), _CallOnceClass(), _CallOnceClass()] - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race def execute(i): return once_objs[i % 4].once_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS))) + results = parallel_map(self, execute, ((i,) for i in range(_N_WORKERS))) self.assertEqual(min(results), 1) self.assertEqual(max(results), math.ceil(_N_WORKERS / 4)) @@ -819,7 +891,7 @@ def receiving_iterator(): barrier = threading.Barrier(_N_WORKERS) - def call_iterator(_): + def call_iterator(): gen = receiving_iterator() result = [] barrier.wait() @@ -828,8 +900,7 @@ def call_iterator(_): result.append(gen.send(i)) return result - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(call_iterator, range(_N_WORKERS)) + results = parallel_map(self, call_iterator) for result in results: self.assertEqual(result, list(range(_N_WORKERS * 4))) @@ -854,8 +925,7 @@ def call_iterator(n): # Unlike the previous test, each execution should yield lists of different lengths. # This ensures that the iterator does not hang, even if not exhausted - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = executor.map(call_iterator, range(1, _N_WORKERS + 1)) + results = parallel_map(self, call_iterator, ((i,) for i in range(1, _N_WORKERS + 1))) for i, result in enumerate(results): self.assertEqual(result, list(range(i + 1))) @@ -893,16 +963,23 @@ async def counting_fn(*args) -> int: del args return counter.get_incremented() - @execute_with_barrier(n_workers=_N_WORKERS) # increases chance of a race + results_lock = asyncio.Lock() + results = [] + + async def counting_fn_multiple_caller(*args): + """Calls counting_fn() multiple times ensuring identical result.""" + result = await counting_fn() + for i in range(5): + self.assertEqual(await counting_fn(), result) + async with results_lock: + results.append(result) + return result + def execute(*args): - coro = counting_fn(*args) + coro = counting_fn_multiple_caller(*args) return asyncio.run(coro) - with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: - results = list(executor.map(execute, range(_N_WORKERS))) - self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) - results = list(executor.map(execute, range(_N_WORKERS))) - self.assertEqual(sorted(results), list(range(1, _N_WORKERS + 1))) + parallel_map(self, execute) async def test_failing_function(self): counter = Counter()