Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions mesa/experimental/altair_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
from typing import Callable, Optional

import altair as alt
import solara

import mesa


def get_agent_data_from_coord_iter(data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my current implementation, I'm using an agent_portrayal method to generate the values needed to draw the space. That may be drawn from the old mesa visualization approach, IDK if there's a good way to pass in something like that to jupyterviz.

I mention because I wonder if it would be cleaner and more explicit than the way you're using json to dump and filter the agent dict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @rlskoeser; the JSON approach is not very clean. Is there a more explicit way to do this?

for agent, (x, y) in data:
if agent:
agent_data = json.loads(
json.dumps(agent[0].__dict__, skipkeys=True, default=str)
)
agent_data["x"] = x
agent_data["y"] = y
agent_data.pop("model", None)
agent_data.pop("pos", None)
yield agent_data


def create_grid(
color: Optional[str] = None,
on_click: Optional[Callable[[mesa.Model, mesa.space.Coordinate], None]] = None,
) -> Callable[[mesa.Model], solara.component]:
return lambda model: Grid(model, color, on_click)


def Grid(model, color=None, on_click=None):
if color is None:
color = "unique_id:N"

if color[-2] != ":":
color = color + ":N"

print(model.grid.coord_iter())

data = solara.reactive(
list(get_agent_data_from_coord_iter(model.grid.coord_iter()))
)

def update_data():
data.value = list(get_agent_data_from_coord_iter(model.grid.coord_iter()))

def click_handler(datum):
if datum is None:
return
on_click(model, datum["x"], datum["y"])
update_data()

default_tooltip = [f"{key}:N" for key in data.value[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice that you have tooltip and click handlers - a little hard to assess without documentation, I would want to know how I can customize these

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please add docstrings to this explaining how to customize the tooltip

chart = (
alt.Chart(alt.Data(values=data.value))
.mark_rect()
.encode(
x=alt.X("x:N", scale=alt.Scale(domain=list(range(model.grid.width)))),
y=alt.Y(
"y:N",
scale=alt.Scale(domain=list(range(model.grid.height - 1, -1, -1))),
),
color=color,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of my models where I'm using a custom altair space drawer, I'm setting color, size, and shape. Probably reasonable not to support all of those on the first pass, but it would be good to think about a more generalized approach (like the agent portrayal method) that would make it possible to customize this without having to completely re-implement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be pretty straightforward to do with the create_grid function

tooltip=default_tooltip,
)
.properties(width=600, height=600)
)
return solara.FigureAltair(chart, on_click=click_handler)
27 changes: 25 additions & 2 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from solara.alias import rv

import mesa
from mesa.experimental.altair_grid import create_grid

# Avoid interactive backend
plt.switch_backend("agg")
Expand Down Expand Up @@ -75,6 +76,12 @@ def ColorCard(color, layout_type):
SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer == "altair":
# draw with the altair implementation
SpaceAltair(
model, agent_portrayal, dependencies=[current_step.value]
)

elif space_drawer:
# if specified, draw agent space with an alternate renderer
space_drawer(model, agent_portrayal)
Expand Down Expand Up @@ -109,6 +116,9 @@ def render_in_jupyter():
SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer == "altair":
# draw with the default implementation
SpaceAltair(model, agent_portrayal, dependencies=[current_step.value])
elif space_drawer:
# if specified, draw agent space with an alternate renderer
space_drawer(model, agent_portrayal)
Expand All @@ -123,7 +133,7 @@ def render_in_jupyter():
else:
make_plot(model, measure)

def render_in_browser():
def render_in_browser(statistics=False):
# if space drawer is disabled, do not include it
layout_types = [{"Space": "default"}] if space_drawer else []

Expand All @@ -139,6 +149,13 @@ def render_in_browser():
ModelController(model, play_interval, current_step, reset_counter)
with solara.Card("Progress", margin=1, elevation=2):
solara.Markdown(md_text=f"####Step - {current_step}")
with solara.Card("Analytics", margin=1, elevation=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the PR.

if statistics:
df = model.datacollector.get_model_vars_dataframe()
for col in list(df.columns):
solara.Markdown(
md_text=f"####Avg. {col} - {df.loc[:, f'{col}'].mean()}"
)

items = [
ColorCard(color="white", layout_type=layout_types[i])
Expand Down Expand Up @@ -334,6 +351,12 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[List[any]] =
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: Optional[List[any]] = None):
grid = create_grid(color="wealth")
grid(model)


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
Expand Down Expand Up @@ -424,7 +447,7 @@ def get_initial_grid_layout(layout_types):
grid_lay = []
y_coord = 0
for ii in range(len(layout_types)):
template_layout = {"h": 10, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0}
template_layout = {"h": 20, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0}
if ii == 0:
grid_lay.append(template_layout)
else:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"pandas",
"solara",
"tqdm",
"altair"
]
dynamic = ["version"]

Expand Down
13 changes: 13 additions & 0 deletions tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def test_call_space_drawer(self, mock_space_matplotlib):
agent_portrayal=agent_portrayal,
)
)
solara.render(
JupyterViz(
model_class=mock_model_class,
model_params={},
agent_portrayal=agent_portrayal,
space_drawer="altair",
)
)
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(
mock_model_class.return_value, agent_portrayal, dependencies=dependencies
Expand Down Expand Up @@ -132,3 +140,8 @@ def test_call_space_drawer(self, mock_space_matplotlib):
altspace_drawer.assert_called_with(
mock_model_class.return_value, agent_portrayal
)


if __name__ == "__main__":
tjv = TestJupyterViz()
tjv.test_call_space_drawer()