Skip to content

Commit 298a099

Browse files
committed
Session: check id against module being saved
1 parent 5b77a06 commit 298a099

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

dill/session.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import dill
1919
from dill import Pickler, Unpickler
20-
from ._dill import ModuleType, _import_module, _is_builtin_module
20+
from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module
2121
from .utils import AttrDict, CheckerSet, TransSet
2222
from .settings import settings
2323

@@ -111,10 +111,14 @@ def _exclude_objs(main, exclude_extra, filters_extra, settings):
111111
categories = {'ids': int, 'names': str, 'regex': re.Pattern, 'types': type}
112112
exclude = AttrDict({cat: copy(settings.session_exclude[cat]) for cat in categories})
113113
filters = copy(settings.session_filters)
114+
del categories['ids'] # special case
114115
if exclude_extra is not None:
115116
if isinstance(exclude_extra, str):
116117
raise ValueError("'exclude' can be of type Iterable[str], but not str")
117118
for item in exclude_extra:
119+
if isinstance(item, int):
120+
exclude.ids.add(item, main=main)
121+
continue
118122
for category, klass in categories.items():
119123
if isinstance(item, klass):
120124
exclude[category].add(item)
@@ -220,12 +224,12 @@ def load_session(filename: Union[os.PathLike, io.BytesIO] = '/tmp/session.pkl',
220224
# Settings #
221225
##############
222226

223-
def _as_id(item):
227+
def _as_id(item, *, main=_main_module):
224228
if isinstance(item, int):
225-
import warnings, __main__
226-
if not any(id(obj) == item for obj in __main__.__dict__.values()):
227-
warnings.warn("%d isn't the id of any object in __main__ namespace. "
228-
"Did you mean 'id(%d)?'" % (item, item))
229+
import warnings
230+
if not any(id(obj) == item for obj in main.__dict__.values()):
231+
warnings.warn("%d isn't the id of any object in the '%s' namespace. "
232+
"Did you mean 'id(%d)'?" % (item, main.__name__, item))
229233
return item
230234
return id(item)
231235

dill/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class TransSet(set):
4747
def __init__(self, func: Callable, *args):
4848
self.constructor = func
4949
super().__init__(*args)
50-
def add(self, item):
51-
super().add(self.constructor(item))
50+
def add(self, item, **kwargs):
51+
super().add(self.constructor(item, **kwargs))
5252
def discard(self, item):
5353
super().discard(self.constructor(item))
5454
def remove(self, item):

0 commit comments

Comments
 (0)