Skip to content

Commit a826503

Browse files
authored
(refactor): Enable new engines with custom dispatching and other constructs (#1666)
* Major refactoring: - Move distributed methods out of non-distributed modules - Refactor dispatching - Refactor structure of distributed modules - Add classes for execution engine and memory format
1 parent d218d79 commit a826503

39 files changed

+927
-718
lines changed

awswrangler/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@
3131
timestream,
3232
)
3333
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
34-
from awswrangler._config import ExecutionEngine, config # noqa
35-
from awswrangler.distributed import initialize_ray
36-
37-
if config.execution_engine == ExecutionEngine.RAY.value:
38-
initialize_ray()
34+
from awswrangler._config import config # noqa
35+
from awswrangler._distributed import engine, memory_format # noqa
3936

37+
engine.initialize()
4038

4139
__all__ = [
4240
"athena",
@@ -60,6 +58,8 @@
6058
"secretsmanager",
6159
"sqlserver",
6260
"config",
61+
"engine",
62+
"memory_format",
6363
"timestream",
6464
"__description__",
6565
"__license__",

awswrangler/_config.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""Configuration file for AWS SDK for pandas."""
22

3-
import importlib.util
43
import inspect
54
import logging
65
import os
7-
from enum import Enum
86
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast
97

108
import botocore.config
@@ -18,20 +16,6 @@
1816
_ConfigValueType = Union[str, bool, int, botocore.config.Config, None]
1917

2018

21-
class ExecutionEngine(Enum):
22-
"""Execution engine enum."""
23-
24-
RAY = "ray"
25-
PYTHON = "python"
26-
27-
28-
class MemoryFormat(Enum):
29-
"""Memory format enum."""
30-
31-
MODIN = "modin"
32-
PANDAS = "pandas"
33-
34-
3519
class _ConfigArg(NamedTuple):
3620
dtype: Type[Union[str, bool, int, botocore.config.Config]]
3721
nullable: bool
@@ -69,12 +53,6 @@ class _ConfigArg(NamedTuple):
6953
"botocore_config": _ConfigArg(dtype=botocore.config.Config, nullable=True),
7054
"verify": _ConfigArg(dtype=str, nullable=True, loaded=True),
7155
# Distributed
72-
"execution_engine": _ConfigArg(
73-
dtype=str, nullable=False, loaded=True, default="ray" if importlib.util.find_spec("ray") else "python"
74-
),
75-
"memory_format": _ConfigArg(
76-
dtype=str, nullable=False, loaded=True, default="modin" if importlib.util.find_spec("modin") else "pandas"
77-
),
7856
"address": _ConfigArg(dtype=str, nullable=True),
7957
"redis_password": _ConfigArg(dtype=str, nullable=True),
8058
"ignore_reinit_error": _ConfigArg(dtype=bool, nullable=True),
@@ -440,24 +418,6 @@ def verify(self) -> Optional[str]:
440418
def verify(self, value: Optional[str]) -> None:
441419
self._set_config_value(key="verify", value=value)
442420

443-
@property
444-
def execution_engine(self) -> str:
445-
"""Property execution_engine."""
446-
return cast(str, self["execution_engine"])
447-
448-
@execution_engine.setter
449-
def execution_engine(self, value: str) -> None:
450-
self._set_config_value(key="execution_engine", value=value)
451-
452-
@property
453-
def memory_format(self) -> str:
454-
"""Property memory_format."""
455-
return cast(str, self["memory_format"])
456-
457-
@memory_format.setter
458-
def memory_format(self, value: str) -> None:
459-
self._set_config_value(key="memory_format", value=value)
460-
461421
@property
462422
def address(self) -> Optional[str]:
463423
"""Property address."""

awswrangler/_distributed.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Distributed engine and memory format configuration."""
2+
3+
# pylint: disable=import-outside-toplevel
4+
5+
import importlib.util
6+
from collections import defaultdict
7+
from enum import Enum, unique
8+
from functools import wraps
9+
from typing import Any, Callable, Dict, Optional
10+
11+
12+
@unique
13+
class EngineEnum(Enum):
14+
"""Execution engine enum."""
15+
16+
RAY = "ray"
17+
PYTHON = "python"
18+
19+
20+
@unique
21+
class MemoryFormatEnum(Enum):
22+
"""Memory format enum."""
23+
24+
MODIN = "modin"
25+
PANDAS = "pandas"
26+
27+
28+
class Engine:
29+
"""Execution engine configuration class."""
30+
31+
_enum: Optional[Enum] = None
32+
_registry: Dict[str, Dict[str, Callable[..., Any]]] = defaultdict(dict)
33+
34+
@classmethod
35+
def get_installed(cls) -> Enum:
36+
"""Get the installed distribution engine.
37+
38+
This is the engine that can be imported.
39+
40+
Returns
41+
-------
42+
EngineEnum
43+
The distribution engine installed.
44+
"""
45+
if importlib.util.find_spec("ray"):
46+
return EngineEnum.RAY
47+
return EngineEnum.PYTHON
48+
49+
@classmethod
50+
def get(cls) -> Enum:
51+
"""Get the configured distribution engine.
52+
53+
This is the engine currently configured. If None, the installed engine is returned.
54+
55+
Returns
56+
-------
57+
str
58+
The distribution engine configured.
59+
"""
60+
return cls._enum if cls._enum else cls.get_installed()
61+
62+
@classmethod
63+
def set(cls, name: str) -> None:
64+
"""Set the distribution engine."""
65+
cls._enum = EngineEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member
66+
67+
@classmethod
68+
def dispatch_func(cls, source_func: Callable[..., Any], value: Optional[Any] = None) -> Callable[..., Any]:
69+
"""Dispatch a func based on value or the distribution engine and the source function."""
70+
try:
71+
return cls._registry[value or cls.get().value][source_func.__name__]
72+
except KeyError:
73+
return getattr(source_func, "_source_func", source_func)
74+
75+
@classmethod
76+
def register_func(cls, source_func: Callable[..., Any], destination_func: Callable[..., Any]) -> Callable[..., Any]:
77+
"""Register a func based on the distribution engine and source function."""
78+
cls._registry[cls.get().value][source_func.__name__] = destination_func
79+
return destination_func
80+
81+
@classmethod
82+
def dispatch_on_engine(cls, func: Callable[..., Any]) -> Callable[..., Any]:
83+
"""Dispatch on engine function decorator."""
84+
85+
@wraps(func)
86+
def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any:
87+
return cls.dispatch_func(func)(*args, **kw)
88+
89+
# Save the original function
90+
wrapper._source_func = func # type: ignore # pylint: disable=protected-access
91+
return wrapper
92+
93+
@classmethod
94+
def register(cls, name: Optional[str] = None) -> None:
95+
"""Register the distribution engine dispatch methods."""
96+
engine_name = cls.get_installed().value if not name else name
97+
cls.set(engine_name)
98+
cls._registry.clear()
99+
100+
if engine_name == EngineEnum.RAY.value:
101+
from awswrangler.distributed.ray._register import register_ray
102+
103+
register_ray()
104+
105+
@classmethod
106+
def initialize(cls, name: Optional[str] = None) -> None:
107+
"""Initialize the distribution engine."""
108+
engine_name = cls.get_installed().value if not name else name
109+
if engine_name == EngineEnum.RAY.value:
110+
from awswrangler.distributed.ray import initialize_ray
111+
112+
initialize_ray()
113+
cls.register(engine_name)
114+
115+
116+
class MemoryFormat:
117+
"""Memory format configuration class."""
118+
119+
_enum: Optional[Enum] = None
120+
121+
@classmethod
122+
def get_installed(cls) -> Enum:
123+
"""Get the installed memory format.
124+
125+
This is the format that can be imported.
126+
127+
Returns
128+
-------
129+
Enum
130+
The memory format installed.
131+
"""
132+
if importlib.util.find_spec("modin"):
133+
return MemoryFormatEnum.MODIN
134+
return MemoryFormatEnum.PANDAS
135+
136+
@classmethod
137+
def get(cls) -> Enum:
138+
"""Get the configured memory format.
139+
140+
This is the memory format currently configured. If None, the installed memory format is returned.
141+
142+
Returns
143+
-------
144+
Enum
145+
The memory format configured.
146+
"""
147+
return cls._enum if cls._enum else cls.get_installed()
148+
149+
@classmethod
150+
def set(cls, name: str) -> None:
151+
"""Set the memory format."""
152+
cls._enum = MemoryFormatEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member
153+
154+
155+
engine: Engine = Engine()
156+
memory_format: MemoryFormat = MemoryFormat()

awswrangler/_threading.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,16 @@
33
import concurrent.futures
44
import itertools
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
6+
from typing import Any, Callable, List, Optional, Union
77

88
import boto3
99

1010
from awswrangler import _utils
11-
from awswrangler._config import ExecutionEngine, config
12-
13-
if config.execution_engine == ExecutionEngine.RAY.value or TYPE_CHECKING:
14-
from awswrangler.distributed.ray._pool import _RayPoolExecutor
11+
from awswrangler._distributed import EngineEnum, engine
1512

1613
_logger: logging.Logger = logging.getLogger(__name__)
1714

1815

19-
def _get_executor(use_threads: Union[bool, int]) -> Union["_ThreadPoolExecutor", "_RayPoolExecutor"]:
20-
return (
21-
_RayPoolExecutor()
22-
if config.execution_engine == ExecutionEngine.RAY.value
23-
else _ThreadPoolExecutor(use_threads) # type: ignore
24-
)
25-
26-
2716
class _ThreadPoolExecutor:
2817
def __init__(self, use_threads: Union[bool, int]):
2918
super().__init__()
@@ -42,3 +31,11 @@ def map(self, func: Callable[..., Any], boto3_session: boto3.Session, *iterables
4231
return list(self._exec.map(func, *args))
4332
# Single-threaded
4433
return list(map(func, *(itertools.repeat(boto3_session), *iterables))) # type: ignore
34+
35+
36+
def _get_executor(use_threads: Union[bool, int]) -> _ThreadPoolExecutor:
37+
if engine.get() == EngineEnum.RAY:
38+
from awswrangler.distributed.ray._pool import _RayPoolExecutor # pylint: disable=import-outside-toplevel
39+
40+
return _RayPoolExecutor() # type: ignore
41+
return _ThreadPoolExecutor(use_threads)

awswrangler/_utils.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
import time
1010
from concurrent.futures import FIRST_COMPLETED, Future, wait
11-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast
11+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast
1212

1313
import boto3
1414
import botocore.config
@@ -19,13 +19,8 @@
1919
from awswrangler import _config, exceptions
2020
from awswrangler.__metadata__ import __version__
2121
from awswrangler._arrow import _table_to_df
22-
from awswrangler._config import ExecutionEngine, MemoryFormat, apply_configs, config
23-
24-
if config.execution_engine == ExecutionEngine.RAY.value or TYPE_CHECKING:
25-
import ray # pylint: disable=unused-import
26-
27-
if config.memory_format == MemoryFormat.MODIN.value:
28-
from awswrangler.distributed.ray._utils import _arrow_refs_to_df # pylint: disable=ungrouped-imports
22+
from awswrangler._config import apply_configs
23+
from awswrangler._distributed import engine
2924

3025
_logger: logging.Logger = logging.getLogger(__name__)
3126

@@ -413,13 +408,10 @@ def check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Di
413408
)
414409

415410

416-
def table_refs_to_df(
417-
tables: Union[List[pa.Table], List["ray.ObjectRef"]], kwargs: Dict[str, Any] # type: ignore
418-
) -> pd.DataFrame:
411+
@engine.dispatch_on_engine
412+
def table_refs_to_df(tables: List[pa.Table], kwargs: Dict[str, Any]) -> pd.DataFrame: # type: ignore
419413
"""Build Pandas dataframe from list of PyArrow tables."""
420-
if isinstance(tables[0], pa.Table):
421-
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
422-
return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore
414+
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
423415

424416

425417
def list_to_arrow_table(
Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1 @@
11
"""Distributed Module."""
2-
3-
from awswrangler.distributed._distributed import ( # noqa
4-
RayLogger,
5-
initialize_ray,
6-
modin_repartition,
7-
ray_get,
8-
ray_remote,
9-
)
10-
11-
__all__ = [
12-
"RayLogger",
13-
"initialize_ray",
14-
"modin_repartition",
15-
"ray_get",
16-
"ray_remote",
17-
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
11
"""Ray Module."""
2+
3+
from awswrangler.distributed.ray._core import RayLogger, initialize_ray, ray_get, ray_remote # noqa
4+
5+
__all__ = [
6+
"RayLogger",
7+
"initialize_ray",
8+
"ray_get",
9+
"ray_remote",
10+
]

0 commit comments

Comments
 (0)