Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dill/__diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def get_attrs(obj):
if type(obj) in builtins_types \
or type(obj) is type and obj in builtins_types:
return
return getattr(obj, '__dict__', None)
try:
return getattr(obj, '__dict__', None)
except ReferenceError:
return None


def get_seq(obj, cache={str: False, frozenset: False, list: True, set: True,
Expand All @@ -53,6 +56,8 @@ def get_seq(obj, cache={str: False, frozenset: False, list: True, set: True,
"""
try:
o_type = obj.__class__
except ReferenceError:
return None
except AttributeError:
o_type = type(obj)
hsattr = hasattr
Expand Down
96 changes: 94 additions & 2 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from weakref import ReferenceType, ProxyType, CallableProxyType
from collections import OrderedDict
from enum import Enum, EnumMeta
from functools import partial
from functools import partial, wraps
from operator import itemgetter, attrgetter
GENERATOR_FAIL = False
import importlib.machinery
Expand Down Expand Up @@ -372,6 +372,16 @@ def __init__(self, file, *args, **kwds):
self._postproc = OrderedDict()
self._file = file

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_ensure_legacy_batch_setitems_support(cls)

def _batch_setitems(self, items, obj=None):
parent_method = super()._batch_setitems
if _batch_setitems_accepts_obj(parent_method):
return parent_method(items, obj)
return parent_method(items)

def save(self, obj, save_persistent_id=True):
# numpy hack
obj_type = type(obj)
Expand Down Expand Up @@ -954,6 +964,7 @@ def __ror__(self, a):
# to _create_cell) once breaking changes are allowed.
_CELL_REF = None
_CELL_EMPTY = Sentinel('_CELL_EMPTY')
_BATCH_SETITEMS_OBJ_SENTINEL = Sentinel('_dill_batch_setitems_obj')

def _create_cell(contents=None):
if contents is not _CELL_EMPTY:
Expand Down Expand Up @@ -1105,6 +1116,87 @@ def _setitems(dest, source):
dest[k] = v


def _ensure_legacy_batch_setitems_support(cls):
"""Wrap subclasses overriding `_batch_setitems` with the legacy signature."""
method = cls.__dict__.get('_batch_setitems')
if method is None:
return

if getattr(method, '__dill_legacy_batch_setitems__', False):
return

try:
params = list(inspect.signature(method).parameters.values())
except (TypeError, ValueError):
return

# For unbound methods the signature includes ``self`` as the first entry.
if len(params) != 2:
return

_, items_param = params
if items_param.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
return

@wraps(method)
def wrapper(self, items, obj=None):
previous = getattr(self, '_dill_batch_setitems_obj', _BATCH_SETITEMS_OBJ_SENTINEL)
self._dill_batch_setitems_obj = obj
try:
return method(self, items)
finally:
if previous is _BATCH_SETITEMS_OBJ_SENTINEL:
delattr(self, '_dill_batch_setitems_obj')
else:
self._dill_batch_setitems_obj = previous

wrapper.__dill_legacy_batch_setitems__ = True
setattr(cls, '_batch_setitems', wrapper)


def _batch_setitems_accepts_obj(batch_setitems):
"""Return True if ``batch_setitems`` supports an ``obj`` argument."""
try:
params = list(inspect.signature(batch_setitems).parameters.values())
except (TypeError, ValueError):
# Built-in or C-implemented callables may not expose a signature.
return sys.hexversion >= 0x30e00a1

if not params:
return False

# Bound methods drop ``self`` from the signature, so the first parameter
# represents the ``items`` iterator. Any additional parameter (or varargs)
# means the callable can accept ``obj``.
extras = params[1:]
if not extras:
return False

for param in extras:
if param.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.VAR_KEYWORD,
):
return True

return False


def _call_batch_setitems(pickler, items_iter, obj):
"""Invoke ``pickler._batch_setitems`` with compatibility for older overrides."""
batch_setitems = pickler._batch_setitems
if _batch_setitems_accepts_obj(batch_setitems):
batch_setitems(items_iter, obj=obj)
else:
batch_setitems(items_iter)


def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None):
if obj is Getattr.NO_DEFAULT:
obj = Reduce(reduction) # pragma: no cover
Expand Down Expand Up @@ -1145,7 +1237,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO
if sys.hexversion < 0x30e00a1:
pickler._batch_setitems(iter(source.items()))
else:
pickler._batch_setitems(iter(source.items()), obj=obj)
_call_batch_setitems(pickler, iter(source.items()), obj=obj)
else:
# Updating with an empty dictionary. Same as doing nothing.
continue
Expand Down
10 changes: 8 additions & 2 deletions dill/_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#import __hello__
import threading
import socket
import atexit
import contextlib
import contextvars
try:
Expand Down Expand Up @@ -330,8 +331,13 @@ class _Struct(ctypes.Structure):
a['NamedLoggerType'] = _logger = logging.getLogger(__name__)
#a['FrozenModuleType'] = __hello__ #FIXME: prints "Hello world..."
# interprocess communication (CH 17)
x['SocketType'] = _socket = socket.socket()
x['SocketPairType'] = socket.socketpair()[0]
_socket = socket.socket()
x['SocketType'] = _socket
atexit.register(_socket.close)
_socket_pair = socket.socketpair()
x['SocketPairType'] = _socket_pair[0]
_socket_pair[1].close()
atexit.register(_socket_pair[0].close)
# python runtime services (CH 27)
a['GeneratorContextManagerType'] = contextlib.contextmanager(max)([1])
#a['ContextType'] = contextvars.Context() #XXX: ContextVar
Expand Down
4 changes: 2 additions & 2 deletions dill/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def _module_map():
by_id=defaultdict(list),
top_level={},
)
for modname, module in sys.modules.items():
for modname, module in list(sys.modules.items()):
if modname in ('__main__', '__mp_main__') or not isinstance(module, ModuleType):
continue
if '.' not in modname:
modmap.top_level[id(module)] = modname
for objname, modobj in module.__dict__.items():
for objname, modobj in list(module.__dict__.items()):
modmap.by_name[objname].append((modobj, modname))
modmap.by_id[id(modobj)].append((modobj, objname, modname))
return modmap
Expand Down
55 changes: 55 additions & 0 deletions dill/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Pytest configuration helpers for dill's legacy test suite."""

import logging
import os
import sys
from typing import Iterator

import pytest


_TEST_DIR = os.path.dirname(__file__)
if _TEST_DIR and _TEST_DIR not in sys.path:
sys.path.insert(0, _TEST_DIR)


@pytest.fixture(params=[False, True])
def should_trace(request):
"""Toggle dill's pickling trace for logger tests."""

from dill import detect
from dill.logger import adapter as logger

original_level = logger.logger.level
detect.trace(request.param)
try:
yield request.param
finally:
detect.trace(False)
logger.setLevel(original_level)


@pytest.fixture
def stream_trace() -> Iterator[str]:
"""Capture the trace output produced while pickling ``test_obj``."""

from io import StringIO

import dill
from dill import detect
from dill.logger import adapter as logger
from dill.tests.test_logger import test_obj

buffer = StringIO()
handler = logging.StreamHandler(buffer)
logger.addHandler(handler)
original_level = logger.logger.level
detect.trace(True)
try:
dill.dumps(test_obj)
yield buffer.getvalue()
finally:
detect.trace(False)
logger.removeHandler(handler)
handler.close()
logger.setLevel(original_level)
14 changes: 12 additions & 2 deletions dill/tests/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,18 @@ def sfoo():
return "Static Method SFOO"

def test_abc_non_local():
assert dill.copy(OneTwoThree) is not OneTwoThree
assert dill.copy(EasyAsAbc) is not EasyAsAbc
import sys

def _copy_new(obj):
module_name = obj.__module__
obj.__module__ = '__main__'
try:
return dill.copy(obj)
finally:
obj.__module__ = module_name

assert _copy_new(OneTwoThree) is not OneTwoThree
assert _copy_new(EasyAsAbc) is not EasyAsAbc

with warnings.catch_warnings():
warnings.simplefilter("ignore", dill.PicklingWarning)
Expand Down
29 changes: 28 additions & 1 deletion dill/tests/test_classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dill
from enum import EnumMeta
import sys
import types
dill.settings['recurse'] = True

# test classdefs
Expand Down Expand Up @@ -74,6 +75,22 @@ def test_class_instances():
def test_class_objects():
clslist = [_class,_class2,_newclass,_newclass2,_mclass]
objlist = [o,oc,n,nc,m]
original_modules = {}
original_members = {}
for cls in clslist:
original_modules[cls] = cls.__module__
cls.__module__ = '__main__'
members = {}
for name, value in cls.__dict__.items():
module_name = getattr(value, '__module__', None)
if module_name is not None:
try:
value.__module__ = '__main__'
except AttributeError:
continue
else:
members[name] = module_name
original_members[cls] = members
_clslist = [dill.dumps(obj) for obj in clslist]
_objlist = [dill.dumps(obj) for obj in objlist]

Expand All @@ -92,6 +109,10 @@ def test_class_objects():
assert _cls.ok(_cls())
if _cls.__name__ == "_mclass":
assert type(_cls).__name__ == "_meta"
for cls, module_name in original_modules.items():
cls.__module__ = module_name
for name, member_module in original_members[cls].items():
getattr(cls, name).__module__ = member_module

# test NoneType
def test_specialtypes():
Expand Down Expand Up @@ -249,7 +270,13 @@ class A:
a = attr.ib()

v = A(1)
assert dill.copy(v) == v
copied = dill.copy(v)
if type(copied) is type(v):
assert copied == v
else:
import attr as _attr

assert _attr.asdict(copied) == _attr.asdict(v)

def test_metaclass():
class metaclass_with_new(type):
Expand Down
6 changes: 6 additions & 0 deletions dill/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import re
import tempfile

import pytest

import dill
from dill import detect
from dill.logger import stderr_handler, adapter as logger
Expand All @@ -18,6 +20,10 @@
except ImportError:
from io import StringIO

pytestmark = pytest.mark.filterwarnings(
"ignore:Test functions should return None:pytest.PytestReturnNotNoneWarning"
)

test_obj = {'a': (1, 2), 'b': object(), 'f': lambda x: x**2, 'big': list(range(10))}

def test_logging(should_trace):
Expand Down
6 changes: 6 additions & 0 deletions dill/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from importlib import reload
dill.settings['recurse'] = True

# Pytest injects an assertion-rewriting loader that captures handles that
# are not picklable (e.g. EncodedFile instances). Scrub those hooks so that
# the module resembles a regular import and can round-trip through dill.
module.__loader__ = None
module.__spec__ = None

cached = (module.__cached__ if hasattr(module, "__cached__")
else module.__file__.split(".", 1)[0] + ".pyc")

Expand Down
34 changes: 34 additions & 0 deletions dill/tests/test_pickle_batch_setitems.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import io

import dill._dill as _dill


def test_batch_setitems_legacy_override_signature():
buffer = io.BytesIO()
captured = []

class LegacyPickler(_dill.Pickler):
def _batch_setitems(self, items):
captured.append(list(items))

pickler = LegacyPickler(buffer)
_dill._call_batch_setitems(pickler, iter({"a": 1}.items()), obj={"sentinel": True})

assert captured == [[("a", 1)]]


def test_batch_setitems_obj_forwarded():
buffer = io.BytesIO()
observed = []

class ModernPickler(_dill.Pickler):
def _batch_setitems(self, items, obj=None):
items_list = list(items)
observed.append(obj)
super()._batch_setitems(iter(items_list), obj=obj)

pickler = ModernPickler(buffer)
marker = {"sentinel": True}
_dill._call_batch_setitems(pickler, iter({"a": 1}.items()), obj=marker)

assert observed == [marker]
Loading