Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions aiodistbus/cfg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import json
import pathlib
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union

import dataclasses_json.cfg

dataclasses_json.cfg.global_config.encoders[pathlib.Path] = str
dataclasses_json.cfg.global_config.decoders[
pathlib.Path
] = pathlib.Path # is this necessary?
dataclasses_json.cfg.global_config.encoders[Optional[pathlib.Path]] = str
dataclasses_json.cfg.global_config.decoders[Optional[pathlib.Path]] = Optional[
pathlib.Path
] # is this necessary?

T = TypeVar("T")

########################################################################
Expand Down Expand Up @@ -34,6 +46,7 @@ def __init__(self):
bytes: lambda x: x,
Json: lambda x: json.loads(x.decode()),
}
self.dtype_map: Dict[str, str] = {}

def get_encoder(self, dtype: Union[Type, Optional[Type]]) -> SFunction[Type]:
"""Get encoder for type
Expand Down Expand Up @@ -75,6 +88,29 @@ def get_decoder(self, dtype: Union[Type, Optional[Type]]) -> DFunction[Type]:
else:
raise ValueError(f"Decoder not found for {dtype}")

def set_dtype_mapping(self, k: str, v: str):
"""Set dtype mapping

Args:
dtype (str): dtype

"""
self.dtype_map[k] = v

def get_dtype_mapping(self, dtype: str) -> Optional[str]:
"""Get dtype mapping

Args:
dtype (str): dtype

Returns:
str: dtype

"""
if dtype in self.dtype_map:
return self.dtype_map[dtype]
return None


global_config = _GlobalConfig()

Expand Down
10 changes: 4 additions & 6 deletions aiodistbus/eventbus/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ async def listen(self, ip: str, port: int, event_types: Optional[List[str]] = No
await e.connect(ip, port)
for event_type in event_types:

# import pdb; pdb.set_trace()

async def _wrapper(event):
await self._emit(event)

Expand All @@ -168,11 +166,11 @@ async def link(
self,
ip: str,
port: int,
to_event_types: Optional[List[str]] = None,
from_event_types: Optional[List[str]] = None,
send: Optional[List[str]] = None,
recv: Optional[List[str]] = None,
):
await self.listen(ip, port, event_types=from_event_types)
await self.forward(ip, port, event_types=to_event_types)
await self.forward(ip, port, event_types=send)
await self.listen(ip, port, event_types=recv)

async def close(self):
"""Close the eventbus"""
Expand Down
4 changes: 4 additions & 0 deletions aiodistbus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ async def reconstruct(event_str: str, dtype: Optional[Type] = None) -> Event:
if dtype:
event = reconstruct_event_data(event, dtype)
elif event.dtype and event.dtype != "builtins.NoneType":
# Handle dtype mapping
dtype_mapping = global_config.get_dtype_mapping(event.dtype)
if dtype_mapping:
event.dtype = dtype_mapping
l_dtype = locate(event.dtype)
event = reconstruct_event_data(event, l_dtype) # type: ignore

Expand Down
1 change: 1 addition & 0 deletions aiodistbus/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def make_evented(
"""
instance.bus = bus # type: ignore[attr-defined]
instance.__evented_values = {} # type: ignore[attr-defined]
instance.__dataclass = f"{instance.__class__.__module__}.{instance.__class__.__name__}" # type: ignore[attr-defined]

# Name of the event
if not event_name:
Expand Down
12 changes: 6 additions & 6 deletions test/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,18 @@ async def test_local_buses_comms_bidirectional(event_type, func, dtype, dtype_in
# Create entrypoint
ce = EntryPoint()
await ce.connect(cb)
await ce.on(f"client.{event_type}", func, dtype)
await ce.on(f"server.{event_type}", func, dtype)
se = EntryPoint()
await se.connect(sb)
await se.on(f"server.{event_type}", func, dtype)
await se.on(f"client.{event_type}", func, dtype)

# Link
await cb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"])
await sb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"])
await cb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"])
await sb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"])

# Send message
cevent = await ce.emit(f"server.{event_type}", dtype_instance)
sevent = await se.emit(f"client.{event_type}", dtype_instance)
cevent = await ce.emit(f"client.{event_type}", dtype_instance)
sevent = await se.emit(f"server.{event_type}", dtype_instance)

# Flush
await sdbus.flush()
Expand Down
25 changes: 24 additions & 1 deletion test/test_dbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from aiodistbus import DEntryPoint, Event
from aiodistbus import DEntryPoint, Event, make_evented

from .conftest import (
ExampleEvent,
Expand Down Expand Up @@ -81,6 +81,29 @@ async def test_dbus_emit(dbus, dentrypoints, event_type, func, dtype, dtype_inst
assert event1 and event1.id in e1._received


async def test_dbus_emit_evented_dataclass(bus, dbus, dentrypoints):

# Create resources
e1, e2 = dentrypoints

# Add funcs
await e1.on("test", func, ExampleEvent)

# Connect
await e1.connect(dbus.ip, dbus.port)
await e2.connect(dbus.ip, dbus.port)

# Send message
instance = make_evented(ExampleEvent("Hello"), bus=bus)
event1 = await e2.emit("test", instance)

# Need to flush
await dbus.flush()

# Assert
assert event1 and event1.id in e1._received


@pytest.mark.parametrize(
"event_type, func, dtype_instance",
[
Expand Down
41 changes: 41 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pathlib
from dataclasses import dataclass

import pytest
from dataclasses_json import DataClassJsonMixin

from aiodistbus.protocols import Event
from aiodistbus.utils import encode, reconstruct

from .conftest import ExampleEvent


@dataclass
class ExampleEventWithPathlib(DataClassJsonMixin):
path: pathlib.Path


async def test_encode_dataclass_with_pathlib():
data = ExampleEventWithPathlib(pathlib.Path("."))
assert data.to_json()


@pytest.mark.parametrize(
"data", [ExampleEvent("Hello"), ExampleEventWithPathlib(pathlib.Path("."))]
)
async def test_encode_decode(data):

# Encode data
encoded_data = encode(data)

# Obtain the dtype and format it
dtype = type(data)
dtype_str = f"{dtype.__module__}.{dtype.__name__}"

event = Event("test", encoded_data, dtype=dtype_str, id="")
ser_event = event.to_json().encode()

decoded_event = ser_event.decode()
deser_event = await reconstruct(decoded_event)

assert deser_event.data == data
1 change: 1 addition & 0 deletions test/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def test_make_evented(bus, entrypoints):

# Create the evented class
data = make_evented(SomeClass(number=1, string="hello"), bus=bus)
logger.debug(type(data))

# Trigger an event by changing the class
data.number = 2
Expand Down