Skip to content

Commit de1943f

Browse files
committed
Add options to exclude objects from dump_session()
1 parent a650f62 commit de1943f

File tree

5 files changed

+345
-32
lines changed

5 files changed

+345
-32
lines changed

dill/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,17 @@
283283
284284
"""
285285

286-
from ._dill import dump, dumps, load, loads, \
287-
Pickler, Unpickler, register, copy, pickle, pickles, check, \
288-
HIGHEST_PROTOCOL, DEFAULT_PROTOCOL, PicklingError, UnpicklingError, \
289-
HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, PickleWarning, \
290-
PicklingWarning, UnpicklingWarning
286+
from ._dill import (
287+
Pickler, Unpickler,
288+
dump, dumps, load, loads, copy, check, pickle, pickles, register,
289+
DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE,
290+
PicklingError, UnpicklingError, PickleError, PicklingWarning, UnpicklingWarning, PickleWarning,
291+
)
291292
from .session import dump_session, load_session
292293
from . import detect, session, source, temp
293294

294295
# get global settings
295-
from .settings import settings
296+
from .settings import Settings, settings
296297

297298
# make sure "trace" is turned off
298299
detect.trace(False)

dill/_dill.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,8 +1917,7 @@ def save_function(pickler, obj):
19171917
_recurse = getattr(pickler, '_recurse', None)
19181918
_byref = getattr(pickler, '_byref', None)
19191919
_postproc = getattr(pickler, '_postproc', None)
1920-
_main_modified = getattr(pickler, '_main_modified', None)
1921-
_original_main = getattr(pickler, '_original_main', __builtin__)#'None'
1920+
_original_main = getattr(pickler, '_original_main', None)
19221921
postproc_list = []
19231922
if _recurse:
19241923
# recurse to get all globals referred to by obj
@@ -1935,7 +1934,7 @@ def save_function(pickler, obj):
19351934

19361935
# If the globals is the __dict__ from the module being saved as a
19371936
# session, substitute it by the dictionary being actually saved.
1938-
if _main_modified and globs_copy is _original_main.__dict__:
1937+
if _original_main and globs_copy is _original_main.__dict__:
19391938
globs_copy = getattr(pickler, '_main', _original_main).__dict__
19401939
globs = globs_copy
19411940
# If the globals is a module __dict__, do not save it in the pickle.

dill/_utils.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#!/usr/bin/env python
2+
#
3+
# Author: Leonardo Gama (@leogama)
4+
# Copyright (c) 2022 The Uncertainty Quantification Foundation.
5+
# License: 3-clause BSD. The full license text is available at:
6+
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
7+
"""auxiliary internal classes used in multiple submodules, set here to avoid import recursion"""
8+
9+
__all__ = ['AttrDict', 'ExcludeRules', 'Filter', 'RuleType']
10+
11+
import logging
12+
logger = logging.getLogger('dill._utils')
13+
14+
import inspect
15+
from functools import partialmethod
16+
17+
class AttrDict(dict):
18+
"""syntactic sugar for accessing dictionary items"""
19+
_CAST = object() # singleton
20+
def __init__(self, *args, **kwargs):
21+
data = args[0] if len(args) == 2 and args[1] is self._CAST else dict(*args, **kwargs)
22+
for key, val in tuple(data.items()):
23+
if isinstance(val, dict) and not isinstance(val, AttrDict):
24+
data[key] = AttrDict(val, self._CAST)
25+
super().__setattr__('_data', data)
26+
def _check_attr(self, name):
27+
try:
28+
super().__getattribute__(name)
29+
except AttributeError:
30+
pass
31+
else:
32+
raise AttributeError("'AttrDict' object attribute %r is read-only" % name)
33+
def __getattr__(self, key):
34+
# This is called only if dict.__getattribute__(key) fails.
35+
try:
36+
return self._data[key]
37+
except KeyError:
38+
raise AttributeError("'AttrDict' object has no attribute %r" % key)
39+
def __setattr__(self, key, value):
40+
self._check_attr(key)
41+
if isinstance(value, dict):
42+
self._data[key] = AttrDict(value, self._CAST)
43+
else:
44+
self._data[key] = value
45+
def __delattr__(self, key):
46+
self._check_attr(key)
47+
del self._data[key]
48+
def __proxy__(self, method, *args, **kwargs):
49+
return getattr(self._data, method)(*args, **kwargs)
50+
def __reduce__(self):
51+
return AttrDict, (self._data,)
52+
def copy(self):
53+
# Deep copy.
54+
copy = AttrDict(self._data)
55+
for key, val in tuple(copy.items()):
56+
if isinstance(val, AttrDict):
57+
copy[key] = val.copy()
58+
return copy
59+
60+
for method, _ in inspect.getmembers(dict, inspect.ismethoddescriptor):
61+
if method not in vars(AttrDict) and method not in {'__getattribute__', '__reduce_ex__'}:
62+
setattr(AttrDict, method, partialmethod(AttrDict.__proxy__, method))
63+
64+
65+
### Namespace filtering
66+
import re
67+
from dataclasses import InitVar, dataclass, field, fields
68+
from collections import abc, namedtuple
69+
from enum import Enum
70+
from functools import partialmethod
71+
from itertools import filterfalse
72+
from re import Pattern
73+
from typing import Callable, Iterable, Set, Tuple, Union
74+
75+
RuleType = Enum('RuleType', 'EXCLUDE INCLUDE', module=__name__)
76+
NamedObj = namedtuple('NamedObj', 'name value', module=__name__)
77+
78+
Filter = Union[str, Pattern, int, type, Callable]
79+
Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]]
80+
81+
def isiterable(arg):
82+
return isinstance(arg, abc.Iterable) and not isinstance(arg, (str, bytes))
83+
84+
@dataclass
85+
class ExcludeFilters:
86+
ids: Set[int] = field(default_factory=set)
87+
names: Set[str] = field(default_factory=set)
88+
regex: Set[Pattern] = field(default_factory=set)
89+
types: Set[type] = field(default_factory=set)
90+
funcs: Set[Callable] = field(default_factory=set)
91+
92+
@property
93+
def filter_sets(self):
94+
return tuple(field.name for field in fields(self))
95+
def __bool__(self):
96+
return any(getattr(self, filter_set) for filter_set in self.filter_sets)
97+
def _check(self, filter):
98+
if isinstance(filter, str):
99+
if filter.isidentifier():
100+
field = 'names'
101+
else:
102+
filter, field = re.compile(filter), 'regex'
103+
elif isinstance(filter, Pattern):
104+
field = 'regex'
105+
elif isinstance(filter, int):
106+
field = 'ids'
107+
elif isinstance(filter, type):
108+
field = 'types'
109+
elif callable(filter):
110+
field = 'funcs'
111+
else:
112+
raise ValueError("invalid filter: %r" % filter)
113+
return filter, getattr(self, field)
114+
def add(self, filter):
115+
filter, filter_set = self._check(filter)
116+
filter_set.add(filter)
117+
def discard(self, filter):
118+
filter, filter_set = self._check(filter)
119+
filter_set.discard(filter)
120+
def remove(self, filter):
121+
filter, filter_set = self._check(filter)
122+
filter_set.remove(filter)
123+
def update(self, filters):
124+
for filter in filters:
125+
self.add(filter)
126+
def clear(self):
127+
for filter_set in self.filter_sets:
128+
getattr(self, filter_set).clear()
129+
def add_type(self, type_name):
130+
import types
131+
name_suffix = type_name + 'Type' if not type_name.endswith('Type') else type_name
132+
if hasattr(types, name_suffix):
133+
type_name = name_suffix
134+
type_obj = getattr(types, type_name, None)
135+
if not isinstance(type_obj, type):
136+
named = type_name if type_name == name_suffix else "%r or %r" % (type_name, name_suffix)
137+
raise NameError("could not find a type named %s in module 'types'" % named)
138+
self.types.add(type_obj)
139+
140+
@dataclass
141+
class ExcludeRules:
142+
exclude: ExcludeFilters = field(init=False, default_factory=ExcludeFilters)
143+
include: ExcludeFilters = field(init=False, default_factory=ExcludeFilters)
144+
rules: InitVar[Iterable[Rule]] = None
145+
146+
def __post_init__(self, rules):
147+
if rules is not None:
148+
self.update(rules)
149+
150+
def __proxy__(self, method, filter, *, rule_type=RuleType.EXCLUDE):
151+
if rule_type is RuleType.EXCLUDE:
152+
getattr(self.exclude, method)(filter)
153+
elif rule_type is RuleType.INCLUDE:
154+
getattr(self.include, method)(filter)
155+
else:
156+
raise ValueError("invalid rule type: %r (must be one of %r)" % (rule_type, list(RuleType)))
157+
158+
add = partialmethod(__proxy__, 'add')
159+
discard = partialmethod(__proxy__, 'discard')
160+
remove = partialmethod(__proxy__, 'remove')
161+
162+
def update(self, rules):
163+
if isinstance(rules, ExcludeRules):
164+
for filter_set in self.exclude.filter_sets:
165+
getattr(self.exclude, filter_set).update(getattr(rules.exclude, filter_set))
166+
getattr(self.include, filter_set).update(getattr(rules.include, filter_set))
167+
else:
168+
# Validate rules.
169+
for rule in rules:
170+
if not isinstance(rule, tuple) or len(rule) != 2:
171+
raise ValueError("invalid rule format: %r" % rule)
172+
for rule_type, filter in rules:
173+
if isiterable(filter):
174+
for f in filter:
175+
self.add(f, rule_type=rule_type)
176+
else:
177+
self.add(filter, rule_type=rule_type)
178+
179+
def clear(self):
180+
self.exclude.clear()
181+
self.include.clear()
182+
183+
def filter_namespace(self, namespace, obj=None):
184+
if not self.exclude and not self.include:
185+
return namespace
186+
187+
# Protect agains dict changes during the call.
188+
namespace_copy = namespace.copy() if obj is None or namespace is vars(obj) else namespace
189+
objects = all_objects = [NamedObj._make(item) for item in namespace_copy.items()]
190+
191+
for filters in (self.exclude, self.include):
192+
if filters is self.exclude and not filters:
193+
# Treat the rule set as an allowlist.
194+
exclude_objs = objects
195+
continue
196+
elif filters is self.include:
197+
if not filters or not exclude_objs:
198+
break
199+
objects = exclude_objs
200+
201+
flist = []
202+
types_list = tuple(filters.types)
203+
# Apply cheaper/broader filters first.
204+
if types_list:
205+
flist.append(lambda obj: isinstance(obj.value, types_list))
206+
if filters.ids:
207+
flist.append(lambda obj: id(obj.value) in filters.ids)
208+
if filters.names:
209+
flist.append(lambda obj: obj.name in filters.names)
210+
if filters.regex:
211+
flist.append(lambda obj: any(regex.fullmatch(obj.name) for regex in filters.regex))
212+
flist.extend(filters.funcs)
213+
for f in flist:
214+
objects = filterfalse(f, objects)
215+
216+
if filters is self.exclude:
217+
include_names = {obj.name for obj in objects}
218+
exclude_objs = [obj for obj in all_objects if obj.name not in include_names]
219+
else:
220+
exclude_objs = list(objects)
221+
222+
if not exclude_objs:
223+
return namespace
224+
if len(exclude_objs) == len(namespace):
225+
warnings.warn("filtering operation left the namespace empty!", PicklingWarning)
226+
return {}
227+
if logger.isEnabledFor(logging.INFO):
228+
exclude_listing = {obj.name: type(obj.value).__name__ for obj in sorted(exclude_objs)}
229+
exclude_listing = str(exclude_listing).translate({ord(","): "\n", ord("'"): None})
230+
logger.info("Objects excluded from dump_session():\n%s\n", exclude_listing)
231+
232+
for obj in exclude_objs:
233+
del namespace_copy[obj.name]
234+
return namespace_copy

0 commit comments

Comments
 (0)