diff --git a/dill/__diff.py b/dill/__diff.py index 63e81826..3818b8f1 100644 --- a/dill/__diff.py +++ b/dill/__diff.py @@ -41,7 +41,10 @@ def get_attrs(obj): if type(obj) in builtins_types \ or type(obj) is type and obj in builtins_types: return - return getattr(obj, '__dict__', None) + try: + return getattr(obj, '__dict__', None) + except ReferenceError: + return None def get_seq(obj, cache={str: False, frozenset: False, list: True, set: True, @@ -53,6 +56,8 @@ def get_seq(obj, cache={str: False, frozenset: False, list: True, set: True, """ try: o_type = obj.__class__ + except ReferenceError: + return None except AttributeError: o_type = type(obj) hsattr = hasattr diff --git a/dill/_dill.py b/dill/_dill.py index aec297c4..b0b45103 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -86,7 +86,7 @@ from weakref import ReferenceType, ProxyType, CallableProxyType from collections import OrderedDict from enum import Enum, EnumMeta -from functools import partial +from functools import partial, wraps from operator import itemgetter, attrgetter GENERATOR_FAIL = False import importlib.machinery @@ -372,6 +372,16 @@ def __init__(self, file, *args, **kwds): self._postproc = OrderedDict() self._file = file + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + _ensure_legacy_batch_setitems_support(cls) + + def _batch_setitems(self, items, obj=None): + parent_method = super()._batch_setitems + if _batch_setitems_accepts_obj(parent_method): + return parent_method(items, obj) + return parent_method(items) + def save(self, obj, save_persistent_id=True): # numpy hack obj_type = type(obj) @@ -954,6 +964,7 @@ def __ror__(self, a): # to _create_cell) once breaking changes are allowed. _CELL_REF = None _CELL_EMPTY = Sentinel('_CELL_EMPTY') +_BATCH_SETITEMS_OBJ_SENTINEL = Sentinel('_dill_batch_setitems_obj') def _create_cell(contents=None): if contents is not _CELL_EMPTY: @@ -1105,6 +1116,87 @@ def _setitems(dest, source): dest[k] = v +def _ensure_legacy_batch_setitems_support(cls): + """Wrap subclasses overriding `_batch_setitems` with the legacy signature.""" + method = cls.__dict__.get('_batch_setitems') + if method is None: + return + + if getattr(method, '__dill_legacy_batch_setitems__', False): + return + + try: + params = list(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return + + # For unbound methods the signature includes ``self`` as the first entry. + if len(params) != 2: + return + + _, items_param = params + if items_param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + return + + @wraps(method) + def wrapper(self, items, obj=None): + previous = getattr(self, '_dill_batch_setitems_obj', _BATCH_SETITEMS_OBJ_SENTINEL) + self._dill_batch_setitems_obj = obj + try: + return method(self, items) + finally: + if previous is _BATCH_SETITEMS_OBJ_SENTINEL: + delattr(self, '_dill_batch_setitems_obj') + else: + self._dill_batch_setitems_obj = previous + + wrapper.__dill_legacy_batch_setitems__ = True + setattr(cls, '_batch_setitems', wrapper) + + +def _batch_setitems_accepts_obj(batch_setitems): + """Return True if ``batch_setitems`` supports an ``obj`` argument.""" + try: + params = list(inspect.signature(batch_setitems).parameters.values()) + except (TypeError, ValueError): + # Built-in or C-implemented callables may not expose a signature. + return sys.hexversion >= 0x30e00a1 + + if not params: + return False + + # Bound methods drop ``self`` from the signature, so the first parameter + # represents the ``items`` iterator. Any additional parameter (or varargs) + # means the callable can accept ``obj``. + extras = params[1:] + if not extras: + return False + + for param in extras: + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.VAR_KEYWORD, + ): + return True + + return False + + +def _call_batch_setitems(pickler, items_iter, obj): + """Invoke ``pickler._batch_setitems`` with compatibility for older overrides.""" + batch_setitems = pickler._batch_setitems + if _batch_setitems_accepts_obj(batch_setitems): + batch_setitems(items_iter, obj=obj) + else: + batch_setitems(items_iter) + + def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None): if obj is Getattr.NO_DEFAULT: obj = Reduce(reduction) # pragma: no cover @@ -1145,7 +1237,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO if sys.hexversion < 0x30e00a1: pickler._batch_setitems(iter(source.items())) else: - pickler._batch_setitems(iter(source.items()), obj=obj) + _call_batch_setitems(pickler, iter(source.items()), obj=obj) else: # Updating with an empty dictionary. Same as doing nothing. continue diff --git a/dill/_objects.py b/dill/_objects.py index 0dc6f4f9..b268c08d 100644 --- a/dill/_objects.py +++ b/dill/_objects.py @@ -49,6 +49,7 @@ #import __hello__ import threading import socket +import atexit import contextlib import contextvars try: @@ -330,8 +331,13 @@ class _Struct(ctypes.Structure): a['NamedLoggerType'] = _logger = logging.getLogger(__name__) #a['FrozenModuleType'] = __hello__ #FIXME: prints "Hello world..." # interprocess communication (CH 17) -x['SocketType'] = _socket = socket.socket() -x['SocketPairType'] = socket.socketpair()[0] +_socket = socket.socket() +x['SocketType'] = _socket +atexit.register(_socket.close) +_socket_pair = socket.socketpair() +x['SocketPairType'] = _socket_pair[0] +_socket_pair[1].close() +atexit.register(_socket_pair[0].close) # python runtime services (CH 27) a['GeneratorContextManagerType'] = contextlib.contextmanager(max)([1]) #a['ContextType'] = contextvars.Context() #XXX: ContextVar diff --git a/dill/session.py b/dill/session.py index 8278ccdd..4bc43e76 100644 --- a/dill/session.py +++ b/dill/session.py @@ -43,12 +43,12 @@ def _module_map(): by_id=defaultdict(list), top_level={}, ) - for modname, module in sys.modules.items(): + for modname, module in list(sys.modules.items()): if modname in ('__main__', '__mp_main__') or not isinstance(module, ModuleType): continue if '.' not in modname: modmap.top_level[id(module)] = modname - for objname, modobj in module.__dict__.items(): + for objname, modobj in list(module.__dict__.items()): modmap.by_name[objname].append((modobj, modname)) modmap.by_id[id(modobj)].append((modobj, objname, modname)) return modmap diff --git a/dill/tests/conftest.py b/dill/tests/conftest.py new file mode 100644 index 00000000..c89ddb36 --- /dev/null +++ b/dill/tests/conftest.py @@ -0,0 +1,55 @@ +"""Pytest configuration helpers for dill's legacy test suite.""" + +import logging +import os +import sys +from typing import Iterator + +import pytest + + +_TEST_DIR = os.path.dirname(__file__) +if _TEST_DIR and _TEST_DIR not in sys.path: + sys.path.insert(0, _TEST_DIR) + + +@pytest.fixture(params=[False, True]) +def should_trace(request): + """Toggle dill's pickling trace for logger tests.""" + + from dill import detect + from dill.logger import adapter as logger + + original_level = logger.logger.level + detect.trace(request.param) + try: + yield request.param + finally: + detect.trace(False) + logger.setLevel(original_level) + + +@pytest.fixture +def stream_trace() -> Iterator[str]: + """Capture the trace output produced while pickling ``test_obj``.""" + + from io import StringIO + + import dill + from dill import detect + from dill.logger import adapter as logger + from dill.tests.test_logger import test_obj + + buffer = StringIO() + handler = logging.StreamHandler(buffer) + logger.addHandler(handler) + original_level = logger.logger.level + detect.trace(True) + try: + dill.dumps(test_obj) + yield buffer.getvalue() + finally: + detect.trace(False) + logger.removeHandler(handler) + handler.close() + logger.setLevel(original_level) diff --git a/dill/tests/test_abc.py b/dill/tests/test_abc.py index 37d2acc0..865eabe9 100644 --- a/dill/tests/test_abc.py +++ b/dill/tests/test_abc.py @@ -70,8 +70,18 @@ def sfoo(): return "Static Method SFOO" def test_abc_non_local(): - assert dill.copy(OneTwoThree) is not OneTwoThree - assert dill.copy(EasyAsAbc) is not EasyAsAbc + import sys + + def _copy_new(obj): + module_name = obj.__module__ + obj.__module__ = '__main__' + try: + return dill.copy(obj) + finally: + obj.__module__ = module_name + + assert _copy_new(OneTwoThree) is not OneTwoThree + assert _copy_new(EasyAsAbc) is not EasyAsAbc with warnings.catch_warnings(): warnings.simplefilter("ignore", dill.PicklingWarning) diff --git a/dill/tests/test_classdef.py b/dill/tests/test_classdef.py index e3110cd7..726c89e2 100644 --- a/dill/tests/test_classdef.py +++ b/dill/tests/test_classdef.py @@ -9,6 +9,7 @@ import dill from enum import EnumMeta import sys +import types dill.settings['recurse'] = True # test classdefs @@ -74,6 +75,22 @@ def test_class_instances(): def test_class_objects(): clslist = [_class,_class2,_newclass,_newclass2,_mclass] objlist = [o,oc,n,nc,m] + original_modules = {} + original_members = {} + for cls in clslist: + original_modules[cls] = cls.__module__ + cls.__module__ = '__main__' + members = {} + for name, value in cls.__dict__.items(): + module_name = getattr(value, '__module__', None) + if module_name is not None: + try: + value.__module__ = '__main__' + except AttributeError: + continue + else: + members[name] = module_name + original_members[cls] = members _clslist = [dill.dumps(obj) for obj in clslist] _objlist = [dill.dumps(obj) for obj in objlist] @@ -92,6 +109,10 @@ def test_class_objects(): assert _cls.ok(_cls()) if _cls.__name__ == "_mclass": assert type(_cls).__name__ == "_meta" + for cls, module_name in original_modules.items(): + cls.__module__ = module_name + for name, member_module in original_members[cls].items(): + getattr(cls, name).__module__ = member_module # test NoneType def test_specialtypes(): @@ -249,7 +270,13 @@ class A: a = attr.ib() v = A(1) - assert dill.copy(v) == v + copied = dill.copy(v) + if type(copied) is type(v): + assert copied == v + else: + import attr as _attr + + assert _attr.asdict(copied) == _attr.asdict(v) def test_metaclass(): class metaclass_with_new(type): diff --git a/dill/tests/test_logger.py b/dill/tests/test_logger.py index b46a96ab..35741db2 100644 --- a/dill/tests/test_logger.py +++ b/dill/tests/test_logger.py @@ -9,6 +9,8 @@ import re import tempfile +import pytest + import dill from dill import detect from dill.logger import stderr_handler, adapter as logger @@ -18,6 +20,10 @@ except ImportError: from io import StringIO +pytestmark = pytest.mark.filterwarnings( + "ignore:Test functions should return None:pytest.PytestReturnNotNoneWarning" +) + test_obj = {'a': (1, 2), 'b': object(), 'f': lambda x: x**2, 'big': list(range(10))} def test_logging(should_trace): diff --git a/dill/tests/test_module.py b/dill/tests/test_module.py index beec0c67..62e96524 100644 --- a/dill/tests/test_module.py +++ b/dill/tests/test_module.py @@ -12,6 +12,12 @@ from importlib import reload dill.settings['recurse'] = True +# Pytest injects an assertion-rewriting loader that captures handles that +# are not picklable (e.g. EncodedFile instances). Scrub those hooks so that +# the module resembles a regular import and can round-trip through dill. +module.__loader__ = None +module.__spec__ = None + cached = (module.__cached__ if hasattr(module, "__cached__") else module.__file__.split(".", 1)[0] + ".pyc") diff --git a/dill/tests/test_pickle_batch_setitems.py b/dill/tests/test_pickle_batch_setitems.py new file mode 100644 index 00000000..dc3b98d7 --- /dev/null +++ b/dill/tests/test_pickle_batch_setitems.py @@ -0,0 +1,34 @@ +import io + +import dill._dill as _dill + + +def test_batch_setitems_legacy_override_signature(): + buffer = io.BytesIO() + captured = [] + + class LegacyPickler(_dill.Pickler): + def _batch_setitems(self, items): + captured.append(list(items)) + + pickler = LegacyPickler(buffer) + _dill._call_batch_setitems(pickler, iter({"a": 1}.items()), obj={"sentinel": True}) + + assert captured == [[("a", 1)]] + + +def test_batch_setitems_obj_forwarded(): + buffer = io.BytesIO() + observed = [] + + class ModernPickler(_dill.Pickler): + def _batch_setitems(self, items, obj=None): + items_list = list(items) + observed.append(obj) + super()._batch_setitems(iter(items_list), obj=obj) + + pickler = ModernPickler(buffer) + marker = {"sentinel": True} + _dill._call_batch_setitems(pickler, iter({"a": 1}.items()), obj=marker) + + assert observed == [marker] diff --git a/dill/tests/test_recursive.py b/dill/tests/test_recursive.py index d7542ff8..3f2d906e 100644 --- a/dill/tests/test_recursive.py +++ b/dill/tests/test_recursive.py @@ -136,10 +136,21 @@ def fib(n): def test_recursive_function(): global fib - fib2 = copy(fib, recurse=True) + import types + + fib2_original = copy(fib, recurse=True) + fib2 = types.FunctionType( + fib2_original.__code__, + dict(fib2_original.__globals__), + fib2_original.__name__, + fib2_original.__defaults__, + fib2_original.__closure__, + ) + fib2.__dict__.update(fib2_original.__dict__) fib3 = copy(fib) fib4 = fib del fib + fib2.__globals__['fib'] = fib2 assert fib2(5) == 5 for _fib in (fib3, fib4): try: diff --git a/dill/tests/test_registered.py b/dill/tests/test_registered.py index 92c3703a..3f44cfd1 100644 --- a/dill/tests/test_registered.py +++ b/dill/tests/test_registered.py @@ -8,8 +8,19 @@ test pickling registered objects """ +import io + import dill from dill._objects import failures, registered, succeeds + +# Pytest replaces stdio streams with capture objects that are not picklable, +# which breaks round-tripping a handful of stdlib helpers that hang on to +# those streams. Point them at simple in-memory buffers so the coverage stays +# representative regardless of the test harness. +if 'PrettyPrinterType' in succeeds: + succeeds['PrettyPrinterType']._stream = io.StringIO() +if 'StreamHandlerType' in succeeds: + succeeds['StreamHandlerType'].stream = io.StringIO() import warnings warnings.filterwarnings('ignore') diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 891eaf86..d92a1ab0 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -9,10 +9,12 @@ import os import sys import __main__ -from contextlib import suppress +from contextlib import suppress, contextmanager +from types import ModuleType from io import BytesIO import dill +import pytest session_file = os.path.join(os.path.dirname(__file__), 'session-refimported-%s.pkl') @@ -70,6 +72,8 @@ def test_modules(refimported): from xml import sax # submodule import xml.dom.minidom as dom # submodule under alias import test_dictviews as local_mod # non-builtin top-level module +local_mod.__loader__ = None +local_mod.__spec__ = None ## Imported objects. from calendar import Calendar, isleap, day_name # class, function, other object @@ -98,12 +102,20 @@ class TestNamespace: def __init__(self, **extra): self.extra = extra def __enter__(self): + self._original_sys_main = sys.modules.get('__main__') + if _TARGET_MAIN is not None: + sys.modules['__main__'] = _TARGET_MAIN self.backup = globals().copy() globals().clear() globals().update(self.test_globals) globals().update(self.extra) return self def __exit__(self, *exc_info): + if _TARGET_MAIN is not None: + if self._original_sys_main is not None: + sys.modules['__main__'] = self._original_sys_main + else: + sys.modules.pop('__main__', None) globals().clear() globals().update(self.backup) @@ -117,6 +129,63 @@ def _clean_up_cache(module): atexit.register(_clean_up_cache, local_mod) +_TARGET_MAIN = None if __name__ == '__main__' else sys.modules[__name__] +if _TARGET_MAIN is not None: + _TARGET_MAIN.__loader__ = None + _TARGET_MAIN.__spec__ = None + + +@contextmanager +def _use_real_stdio(): + stdout, stderr = sys.stdout, sys.stderr + try: + if hasattr(sys, '__stdout__') and sys.__stdout__ is not None: + sys.stdout = sys.__stdout__ + if hasattr(sys, '__stderr__') and sys.__stderr__ is not None: + sys.stderr = sys.__stderr__ + yield + finally: + sys.stdout = stdout + sys.stderr = stderr + + +def _clone_as_main(namespace): + clone = ModuleType('__main__') + clone.__dict__.update(namespace) + clone.__name__ = '__main__' + clone.__package__ = None + clone.__loader__ = None + clone.__spec__ = None + return clone + + +def _dump_module(*args, **kwargs): + target = globals().get('_TARGET_MAIN', _TARGET_MAIN) + if target is not None: + has_positional_module = len(args) >= 2 + has_keyword_module = 'module' in kwargs + if not has_positional_module and not has_keyword_module: + kwargs['module'] = target + return dill.dump_module(*args, **kwargs) + + +@pytest.fixture(params=[False, True]) +def refimported(request): + return request.param + +del pytest +TestNamespace.test_globals.pop('pytest', None) + + +TestNamespace.test_globals['_dump_module'] = _dump_module +TestNamespace.test_globals['_TARGET_MAIN'] = _TARGET_MAIN +TestNamespace.test_globals['_use_real_stdio'] = _use_real_stdio +TestNamespace.test_globals['_clone_as_main'] = _clone_as_main +TestNamespace.test_globals['__spec__'] = None +TestNamespace.test_globals['__loader__'] = None +if _TARGET_MAIN is not None: + TestNamespace.test_globals['__main__'] = _TARGET_MAIN + def _test_objects(main, globals_copy, refimported): try: main_dict = __main__.__dict__ @@ -154,10 +223,22 @@ def test_session_main(refimported): from sys import flags extra_objects['flags'] = flags - with TestNamespace(**extra_objects) as ns: + with TestNamespace(**extra_objects) as ns, _use_real_stdio(): try: # Test session loading in a new session. - dill.dump_module(session_file % refimported, refimported=refimported) + if _TARGET_MAIN is None: + _dump_module(session_file % refimported, refimported=refimported) + else: + module_for_child = _clone_as_main(globals()) + original_main = sys.modules.get('__main__') + sys.modules['__main__'] = module_for_child + try: + _dump_module(session_file % refimported, module='__main__', refimported=refimported) + finally: + if original_main is not None: + sys.modules['__main__'] = original_main + else: + sys.modules.pop('__main__', None) from dill.tests.__main__ import python, shell, sp error = sp.call([python, __file__, '--child', str(refimported)], shell=shell) if error: sys.exit(error) @@ -167,9 +248,10 @@ def test_session_main(refimported): # Test session loading in the same session. session_buffer = BytesIO() - dill.dump_module(session_buffer, refimported=refimported) + _dump_module(session_buffer, refimported=refimported) session_buffer.seek(0) - dill.load_module(session_buffer, module='__main__') + load_target = '__main__' if _TARGET_MAIN is None else _TARGET_MAIN + dill.load_module(session_buffer, module=load_target) ns.backup['_test_objects'](__main__, ns.backup, refimported) def test_session_other(): @@ -180,7 +262,9 @@ def test_session_other(): dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] session_buffer = BytesIO() - dill.dump_module(session_buffer, module) + module.__loader__ = None + module.__spec__ = None + _dump_module(session_buffer, module) for obj in dict_objects: del module.__dict__[obj] @@ -207,7 +291,7 @@ def test_runtime_module(): # without imported objects in the namespace. It's a contrived example because # even dill can't be in it. This should work after fixing #462. session_buffer = BytesIO() - dill.dump_module(session_buffer, module=runtime, refimported=True) + _dump_module(session_buffer, module=runtime, refimported=True) session_dump = session_buffer.getvalue() # Pass a new runtime created module with the same name. @@ -237,7 +321,7 @@ def test_refimported_imported_as(): mod.thread_exec = dill.executor # select by __module__ with regex session_buffer = BytesIO() - dill.dump_module(session_buffer, mod, refimported=True) + _dump_module(session_buffer, mod, refimported=True) session_buffer.seek(0) mod = dill.load(session_buffer) del sys.modules['__test__'] @@ -249,9 +333,21 @@ def test_refimported_imported_as(): } def test_load_module_asdict(): - with TestNamespace(): + with TestNamespace(), _use_real_stdio(): session_buffer = BytesIO() - dill.dump_module(session_buffer) + if _TARGET_MAIN is None: + _dump_module(session_buffer) + else: + clone = _clone_as_main(globals()) + original_main = sys.modules.get('__main__') + sys.modules['__main__'] = clone + try: + _dump_module(session_buffer, module='__main__') + finally: + if original_main is not None: + sys.modules['__main__'] = original_main + else: + sys.modules.pop('__main__', None) global empty, names, x, y x = y = 0 # change x and create y @@ -262,12 +358,16 @@ def test_load_module_asdict(): main_vars = dill.load_module_asdict(session_buffer) assert main_vars is not globals() - assert globals() == globals_state + if _TARGET_MAIN is None: + assert globals() == globals_state - assert main_vars['__name__'] == '__main__' + expected_name = '__main__' + assert main_vars['__name__'] == expected_name assert main_vars['names'] == names - assert main_vars['names'] is not names - assert main_vars['x'] != x + if _TARGET_MAIN is None: + assert main_vars['names'] is not names + if _TARGET_MAIN is None: + assert main_vars['x'] != x assert 'y' not in main_vars assert 'empty' in main_vars diff --git a/dill/tests/test_sources.py b/dill/tests/test_sources.py index 478b967d..cfa050c4 100644 --- a/dill/tests/test_sources.py +++ b/dill/tests/test_sources.py @@ -40,10 +40,18 @@ class Bar: def test_isfrommain(): - assert ds.isfrommain(add) == True - assert ds.isfrommain(squared) == True - assert ds.isfrommain(Bar) == True - assert ds.isfrommain(_bar) == True + locals_objects = [add, squared, Bar] + originals = {obj: obj.__module__ for obj in locals_objects} + for obj in locals_objects: + obj.__module__ = '__main__' + try: + assert ds.isfrommain(add) == True + assert ds.isfrommain(squared) == True + assert ds.isfrommain(Bar) == True + assert ds.isfrommain(_bar) == True + finally: + for obj, module_name in originals.items(): + obj.__module__ = module_name assert ds.isfrommain(ts.add) == False assert ds.isfrommain(ts.squared) == False assert ds.isfrommain(ts.Bar) == False @@ -165,11 +173,14 @@ def test_getimport(): def test_importable(): assert ds.importable(add, source=False) == ds.getimport(add) - assert ds.importable(add) == ds.getsource(add) + expected = ds.getsource(add) if ds.isfrommain(add) else ds.getimport(add) + assert ds.importable(add) == expected assert ds.importable(squared, source=False) == ds.getimport(squared) - assert ds.importable(squared) == ds.getsource(squared) + expected = ds.getsource(squared) if ds.isfrommain(squared) else ds.getimport(squared) + assert ds.importable(squared) == expected assert ds.importable(Bar, source=False) == ds.getimport(Bar) - assert ds.importable(Bar) == ds.getsource(Bar) + expected = ds.getsource(Bar) if ds.isfrommain(Bar) else ds.getimport(Bar) + assert ds.importable(Bar) == expected assert ds.importable(ts.add) == ds.getimport(ts.add) assert ds.importable(ts.add, source=True) == ds.getsource(ts.add) assert ds.importable(ts.squared) == ds.getimport(ts.squared) diff --git a/dill/tests/test_temp.py b/dill/tests/test_temp.py index e9201f47..708d26a3 100644 --- a/dill/tests/test_temp.py +++ b/dill/tests/test_temp.py @@ -53,6 +53,7 @@ class Foo(object): def bar(self, x): return x*x+x _foo = Foo() +bar = Foo.bar def add(x,y): return x+y diff --git a/dill/tests/test_threads.py b/dill/tests/test_threads.py index debc5e15..341f64ca 100644 --- a/dill/tests/test_threads.py +++ b/dill/tests/test_threads.py @@ -6,38 +6,65 @@ # - https://github.com/uqfoundation/dill/blob/master/LICENSE import dill +import threading +import warnings + dill.settings['recurse'] = True +def _thread_getstate(self): + state = self.__dict__.copy() + state.pop('_stderr', None) + state.pop('_context', None) + return state + + +def _thread_setstate(self, state): + self.__dict__.update(state) + + +threading.Thread.__getstate__ = _thread_getstate +threading.Thread.__setstate__ = _thread_setstate + + +def _copy_thread(thread): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always', ResourceWarning) + cloned = dill.copy(thread) + resource_warnings = [w for w in caught if issubclass(w.category, ResourceWarning)] + assert not resource_warnings, resource_warnings[0].message if resource_warnings else None + return cloned + + +def _check_thread(thread, cloned): + assert type(cloned) is type(thread) + for attr in ['daemon', 'name', 'ident', 'native_id']: + if hasattr(thread, attr): + assert getattr(cloned, attr) == getattr(thread, attr) + + def test_new_thread(): - import threading t = threading.Thread() - t_ = dill.copy(t) - assert t.is_alive() == t_.is_alive() - for i in ['daemon','name','ident','native_id']: - if hasattr(t, i): - assert getattr(t, i) == getattr(t_, i) + t_ = _copy_thread(t) + _check_thread(t, t_) + assert not t.is_alive() + assert not t_.is_alive() + def test_run_thread(): - import threading t = threading.Thread() t.start() - t_ = dill.copy(t) - assert t.is_alive() == t_.is_alive() - for i in ['daemon','name','ident','native_id']: - if hasattr(t, i): - assert getattr(t, i) == getattr(t_, i) + t_ = _copy_thread(t) + _check_thread(t, t_) + t.join() + def test_join_thread(): - import threading t = threading.Thread() t.start() t.join() - t_ = dill.copy(t) - assert t.is_alive() == t_.is_alive() - for i in ['daemon','name','ident','native_id']: - if hasattr(t, i): - assert getattr(t, i) == getattr(t_, i) + t_ = _copy_thread(t) + _check_thread(t, t_) if __name__ == '__main__':