|
| 1 | +"""Distributed engine and memory format configuration.""" |
| 2 | + |
| 3 | +# pylint: disable=import-outside-toplevel |
| 4 | + |
| 5 | +import importlib.util |
| 6 | +from collections import defaultdict |
| 7 | +from enum import Enum, unique |
| 8 | +from functools import wraps |
| 9 | +from typing import Any, Callable, Dict, Optional |
| 10 | + |
| 11 | + |
| 12 | +@unique |
| 13 | +class EngineEnum(Enum): |
| 14 | + """Execution engine enum.""" |
| 15 | + |
| 16 | + RAY = "ray" |
| 17 | + PYTHON = "python" |
| 18 | + |
| 19 | + |
| 20 | +@unique |
| 21 | +class MemoryFormatEnum(Enum): |
| 22 | + """Memory format enum.""" |
| 23 | + |
| 24 | + MODIN = "modin" |
| 25 | + PANDAS = "pandas" |
| 26 | + |
| 27 | + |
| 28 | +class Engine: |
| 29 | + """Execution engine configuration class.""" |
| 30 | + |
| 31 | + _enum: Optional[Enum] = None |
| 32 | + _registry: Dict[str, Dict[str, Callable[..., Any]]] = defaultdict(dict) |
| 33 | + |
| 34 | + @classmethod |
| 35 | + def get_installed(cls) -> Enum: |
| 36 | + """Get the installed distribution engine. |
| 37 | +
|
| 38 | + This is the engine that can be imported. |
| 39 | +
|
| 40 | + Returns |
| 41 | + ------- |
| 42 | + EngineEnum |
| 43 | + The distribution engine installed. |
| 44 | + """ |
| 45 | + if importlib.util.find_spec("ray"): |
| 46 | + return EngineEnum.RAY |
| 47 | + return EngineEnum.PYTHON |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def get(cls) -> Enum: |
| 51 | + """Get the configured distribution engine. |
| 52 | +
|
| 53 | + This is the engine currently configured. If None, the installed engine is returned. |
| 54 | +
|
| 55 | + Returns |
| 56 | + ------- |
| 57 | + str |
| 58 | + The distribution engine configured. |
| 59 | + """ |
| 60 | + return cls._enum if cls._enum else cls.get_installed() |
| 61 | + |
| 62 | + @classmethod |
| 63 | + def set(cls, name: str) -> None: |
| 64 | + """Set the distribution engine.""" |
| 65 | + cls._enum = EngineEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member |
| 66 | + |
| 67 | + @classmethod |
| 68 | + def dispatch_func(cls, source_func: Callable[..., Any], value: Optional[Any] = None) -> Callable[..., Any]: |
| 69 | + """Dispatch a func based on value or the distribution engine and the source function.""" |
| 70 | + try: |
| 71 | + return cls._registry[value or cls.get().value][source_func.__name__] |
| 72 | + except KeyError: |
| 73 | + return getattr(source_func, "_source_func", source_func) |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def register_func(cls, source_func: Callable[..., Any], destination_func: Callable[..., Any]) -> Callable[..., Any]: |
| 77 | + """Register a func based on the distribution engine and source function.""" |
| 78 | + cls._registry[cls.get().value][source_func.__name__] = destination_func |
| 79 | + return destination_func |
| 80 | + |
| 81 | + @classmethod |
| 82 | + def dispatch_on_engine(cls, func: Callable[..., Any]) -> Callable[..., Any]: |
| 83 | + """Dispatch on engine function decorator.""" |
| 84 | + |
| 85 | + @wraps(func) |
| 86 | + def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any: |
| 87 | + return cls.dispatch_func(func)(*args, **kw) |
| 88 | + |
| 89 | + # Save the original function |
| 90 | + wrapper._source_func = func # type: ignore # pylint: disable=protected-access |
| 91 | + return wrapper |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def register(cls, name: Optional[str] = None) -> None: |
| 95 | + """Register the distribution engine dispatch methods.""" |
| 96 | + engine_name = cls.get_installed().value if not name else name |
| 97 | + cls.set(engine_name) |
| 98 | + cls._registry.clear() |
| 99 | + |
| 100 | + if engine_name == EngineEnum.RAY.value: |
| 101 | + from awswrangler.distributed.ray._register import register_ray |
| 102 | + |
| 103 | + register_ray() |
| 104 | + |
| 105 | + @classmethod |
| 106 | + def initialize(cls, name: Optional[str] = None) -> None: |
| 107 | + """Initialize the distribution engine.""" |
| 108 | + engine_name = cls.get_installed().value if not name else name |
| 109 | + if engine_name == EngineEnum.RAY.value: |
| 110 | + from awswrangler.distributed.ray import initialize_ray |
| 111 | + |
| 112 | + initialize_ray() |
| 113 | + cls.register(engine_name) |
| 114 | + |
| 115 | + |
| 116 | +class MemoryFormat: |
| 117 | + """Memory format configuration class.""" |
| 118 | + |
| 119 | + _enum: Optional[Enum] = None |
| 120 | + |
| 121 | + @classmethod |
| 122 | + def get_installed(cls) -> Enum: |
| 123 | + """Get the installed memory format. |
| 124 | +
|
| 125 | + This is the format that can be imported. |
| 126 | +
|
| 127 | + Returns |
| 128 | + ------- |
| 129 | + Enum |
| 130 | + The memory format installed. |
| 131 | + """ |
| 132 | + if importlib.util.find_spec("modin"): |
| 133 | + return MemoryFormatEnum.MODIN |
| 134 | + return MemoryFormatEnum.PANDAS |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def get(cls) -> Enum: |
| 138 | + """Get the configured memory format. |
| 139 | +
|
| 140 | + This is the memory format currently configured. If None, the installed memory format is returned. |
| 141 | +
|
| 142 | + Returns |
| 143 | + ------- |
| 144 | + Enum |
| 145 | + The memory format configured. |
| 146 | + """ |
| 147 | + return cls._enum if cls._enum else cls.get_installed() |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def set(cls, name: str) -> None: |
| 151 | + """Set the memory format.""" |
| 152 | + cls._enum = MemoryFormatEnum._member_map_[name.upper()] # pylint: disable=protected-access,no-member |
| 153 | + |
| 154 | + |
| 155 | +engine: Engine = Engine() |
| 156 | +memory_format: MemoryFormat = MemoryFormat() |
0 commit comments