diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba26b05258d..18a6f80b961 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,10 +64,12 @@ jobs: pip install -r additional-tests-requirements.txt --no-deps - name: Install dependencies (latest versions) if: ${{ matrix.deps_versions == 'latest' }} - run: pip install --upgrade pyarrow huggingface-hub - - name: Install depencencies (minimum versions) + run: | + pip uninstall -y apache-beam + pip install --upgrade pyarrow huggingface-hub dill + - name: Install dependencies (minimum versions) if: ${{ matrix.deps_versions != 'latest' }} - run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers + run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers dill==0.3.1.1 - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ diff --git a/setup.py b/setup.py index f7743bb509d..b899c027c4b 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ # Minimum 6.0.0 to support wrap_array which is needed for ArrayND features "pyarrow>=6.0.0", # For smart caching dataset processing - "dill<0.3.6", # tmp pin until 0.3.6 release: see https://github.com/huggingface/datasets/pull/4397 + "dill<0.3.7", # tmp pin until next 0.3.7 release: see https://github.com/huggingface/datasets/pull/5166 # For performance gains with apache arrow "pandas", # for downloading datasets over HTTPS diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index c4b7152307e..fb8b0003fe5 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -634,39 +634,76 @@ def proxy(func): return proxy -@pklregister(CodeType) -def _save_code(pickler, obj): - """ - From dill._dill.save_code - This is a modified version that removes the origin (filename + line no.) - of functions created in notebooks or shells for example. - """ - dill._dill.log.info(f"Co: {obj}") - # The filename of a function is the .py file where it is defined. - # Filenames of functions created in notebooks or shells start with '<' - # ex: for ipython, and for shell - # Moreover lambda functions have a special name: '' - # ex: (lambda x: x).__code__.co_name == "" # True - # - # For the hashing mechanism we ignore where the function has been defined - # More specifically: - # - we ignore the filename of special functions (filename starts with '<') - # - we always ignore the line number - # - we only use the base name of the file instead of the whole path, - # to be robust in case a script is moved for example. - # - # Only those two lines are different from the original implementation: - co_filename = ( - "" if obj.co_filename.startswith("<") or obj.co_name == "" else os.path.basename(obj.co_filename) - ) - co_firstlineno = 1 - # The rest is the same as in the original dill implementation - if dill._dill.PY3: - if hasattr(obj, "co_posonlyargcount"): +if config.DILL_VERSION < version.parse("0.3.6"): + + @pklregister(CodeType) + def _save_code(pickler, obj): + """ + From dill._dill.save_code + This is a modified version that removes the origin (filename + line no.) + of functions created in notebooks or shells for example. + """ + dill._dill.log.info(f"Co: {obj}") + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" if obj.co_filename.startswith("<") or obj.co_name == "" else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation + if dill._dill.PY3: + if hasattr(obj, "co_posonlyargcount"): + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: args = ( obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, @@ -681,9 +718,47 @@ def _save_code(pickler, obj): obj.co_freevars, obj.co_cellvars, ) - else: + pickler.save_reduce(CodeType, args, obj=obj) + dill._dill.log.info("# Co") + return + +elif config.DILL_VERSION.release[:3] == version.parse("0.3.6").release: + + # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1104 + @pklregister(CodeType) + def save_code(pickler, obj): + dill._dill.logger.trace(pickler, "Co: %s", obj) + + ############################################################################################################ + # Modification here for huggingface/datasets + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" if obj.co_filename.startswith("<") or obj.co_name == "" else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation, except for the replacements: + # - obj.co_filename => co_filename + # - obj.co_firstlineno => co_firstlineno + ############################################################################################################ + + if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] obj.co_argcount, + obj.co_posonlyargcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, @@ -692,33 +767,100 @@ def _save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - co_filename, + co_filename, # Modification for huggingface/datasets ############################################ obj.co_name, - co_firstlineno, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_endlinetable, + obj.co_columntable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_linetable"): # python 3.10 (16 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: # python 3.7 (15 args) + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### obj.co_lnotab, obj.co_freevars, obj.co_cellvars, ) - else: - args = ( - obj.co_argcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, - obj.co_name, - co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - pickler.save_reduce(CodeType, args, obj=obj) - dill._dill.log.info("# Co") - return + + pickler.save_reduce(dill._dill._create_code, args, obj=obj) + dill._dill.logger.trace(pickler, "# Co") + return if config.DILL_VERSION < version.parse("0.3.5"): @@ -796,7 +938,7 @@ def save_function(pickler, obj): dill._dill.log.info("# F2") return -else: # config.DILL_VERSION >= version.parse("0.3.5") +elif config.DILL_VERSION.release[:3] == version.parse("0.3.5").release: # 0.3.5, 0.3.5.1 # https://github.com/uqfoundation/dill/blob/dill-0.3.5.1/dill/_dill.py @pklregister(FunctionType) @@ -804,7 +946,6 @@ def save_function(pickler, obj): if not dill._dill._locate_function(obj, pickler): dill._dill.log.info("F1: %s" % obj) _recurse = getattr(pickler, "_recurse", None) - # _byref = getattr(pickler, "_byref", None) # TODO: not used _postproc = getattr(pickler, "_postproc", None) _main_modified = getattr(pickler, "_main_modified", None) _original_main = getattr(pickler, "_original_main", dill._dill.__builtin__) # 'None' @@ -941,6 +1082,147 @@ def save_function(pickler, obj): dill._dill.log.info("# F2") return +elif config.DILL_VERSION.release[:3] == version.parse("0.3.6").release: + + # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1739 + @pklregister(FunctionType) + def save_function(pickler, obj): + if not dill._dill._locate_function(obj, pickler): + if type(obj.__code__) is not CodeType: + # Some PyPy builtin functions have no module name, and thus are not + # able to be located + module_name = getattr(obj, "__module__", None) + if module_name is None: + module_name = dill._dill.__builtin__.__name__ + module = dill._dill._import_module(module_name, safe=True) + _pypy_builtin = False + try: + found, _ = dill._dill._getattribute(module, obj.__qualname__) + if getattr(found, "__func__", None) is obj: + _pypy_builtin = True + except AttributeError: + pass + + if _pypy_builtin: + dill._dill.logger.trace(pickler, "F3: %s", obj) + pickler.save_reduce(getattr, (found, "__func__"), obj=obj) + dill._dill.logger.trace(pickler, "# F3") + return + + dill._dill.logger.trace(pickler, "F1: %s", obj) + _recurse = getattr(pickler, "_recurse", None) + _postproc = getattr(pickler, "_postproc", None) + _main_modified = getattr(pickler, "_main_modified", None) + _original_main = getattr(pickler, "_original_main", dill._dill.__builtin__) # 'None' + postproc_list = [] + if _recurse: + # recurse to get all globals referred to by obj + from dill.detect import globalvars + + globs_copy = globalvars(obj, recurse=True, builtin=True) + + # Add the name of the module to the globs dictionary to prevent + # the duplication of the dictionary. Pickle the unpopulated + # globals dictionary and set the remaining items after the function + # is created to correctly handle recursion. + globs = {"__name__": obj.__module__} + else: + globs_copy = obj.__globals__ + + # If the globals is the __dict__ from the module being saved as a + # session, substitute it by the dictionary being actually saved. + if _main_modified and globs_copy is _original_main.__dict__: + globs_copy = getattr(pickler, "_main", _original_main).__dict__ + globs = globs_copy + # If the globals is a module __dict__, do not save it in the pickle. + elif ( + globs_copy is not None + and obj.__module__ is not None + and getattr(dill._dill._import_module(obj.__module__, True), "__dict__", None) is globs_copy + ): + globs = globs_copy + else: + globs = {"__name__": obj.__module__} + + ######################################################################################################## + # Modification here for huggingface/datasets + # - globs is a dictionary with keys = var names (str) and values = python objects + # - globs_copy is a dictionary with keys = var names (str) and values = ids of the python objects + # However the dictionary is not always loaded in the same order, + # therefore we have to sort the keys to make deterministic. + # This is important to make `dump` deterministic. + # Only these line are different from the original implementation: + # START + globs_is_globs_copy = globs is globs_copy + globs = dict(sorted(globs.items())) + if globs_is_globs_copy: + globs_copy = globs + elif globs_copy is not None: + globs_copy = dict(sorted(globs_copy.items())) + # END + ######################################################################################################## + + if globs_copy is not None and globs is not globs_copy: + # In the case that the globals are copied, we need to ensure that + # the globals dictionary is updated when all objects in the + # dictionary are already created. + glob_ids = {id(g) for g in globs_copy.values()} + for stack_element in _postproc: + if stack_element in glob_ids: + _postproc[stack_element].append((dill._dill._setitems, (globs, globs_copy))) + break + else: + postproc_list.append((dill._dill._setitems, (globs, globs_copy))) + + closure = obj.__closure__ + state_dict = {} + for fattrname in ("__doc__", "__kwdefaults__", "__annotations__"): + fattr = getattr(obj, fattrname, None) + if fattr is not None: + state_dict[fattrname] = fattr + if obj.__qualname__ != obj.__name__: + state_dict["__qualname__"] = obj.__qualname__ + if "__name__" not in globs or obj.__module__ != globs["__name__"]: + state_dict["__module__"] = obj.__module__ + + state = obj.__dict__ + if type(state) is not dict: + state_dict["__dict__"] = state + state = None + if state_dict: + state = state, state_dict + + dill._dill._save_with_postproc( + pickler, + (dill._dill._create_function, (obj.__code__, globs, obj.__name__, obj.__defaults__, closure), state), + obj=obj, + postproc_list=postproc_list, + ) + + # Lift closure cell update to earliest function (#458) + if _postproc: + topmost_postproc = next(iter(_postproc.values()), None) + if closure and topmost_postproc: + for cell in closure: + possible_postproc = (setattr, (cell, "cell_contents", obj)) + try: + topmost_postproc.remove(possible_postproc) + except ValueError: + continue + + # Change the value of the cell + pickler.save_reduce(*possible_postproc) + # pop None created by calling preprocessing step off stack + pickler.write(bytes("0", "UTF-8")) + + dill._dill.logger.trace(pickler, "# F1") + else: + dill._dill.logger.trace(pickler, "F2: %s", obj) + name = getattr(obj, "__qualname__", getattr(obj, "__name__", None)) + dill._dill.StockPickler.save_global(pickler, obj, name=name) + dill._dill.logger.trace(pickler, "# F2") + return + def copyfunc(func): result = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) @@ -951,16 +1233,31 @@ def copyfunc(func): try: import regex - @pklregister(type(regex.Regex("", 0))) - def _save_regex(pickler, obj): - dill._dill.log.info(f"Re: {obj}") - args = ( - obj.pattern, - obj.flags, - ) - pickler.save_reduce(regex.compile, args, obj=obj) - dill._dill.log.info("# Re") - return + if config.DILL_VERSION < version.parse("0.3.6"): + + @pklregister(type(regex.Regex("", 0))) + def _save_regex(pickler, obj): + dill._dill.log.info(f"Re: {obj}") + args = ( + obj.pattern, + obj.flags, + ) + pickler.save_reduce(regex.compile, args, obj=obj) + dill._dill.log.info("# Re") + return + + elif config.DILL_VERSION.release[:3] == version.parse("0.3.6").release: + + @pklregister(type(regex.Regex("", 0))) + def _save_regex(pickler, obj): + dill._dill.logger.trace(pickler, "Re: %s", obj) + args = ( + obj.pattern, + obj.flags, + ) + pickler.save_reduce(regex.compile, args, obj=obj) + dill._dill.logger.trace(pickler, "# Re") + return except ImportError: pass diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 2b585dd2f89..ba1b5767c7d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -9,7 +9,7 @@ from functools import partial from pathlib import Path from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy as np import numpy.testing as npt @@ -62,6 +62,11 @@ ) +class PickableMagicMock(MagicMock): + def __reduce__(self): + return MagicMock, () + + class Unpicklable: def __getstate__(self): raise pickle.PicklingError() @@ -105,14 +110,6 @@ def assert_arrow_metadata_are_synced_with_dataset_features(dataset: Dataset): @parameterized.named_parameters(IN_MEMORY_PARAMETERS) class BaseDatasetTest(TestCase): - def setUp(self): - # google colab doesn't allow to pickle loggers - # so we want to make sure each tests passes without pickling the logger - def reduce_ex(self): - raise pickle.PicklingError() - - datasets.arrow_dataset.logger.__reduce_ex__ = reduce_ex - @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog @@ -1255,7 +1252,11 @@ def test_map_caching(self, in_memory): self._caplog.clear() with self._caplog.at_level(WARNING): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: - with patch("datasets.arrow_dataset.Pool", side_effect=datasets.arrow_dataset.Pool) as mock_pool: + with patch( + "datasets.arrow_dataset.Pool", + new_callable=PickableMagicMock, + side_effect=datasets.arrow_dataset.Pool, + ) as mock_pool: with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1: dset_test1_data_files = list(dset_test1.cache_files) self.assertEqual(mock_pool.call_count, 1)