diff --git a/aiodistbus/cfg.py b/aiodistbus/cfg.py index 1219cfb..a1ddc06 100644 --- a/aiodistbus/cfg.py +++ b/aiodistbus/cfg.py @@ -1,6 +1,18 @@ import json +import pathlib from typing import Callable, Dict, List, Optional, Type, TypeVar, Union +import dataclasses_json.cfg + +dataclasses_json.cfg.global_config.encoders[pathlib.Path] = str +dataclasses_json.cfg.global_config.decoders[ + pathlib.Path +] = pathlib.Path # is this necessary? +dataclasses_json.cfg.global_config.encoders[Optional[pathlib.Path]] = str +dataclasses_json.cfg.global_config.decoders[Optional[pathlib.Path]] = Optional[ + pathlib.Path +] # is this necessary? + T = TypeVar("T") ######################################################################## @@ -34,6 +46,7 @@ def __init__(self): bytes: lambda x: x, Json: lambda x: json.loads(x.decode()), } + self.dtype_map: Dict[str, str] = {} def get_encoder(self, dtype: Union[Type, Optional[Type]]) -> SFunction[Type]: """Get encoder for type @@ -75,6 +88,29 @@ def get_decoder(self, dtype: Union[Type, Optional[Type]]) -> DFunction[Type]: else: raise ValueError(f"Decoder not found for {dtype}") + def set_dtype_mapping(self, k: str, v: str): + """Set dtype mapping + + Args: + dtype (str): dtype + + """ + self.dtype_map[k] = v + + def get_dtype_mapping(self, dtype: str) -> Optional[str]: + """Get dtype mapping + + Args: + dtype (str): dtype + + Returns: + str: dtype + + """ + if dtype in self.dtype_map: + return self.dtype_map[dtype] + return None + global_config = _GlobalConfig() diff --git a/aiodistbus/eventbus/eventbus.py b/aiodistbus/eventbus/eventbus.py index d535bb1..663200b 100644 --- a/aiodistbus/eventbus/eventbus.py +++ b/aiodistbus/eventbus/eventbus.py @@ -154,8 +154,6 @@ async def listen(self, ip: str, port: int, event_types: Optional[List[str]] = No await e.connect(ip, port) for event_type in event_types: - # import pdb; pdb.set_trace() - async def _wrapper(event): await self._emit(event) @@ -168,11 +166,11 @@ async def link( self, ip: str, port: int, - to_event_types: Optional[List[str]] = None, - from_event_types: Optional[List[str]] = None, + send: Optional[List[str]] = None, + recv: Optional[List[str]] = None, ): - await self.listen(ip, port, event_types=from_event_types) - await self.forward(ip, port, event_types=to_event_types) + await self.forward(ip, port, event_types=send) + await self.listen(ip, port, event_types=recv) async def close(self): """Close the eventbus""" diff --git a/aiodistbus/utils.py b/aiodistbus/utils.py index ebbb4a3..ff67605 100644 --- a/aiodistbus/utils.py +++ b/aiodistbus/utils.py @@ -133,6 +133,10 @@ async def reconstruct(event_str: str, dtype: Optional[Type] = None) -> Event: if dtype: event = reconstruct_event_data(event, dtype) elif event.dtype and event.dtype != "builtins.NoneType": + # Handle dtype mapping + dtype_mapping = global_config.get_dtype_mapping(event.dtype) + if dtype_mapping: + event.dtype = dtype_mapping l_dtype = locate(event.dtype) event = reconstruct_event_data(event, l_dtype) # type: ignore diff --git a/aiodistbus/wrapper.py b/aiodistbus/wrapper.py index d95f735..586c40b 100644 --- a/aiodistbus/wrapper.py +++ b/aiodistbus/wrapper.py @@ -104,6 +104,7 @@ def make_evented( """ instance.bus = bus # type: ignore[attr-defined] instance.__evented_values = {} # type: ignore[attr-defined] + instance.__dataclass = f"{instance.__class__.__module__}.{instance.__class__.__name__}" # type: ignore[attr-defined] # Name of the event if not event_name: diff --git a/test/test_bridge.py b/test/test_bridge.py index 71ca939..bb7bbc9 100644 --- a/test/test_bridge.py +++ b/test/test_bridge.py @@ -283,18 +283,18 @@ async def test_local_buses_comms_bidirectional(event_type, func, dtype, dtype_in # Create entrypoint ce = EntryPoint() await ce.connect(cb) - await ce.on(f"client.{event_type}", func, dtype) + await ce.on(f"server.{event_type}", func, dtype) se = EntryPoint() await se.connect(sb) - await se.on(f"server.{event_type}", func, dtype) + await se.on(f"client.{event_type}", func, dtype) # Link - await cb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"]) - await sb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"]) + await cb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"]) + await sb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"]) # Send message - cevent = await ce.emit(f"server.{event_type}", dtype_instance) - sevent = await se.emit(f"client.{event_type}", dtype_instance) + cevent = await ce.emit(f"client.{event_type}", dtype_instance) + sevent = await se.emit(f"server.{event_type}", dtype_instance) # Flush await sdbus.flush() diff --git a/test/test_dbus.py b/test/test_dbus.py index 026c3ef..c96c194 100644 --- a/test/test_dbus.py +++ b/test/test_dbus.py @@ -3,7 +3,7 @@ import pytest -from aiodistbus import DEntryPoint, Event +from aiodistbus import DEntryPoint, Event, make_evented from .conftest import ( ExampleEvent, @@ -81,6 +81,29 @@ async def test_dbus_emit(dbus, dentrypoints, event_type, func, dtype, dtype_inst assert event1 and event1.id in e1._received +async def test_dbus_emit_evented_dataclass(bus, dbus, dentrypoints): + + # Create resources + e1, e2 = dentrypoints + + # Add funcs + await e1.on("test", func, ExampleEvent) + + # Connect + await e1.connect(dbus.ip, dbus.port) + await e2.connect(dbus.ip, dbus.port) + + # Send message + instance = make_evented(ExampleEvent("Hello"), bus=bus) + event1 = await e2.emit("test", instance) + + # Need to flush + await dbus.flush() + + # Assert + assert event1 and event1.id in e1._received + + @pytest.mark.parametrize( "event_type, func, dtype_instance", [ diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..ec958eb --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,41 @@ +import pathlib +from dataclasses import dataclass + +import pytest +from dataclasses_json import DataClassJsonMixin + +from aiodistbus.protocols import Event +from aiodistbus.utils import encode, reconstruct + +from .conftest import ExampleEvent + + +@dataclass +class ExampleEventWithPathlib(DataClassJsonMixin): + path: pathlib.Path + + +async def test_encode_dataclass_with_pathlib(): + data = ExampleEventWithPathlib(pathlib.Path(".")) + assert data.to_json() + + +@pytest.mark.parametrize( + "data", [ExampleEvent("Hello"), ExampleEventWithPathlib(pathlib.Path("."))] +) +async def test_encode_decode(data): + + # Encode data + encoded_data = encode(data) + + # Obtain the dtype and format it + dtype = type(data) + dtype_str = f"{dtype.__module__}.{dtype.__name__}" + + event = Event("test", encoded_data, dtype=dtype_str, id="") + ser_event = event.to_json().encode() + + decoded_event = ser_event.decode() + deser_event = await reconstruct(decoded_event) + + assert deser_event.data == data diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 741f047..96828d4 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -49,6 +49,7 @@ async def test_make_evented(bus, entrypoints): # Create the evented class data = make_evented(SomeClass(number=1, string="hello"), bus=bus) + logger.debug(type(data)) # Trigger an event by changing the class data.number = 2