diff --git a/metaflow/cli.py b/metaflow/cli.py index a4a1558304d..083f3ff744c 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -15,7 +15,7 @@ from .datastore import FlowDataStore, TaskDataStoreSet from .debug import debug from .exception import CommandException, MetaflowException -from .flowspec import _FlowState +from .flowspec import FlowStateItems from .graph import FlowGraph from .metaflow_config import ( DEFAULT_DATASTORE, @@ -458,8 +458,8 @@ def start( # We can now set the the CONFIGS value in the flow properly. This will overwrite # anything that may have been passed in by default and we will use exactly what # the original flow had. Note that these are accessed through the parameter name - ctx.obj.flow._flow_state[_FlowState.CONFIGS].clear() - d = ctx.obj.flow._flow_state[_FlowState.CONFIGS] + ctx.obj.flow._flow_state[FlowStateItems.CONFIGS].clear() + d = ctx.obj.flow._flow_state[FlowStateItems.CONFIGS] for param_name, var_name in zip(config_param_names, config_var_names): val = param_ds[var_name] debug.userconf_exec("Loaded config %s as: %s" % (param_name, val)) @@ -471,7 +471,7 @@ def start( raise ctx.obj.delayed_config_exception # Init all values in the flow mutators and then process them - for decorator in ctx.obj.flow._flow_state.get(_FlowState.FLOW_MUTATORS, []): + for decorator in ctx.obj.flow._flow_state[FlowStateItems.FLOW_MUTATORS]: decorator.external_init() new_cls = ctx.obj.flow._process_config_decorators(config_options) @@ -593,9 +593,9 @@ def start( ctx.obj.flow_datastore, { k: ConfigValue(v) if v is not None else None - for k, v in ctx.obj.flow.__class__._flow_state.get( - _FlowState.CONFIGS, {} - ).items() + for k, v in ctx.obj.flow.__class__._flow_state[ + FlowStateItems.CONFIGS + ].items() }, ) diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 760508497f0..83ad043f9f8 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -5,7 +5,7 @@ from functools import partial from typing import Any, Callable, Dict, List, NewType, Tuple, TypeVar, Union, overload -from .flowspec import FlowSpec, _FlowState +from .flowspec import FlowSpec, FlowStateItems from .exception import ( MetaflowInternalError, MetaflowException, @@ -294,7 +294,11 @@ def add_decorator_options(cmd): def flow_decorators(flow_cls): - return [d for deco_list in flow_cls._flow_decorators.values() for d in deco_list] + return [ + d + for deco_list in flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS].values() + for d in deco_list + ] class StepDecorator(Decorator): @@ -492,12 +496,20 @@ def _base_flow_decorator(decofunc, *args, **kwargs): cls = args[0] if isinstance(cls, type) and issubclass(cls, FlowSpec): # flow decorators add attributes in the class dictionary, - # _flow_decorators. _flow_decorators is of type `{key:[decos]}` - if decofunc.name in cls._flow_decorators and not decofunc.allow_multiple: + # cls._flow_state[FlowStateItems.FLOW_DECORATORS]. This is of type `{key:[decos]}` + self_flow_decos = cls._flow_state.self_data[FlowStateItems.FLOW_DECORATORS] + inherited_flow_decos = cls._flow_state.inherited_data.get( + FlowStateItems.FLOW_DECORATORS, {} + ) + + if ( + decofunc.name in self_flow_decos + or decofunc.name in inherited_flow_decos + ) and not decofunc.allow_multiple: raise DuplicateFlowDecoratorException(decofunc.name) else: deco_instance = decofunc(attributes=kwargs, statically_defined=True) - cls._flow_decorators.setdefault(decofunc.name, []).append(deco_instance) + self_flow_decos.setdefault(decofunc.name, []).append(deco_instance) else: raise BadFlowDecoratorException(decofunc.name) return cls @@ -703,7 +715,8 @@ def _should_skip_decorator_for_spin( def _init(flow, only_non_static=False): - for decorators in flow._flow_decorators.values(): + flow_decos = flow._flow_state[FlowStateItems.FLOW_DECORATORS] + for decorators in flow_decos.values(): for deco in decorators: deco.external_init() @@ -729,7 +742,8 @@ def _init_flow_decorators( skip_decorators=False, ): # Since all flow decorators are stored as `{key:[deco]}` we iterate through each of them. - for decorators in flow._flow_decorators.values(): + flow_decos = flow._flow_state[FlowStateItems.FLOW_DECORATORS] + for decorators in flow_decos.values(): # First resolve the `options` for the flow decorator. # Options are passed from cli. # For example `@project` can take a `--name` / `--branch` from the cli as options. @@ -789,7 +803,7 @@ def _init_step_decorators( # and then the step level ones to maintain a consistent order with how # other decorators are run. - for deco in cls._flow_state.get(_FlowState.FLOW_MUTATORS, []): + for deco in cls._flow_state[FlowStateItems.FLOW_MUTATORS]: if isinstance(deco, FlowMutator): inserted_by_value = [deco.decorator_name] + (deco.inserted_by or []) mutable_flow = MutableFlow( @@ -811,8 +825,8 @@ def _init_step_decorators( deco.mutate(mutable_flow) # We reset cached_parameters on the very off chance that the user added # more configurations based on the configuration - if _FlowState.CACHED_PARAMETERS in cls._flow_state: - del cls._flow_state[_FlowState.CACHED_PARAMETERS] + if cls._flow_state[FlowStateItems.CACHED_PARAMETERS] is not None: + cls._flow_state[FlowStateItems.CACHED_PARAMETERS] = None else: raise MetaflowInternalError( "A non FlowMutator found in flow custom decorators" diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index a8df867e644..e068f9b03b5 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -4,10 +4,11 @@ import traceback import reprlib +from collections.abc import MutableMapping from enum import Enum from itertools import islice from types import FunctionType, MethodType -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple from . import cmd_with_io, parameters from .debug import debug @@ -76,14 +77,84 @@ def __getitem__(self, item): return item or 0 # item is None for the control task, but it is also split 0 -class _FlowState(Enum): +class FlowStateItems(Enum): CONFIGS = 1 FLOW_MUTATORS = 2 CACHED_PARAMETERS = 3 - SET_CONFIG_PARAMETERS = ( - 4 # These are Parameters that now have a ConfigValue (converted) - ) - # but we need to remember them. + SET_CONFIG_PARAMETERS = 4 # Parameters that now have a ConfigValue (converted) + FLOW_DECORATORS = 5 + + +class _FlowState(MutableMapping): + # Dict like structure to hold state information about the flow but it holds + # the key/values in two sub dictionaries: the ones that are specific to the flow + # and the ones that are inherited from parent classes. + # This is NOT a general purpose class and is meant to only work with FlowSpec. + # For example, it assumes that items are only list, dicts or None and assumes that + # self._self_data has all keys properly initialized. + + def __init__(self, *args, **kwargs): + self._self_data = dict(*args, **kwargs) + self._merged_data = {} + self._inherited = {} + + def __getitem__(self, key): + # ORDER IS IMPORTANT: we use inherited first and extend by whatever is in + # the flowspec + if key in self._merged_data: + return self._merged_data[key] + + # We haven't accessed this yet so compute it for the first time + self_value = self._self_data.get(key) + inherited_value = self._inherited.get(key) + + if self_value is not None: + self._merged_data[key] = self._merge_value(inherited_value, self_value) + return self._merged_data[key] + elif key in self._self_data: + # Case of CACHED_PARAMETERS; a valid value is None. It is never inherited + self._merged_data[key] = None + return None + raise KeyError(key) + + def __setitem__(self, key, value): + self._self_data[key] = value + + def __delitem__(self, key): + if key in self._merged_data: + del self._merged_data[key] + else: + raise KeyError(key) + + def __iter__(self): + # All keys are in self._self_data + for key in self._self_data: + yield self[key] + + def __len__(self): + return len(self._self_data) + + @property + def self_data(self): + self._merged_data.clear() + return self._self_data + + @property + def inherited_data(self): + return self._inherited + + def _merge_value(self, inherited_value, self_value): + if self_value is None: + return None + inherited_value = inherited_value or type(self_value)() + if isinstance(self_value, dict): + return {**inherited_value, **self_value} + elif isinstance(self_value, list): + return inherited_value + self_value + raise RuntimeError( + f"Cannot merge values of type {type(inherited_value)} and {type(self_value)} -- " + "please report this as a bug" + ) class FlowSpecMeta(type): @@ -104,12 +175,16 @@ def _init_attrs(cls): # Runner/NBRunner. This is also created here in the meta class to avoid it being # shared between different children classes. - # We should move _flow_decorators into this structure as well but keeping it - # out to limit the changes for now. - cls._flow_decorators = {} - - # Keys are _FlowState enum values - cls._flow_state = {} + # Keys are FlowStateItems enum values + cls._flow_state = _FlowState( + { + FlowStateItems.CONFIGS: {}, + FlowStateItems.FLOW_MUTATORS: [], + FlowStateItems.CACHED_PARAMETERS: None, + FlowStateItems.SET_CONFIG_PARAMETERS: [], + FlowStateItems.FLOW_DECORATORS: {}, + } + ) # Keep track if configs have been processed -- this is particularly applicable # for the Runner/Deployer where calling multiple APIs on the same flow could @@ -119,7 +194,7 @@ def _init_attrs(cls): # We inherit stuff from our parent classes as well -- we need to be careful # in terms of the order; we will follow the MRO with the following rules: - # - decorators (cls._flow_decorators) will cause an error if they do not + # - decorators will cause an error if they do not # support multiple and we see multiple instances of the same # - config decorators will be joined # - configs will be added later directly by the class; base class configs will @@ -127,25 +202,51 @@ def _init_attrs(cls): # We only need to do this for the base classes since the current class will # get updated as decorators are parsed. + + # We also need to be sure to not duplicate things. Consider something like + # class A(FlowSpec): + # pass + # + # class B(A): + # pass + # + # class C(B): + # pass + # + # C inherits from both B and A but we need to duplicate things from A only + # ONCE. To do this, we only propagate the self data from each class. + for base in cls.__mro__: if base != cls and base != FlowSpec and issubclass(base, FlowSpec): # Take care of decorators - for deco_name, deco in base._flow_decorators.items(): - if deco_name in cls._flow_decorators and not deco.allow_multiple: - raise DuplicateFlowDecoratorException(deco_name) - cls._flow_decorators.setdefault(deco_name, []).extend(deco) + base_flow_decorators = base._flow_state.self_data[ + FlowStateItems.FLOW_DECORATORS + ] - # Take care of configs and flow mutators - base_configs = base._flow_state.get(_FlowState.CONFIGS) - if base_configs: - cls._flow_state.setdefault(_FlowState.CONFIGS, {}).update( - base_configs + inherited_cls_flow_decorators = ( + cls._flow_state.inherited_data.setdefault( + FlowStateItems.FLOW_DECORATORS, {} ) - base_mutators = base._flow_state.get(_FlowState.FLOW_MUTATORS) + ) + for deco_name, deco in base_flow_decorators.items(): + if not deco: + continue + deco_allow_multiple = deco[0].allow_multiple + if ( + deco_name in inherited_cls_flow_decorators + and not deco_allow_multiple + ): + raise DuplicateFlowDecoratorException(deco_name) + inherited_cls_flow_decorators.setdefault(deco_name, []).extend(deco) + + # Take care of flow mutators -- configs are just objects in the class + # so they are naturally inherited. We do not need to do anything special + # for them. + base_mutators = base._flow_state.self_data[FlowStateItems.FLOW_MUTATORS] if base_mutators: - cls._flow_state.setdefault(_FlowState.FLOW_MUTATORS, []).extend( - base_mutators - ) + cls._flow_state.inherited_data.setdefault( + FlowStateItems.FLOW_MUTATORS, [] + ).extend(base_mutators) cls._init_graph() @@ -175,7 +276,6 @@ class FlowSpec(metaclass=FlowSpecMeta): "_datastore", "_cached_input", "_graph", - "_flow_decorators", "_flow_state", "_steps", "index", @@ -226,6 +326,11 @@ def script_name(self) -> str: fname = fname[:-1] return os.path.basename(fname) + @property + def _flow_decorators(self): + # Backward compatible method to access flow decorators + return self._flow_state[FlowStateItems.FLOW_DECORATORS] + @classmethod def _check_parameters(cls, config_parameters=False): seen = set() @@ -250,7 +355,7 @@ def _process_config_decorators(cls, config_options, process_configs=True): # Fast path for no user configurations if not process_configs or ( - not cls._flow_state.get(_FlowState.FLOW_MUTATORS) + not cls._flow_state[FlowStateItems.FLOW_MUTATORS] and all(len(step.config_decorators) == 0 for step in cls._steps) ): # Process parameters to allow them to also use config values easily @@ -284,12 +389,12 @@ def _process_config_decorators(cls, config_options, process_configs=True): debug.userconf_exec("Setting config %s to %s" % (var, str(val))) setattr(cls, var, val) - cls._flow_state[_FlowState.SET_CONFIG_PARAMETERS] = to_save_configs + cls._flow_state[FlowStateItems.SET_CONFIG_PARAMETERS] = to_save_configs # Run all the decorators. We first run the flow-level decorators # and then the step level ones to maintain a consistent order with how # other decorators are run. - for deco in cls._flow_state.get(_FlowState.FLOW_MUTATORS, []): + for deco in cls._flow_state[FlowStateItems.FLOW_MUTATORS]: if isinstance(deco, FlowMutator): inserted_by_value = [deco.decorator_name] + (deco.inserted_by or []) mutable_flow = MutableFlow( @@ -313,8 +418,8 @@ def _process_config_decorators(cls, config_options, process_configs=True): deco.pre_mutate(mutable_flow) # We reset cached_parameters on the very off chance that the user added # more configurations based on the configuration - if _FlowState.CACHED_PARAMETERS in cls._flow_state: - del cls._flow_state[_FlowState.CACHED_PARAMETERS] + if cls._flow_state[FlowStateItems.CACHED_PARAMETERS] is not None: + cls._flow_state[FlowStateItems.CACHED_PARAMETERS] = None else: raise MetaflowInternalError( "A non FlowMutator found in flow custom decorators" @@ -429,7 +534,7 @@ def _set_constants(self, graph, kwargs, config_options): "statically_defined": deco.statically_defined, "inserted_by": deco.inserted_by, } - for deco in self._flow_state.get(_FlowState.FLOW_MUTATORS, []) + for deco in self._flow_state[FlowStateItems.FLOW_MUTATORS] ], "extensions": extension_info(), } @@ -437,10 +542,10 @@ def _set_constants(self, graph, kwargs, config_options): @classmethod def _get_parameters(cls): - cached = cls._flow_state.get(_FlowState.CACHED_PARAMETERS) + cached = cls._flow_state[FlowStateItems.CACHED_PARAMETERS] returned = set() if cached is not None: - for set_config in cls._flow_state.get(_FlowState.SET_CONFIG_PARAMETERS, []): + for set_config in cls._flow_state[FlowStateItems.SET_CONFIG_PARAMETERS]: returned.add(set_config[0]) yield set_config[0], set_config[1] for var in cached: @@ -448,7 +553,7 @@ def _get_parameters(cls): yield var, getattr(cls, var) return build_list = [] - for set_config in cls._flow_state.get(_FlowState.SET_CONFIG_PARAMETERS, []): + for set_config in cls._flow_state[FlowStateItems.SET_CONFIG_PARAMETERS]: returned.add(set_config[0]) yield set_config[0], set_config[1] for var in dir(cls): @@ -461,7 +566,7 @@ def _get_parameters(cls): if isinstance(val, Parameter) and var not in returned: build_list.append(var) yield var, val - cls._flow_state[_FlowState.CACHED_PARAMETERS] = build_list + cls._flow_state[FlowStateItems.CACHED_PARAMETERS] = build_list def _set_datastore(self, datastore): self._datastore = datastore diff --git a/metaflow/runner/click_api.py b/metaflow/runner/click_api.py index bda7ddb0157..a0cbfc30c69 100644 --- a/metaflow/runner/click_api.py +++ b/metaflow/runner/click_api.py @@ -43,7 +43,7 @@ ) from metaflow.decorators import add_decorator_options from metaflow.exception import MetaflowException -from metaflow.flowspec import _FlowState +from metaflow.flowspec import FlowStateItems from metaflow.includefile import FilePathClass from metaflow.metaflow_config import CLICK_API_PROCESS_CONFIG from metaflow.parameters import JSONTypeClass, flow_context @@ -532,7 +532,7 @@ def _compute_flow_parameters(self): # We ignore any errors if we don't check the configs in the click API. # Init all values in the flow mutators and then process them - for decorator in self._flow_cls._flow_state.get(_FlowState.FLOW_MUTATORS, []): + for decorator in self._flow_cls._flow_state[FlowStateItems.FLOW_MUTATORS]: decorator.external_init() new_cls = self._flow_cls._process_config_decorators( diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 3dfb01f529d..4c6d1417b4b 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -46,7 +46,7 @@ from .datastore import FlowDataStore, TaskDataStoreSet from .debug import debug from .decorators import flow_decorators -from .flowspec import _FlowState +from .flowspec import FlowStateItems from .mflog import mflog, RUNTIME_LOG_SOURCE from .util import to_unicode, compress_list, unicode_type, get_latest_task_pathspec from .clone_util import clone_task_helper @@ -940,7 +940,8 @@ def execute(self): def _run_exit_hooks(self): try: - exit_hook_decos = self._flow._flow_decorators.get("exit_hook", []) + flow_decos = self._flow._flow_state[FlowStateItems.FLOW_DECORATORS] + exit_hook_decos = flow_decos.get("exit_hook", []) if not exit_hook_decos: return @@ -2077,7 +2078,7 @@ def __init__( # We also pass configuration options using the kv. syntax which will cause # the configuration options to be loaded from the CONFIG file (or local-config-file # in the case of the local runtime) - configs = self.task.flow._flow_state.get(_FlowState.CONFIGS) + configs = self.task.flow._flow_state[FlowStateItems.CONFIGS] if configs: self.top_level_options["config-value"] = [ (k, ConfigInput.make_key_name(k)) for k in configs diff --git a/metaflow/user_configs/config_options.py b/metaflow/user_configs/config_options.py index f77f50dfc97..3d996af1c71 100644 --- a/metaflow/user_configs/config_options.py +++ b/metaflow/user_configs/config_options.py @@ -186,7 +186,7 @@ def process_configs( click_obj: Optional[Any] = None, ): from ..cli import echo_always, echo_dev_null # Prevent circular import - from ..flowspec import _FlowState # Prevent circular import + from ..flowspec import FlowStateItems # Prevent circular import flow_cls = getattr(current_flow, "flow_cls", None) if flow_cls is None: @@ -260,7 +260,6 @@ def process_configs( for k in all_keys ) - flow_cls._flow_state[_FlowState.CONFIGS] = {} to_return = {} if not has_all_kv: @@ -332,7 +331,7 @@ def process_configs( if val is None: missing_configs.add(name) to_return[name] = None - flow_cls._flow_state[_FlowState.CONFIGS][name] = None + flow_cls._flow_state.self_data[FlowStateItems.CONFIGS][name] = None continue if val.startswith(_CONVERTED_NO_FILE): no_file.append(name) @@ -356,7 +355,9 @@ def process_configs( click_obj.delayed_config_exception = exc return None raise exc from e - flow_cls._flow_state[_FlowState.CONFIGS][name] = read_value + flow_cls._flow_state.self_data[FlowStateItems.CONFIGS][ + name + ] = read_value to_return[name] = ( ConfigValue(read_value) if read_value is not None else None ) @@ -373,7 +374,9 @@ def process_configs( ) continue # TODO: Support YAML - flow_cls._flow_state[_FlowState.CONFIGS][name] = read_value + flow_cls._flow_state.self_data[FlowStateItems.CONFIGS][ + name + ] = read_value to_return[name] = ( ConfigValue(read_value) if read_value is not None else None ) diff --git a/metaflow/user_configs/config_parameters.py b/metaflow/user_configs/config_parameters.py index 9bc553758a3..84d0689b08a 100644 --- a/metaflow/user_configs/config_parameters.py +++ b/metaflow/user_configs/config_parameters.py @@ -55,9 +55,9 @@ def dump_config_values(flow: "FlowSpec"): - from ..flowspec import _FlowState # Prevent circular import + from ..flowspec import FlowStateItems # Prevent circular import - configs = flow._flow_state.get(_FlowState.CONFIGS) + configs = flow._flow_state[FlowStateItems.CONFIGS] if configs: return {"user_configs": configs} return {} @@ -344,7 +344,7 @@ def __getattr__(self, name): return c def __call__(self, ctx=None, deploy_time=False): - from ..flowspec import _FlowState # Prevent circular import + from ..flowspec import FlowStateItems # Prevent circular import # Two additional arguments are only used by DeployTimeField which will call # this function with those two additional arguments. They are ignored. @@ -380,7 +380,7 @@ def __call__(self, ctx=None, deploy_time=False): self._globals or globals(), { k: ConfigValue(v) if v is not None else None - for k, v in flow_cls._flow_state.get(_FlowState.CONFIGS, {}).items() + for k, v in flow_cls._flow_state[FlowStateItems.CONFIGS].items() }, ) except NameError as e: diff --git a/metaflow/user_decorators/mutable_flow.py b/metaflow/user_decorators/mutable_flow.py index 2cd092fcb78..e015c9c324f 100644 --- a/metaflow/user_decorators/mutable_flow.py +++ b/metaflow/user_decorators/mutable_flow.py @@ -61,7 +61,10 @@ def decorator_specs( A tuple containing the decorator name, it's fully qualified name, a list of positional arguments, and a dictionary of keyword arguments. """ - for decos in self._flow_cls._flow_decorators.values(): + from metaflow.flowspec import FlowStateItems + + flow_decos = self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] + for decos in flow_decos.values(): for deco in decos: # 3.7 does not support yield foo, *bar syntax so we # work around @@ -108,12 +111,10 @@ def start(self): Tuple[str, ConfigValue] Iterates over the configurations of the flow """ - from metaflow.flowspec import _FlowState + from metaflow.flowspec import FlowStateItems - # When configs are parsed, they are loaded in _flow_state[_FlowState.CONFIGS] - for name, value in self._flow_cls._flow_state.get( - _FlowState.CONFIGS, {} - ).items(): + # When configs are parsed, they are loaded in _flow_state[FlowStateItems.CONFIGS] + for name, value in self._flow_cls._flow_state[FlowStateItems.CONFIGS].items(): r = name, ConfigValue(value) if value is not None else None debug.userconf_exec("Mutable flow yielding config: %s" % str(r)) yield r @@ -228,7 +229,7 @@ def remove_parameter(self, parameter_name: str) -> bool: "method and not the `mutate` method" % (parameter_name, " from ".join(self._inserted_by)) ) - from metaflow.flowspec import _FlowState + from metaflow.flowspec import FlowStateItems for var, param in self._flow_cls._get_parameters(): if param.IS_CONFIG_PARAMETER: @@ -239,7 +240,7 @@ def remove_parameter(self, parameter_name: str) -> bool: "Mutable flow removing parameter %s from flow" % var ) # Reset so that we don't list it again - self._flow_cls._flow_state.pop(_FlowState.CACHED_PARAMETERS, None) + self._flow_cls._flow_state.pop(FlowStateItems.CACHED_PARAMETERS, None) return True debug.userconf_exec( "Mutable flow failed to remove parameter %s from flow" % parameter_name @@ -324,11 +325,17 @@ def add_decorator( FlowDecorator, extract_flow_decorator_from_decospec, ) + from metaflow.flowspec import FlowStateItems deco_args = deco_args or [] deco_kwargs = deco_kwargs or {} def _add_flow_decorator(flow_deco): + # NOTE: Here we operate not on self_data or inherited_data because mutators + # are processed on the end flow anyways (they can come from any of the base + # flow classes but they only execute on the flow actually being run). This makes + # it easier particularly for the case of OVERRIDE where we need to override + # a decorator that could be in either of the inherited or self dictionaries. if deco_args: raise MetaflowException( "Flow decorators do not take additional positional arguments" @@ -340,18 +347,17 @@ def _add_flow_decorator(flow_deco): def _do_add(): flow_deco.statically_defined = self._statically_defined flow_deco.inserted_by = self._inserted_by - self._flow_cls._flow_decorators.setdefault(flow_deco.name, []).append( - flow_deco - ) + flow_decos = self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] + + flow_decos.setdefault(flow_deco.name, []).append(flow_deco) debug.userconf_exec( "Mutable flow adding flow decorator '%s'" % deco_type ) - # self._flow_cls._flow_decorators is a dictionary of form : + # self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] is a dictionary of form : # : [deco_instance, deco_instance, ...] - existing_deco = [ - d for d in self._flow_cls._flow_decorators if d == flow_deco.name - ] + flow_decos = self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] + existing_deco = [d for d in flow_decos if d == flow_deco.name] if flow_deco.allow_multiple or not existing_deco: _do_add() @@ -367,10 +373,9 @@ def _do_add(): "Mutable flow overriding flow decorator '%s' " "(removing existing decorator and adding new one)" % flow_deco.name ) - self._flow_cls._flow_decorators = { - d: self._flow_cls._flow_decorators[d] - for d in self._flow_cls._flow_decorators - if d != flow_deco.name + flow_decos = self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] + self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] = { + d: flow_decos[d] for d in flow_decos if d != flow_deco.name } _do_add() elif duplicates == MutableFlow.ERROR: @@ -443,6 +448,9 @@ def remove_decorator( Returns True if a decorator was removed. """ + # Prevent circular import + from metaflow.flowspec import FlowStateItems + if not self._pre_mutate: raise MetaflowException( "Removing flow-decorator '%s' from %s is only allowed in the `pre_mutate` " @@ -451,10 +459,12 @@ def remove_decorator( do_all = deco_args is None and deco_kwargs is None did_remove = False - if do_all and deco_name in self._flow_cls._flow_decorators: - del self._flow_cls._flow_decorators[deco_name] + flow_decos = self._flow_cls._flow_state[FlowStateItems.FLOW_DECORATORS] + + if do_all and deco_name in flow_decos: + del flow_decos[deco_name] return True - old_deco_list = self._flow_cls._flow_decorators.get(deco_name) + old_deco_list = flow_decos.get(deco_name) if not old_deco_list: debug.userconf_exec( "Mutable flow failed to remove decorator '%s' from flow (non present)" @@ -471,10 +481,11 @@ def remove_decorator( "Mutable flow removed %d decorators from flow" % (len(old_deco_list) - len(new_deco_list)) ) + if new_deco_list: - self._flow_cls._flow_decorators[deco_name] = new_deco_list + flow_decos[deco_name] = new_deco_list else: - del self._flow_cls._flow_decorators[deco_name] + del flow_decos[deco_name] return did_remove def __getattr__(self, name): diff --git a/metaflow/user_decorators/user_flow_decorator.py b/metaflow/user_decorators/user_flow_decorator.py index 36f3ed53c6b..ad964baa827 100644 --- a/metaflow/user_decorators/user_flow_decorator.py +++ b/metaflow/user_decorators/user_flow_decorator.py @@ -188,9 +188,9 @@ def __call__( def _set_flow_cls( self, flow_spec: "metaflow.flowspec.FlowSpecMeta" ) -> "metaflow.flowspec.FlowSpecMeta": - from ..flowspec import _FlowState + from ..flowspec import FlowStateItems - flow_spec._flow_state.setdefault(_FlowState.FLOW_MUTATORS, []).append(self) + flow_spec._flow_state.self_data[FlowStateItems.FLOW_MUTATORS].append(self) self._flow_cls = flow_spec return flow_spec