Skip to content

Dynamic colorscale to Geoplot Visualizer #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 82 additions & 85 deletions geoplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,17 @@
from string import Template
from agent_torch.core.helpers import get_by_path

# HTML/JS template string is being used here to render cesium-based 3D visualisation.
# Populated with actual data, timestamps, etc.
geoplot_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta
name="viewport"
content="width=device-width, initial-scale=1.0"
/>
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Cesium Time-Series Heatmap Visualization</title>
<script src="https://cesium.com/downloads/cesiumjs/releases/1.95/Build/Cesium/Cesium.js"></script>
<link
href="https://cesium.com/downloads/cesiumjs/releases/1.95/Build/Cesium/Widgets/widgets.css"
rel="stylesheet"
/>
<link href="https://cesium.com/downloads/cesiumjs/releases/1.95/Build/Cesium/Widgets/widgets.css" rel="stylesheet" />
<style>
#cesiumContainer {
width: 100%;
Expand All @@ -65,17 +61,16 @@
<body>
<div id="cesiumContainer"></div>
<script>
// Your Cesium ion access token here
Cesium.Ion.defaultAccessToken = '$accessToken'
const globalMinValue = $minValue;
const globalMaxValue = $maxValue;

// Create the viewer
const viewer = new Cesium.Viewer('cesiumContainer')

function interpolateColor(color1, color2, factor) {
const result = new Cesium.Color()
result.red = color1.red + factor * (color2.red - color1.red)
result.green =
color1.green + factor * (color2.green - color1.green)
result.green = color1.green + factor * (color2.green - color1.green)
result.blue = color1.blue + factor * (color2.blue - color1.blue)
result.alpha = '$visualType' == 'size' ? 0.2 :
color1.alpha + factor * (color2.alpha - color1.alpha)
Expand All @@ -98,46 +93,34 @@

function processTimeSeriesData(geoJsonData) {
const timeSeriesMap = new Map()
let minValue = Infinity
let maxValue = -Infinity

geoJsonData.features.forEach((feature) => {
const id = feature.properties.id
const time = Cesium.JulianDate.fromIso8601(
feature.properties.time
)
const time = Cesium.JulianDate.fromIso8601(feature.properties.time)
const value = feature.properties.value
const coordinates = feature.geometry.coordinates

if (!timeSeriesMap.has(id)) {
timeSeriesMap.set(id, [])
}
timeSeriesMap.get(id).push({ time, value, coordinates })

minValue = Math.min(minValue, value)
maxValue = Math.max(maxValue, value)
})

return { timeSeriesMap, minValue, maxValue }
return {
timeSeriesMap,
minValue: globalMinValue,
maxValue: globalMaxValue,
}
}

function createTimeSeriesEntities(
timeSeriesData,
startTime,
stopTime
) {
const dataSource = new Cesium.CustomDataSource(
'AgentTorch Simulation'
)
function createTimeSeriesEntities(timeSeriesData, startTime, stopTime) {
const dataSource = new Cesium.CustomDataSource('AgentTorch Simulation')

for (const [id, timeSeries] of timeSeriesData.timeSeriesMap) {
const entity = new Cesium.Entity({
id: id,
availability: new Cesium.TimeIntervalCollection([
new Cesium.TimeInterval({
start: startTime,
stop: stopTime,
}),
new Cesium.TimeInterval({ start: startTime, stop: stopTime }),
]),
position: new Cesium.SampledPositionProperty(),
point: {
Expand All @@ -150,30 +133,19 @@
})

timeSeries.forEach(({ time, value, coordinates }) => {
const position = Cesium.Cartesian3.fromDegrees(
coordinates[0],
coordinates[1]
)
const position = Cesium.Cartesian3.fromDegrees(coordinates[0], coordinates[1])
entity.position.addSample(time, position)
entity.properties.value.addSample(time, value)
entity.point.color.addSample(
time,
getColor(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
getColor(value, timeSeriesData.minValue, timeSeriesData.maxValue)
)

if ('$visualType' == 'size') {
entity.point.pixelSize.addSample(
time,
getPixelSize(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
)
entity.point.pixelSize.addSample(
time,
getPixelSize(value, timeSeriesData.minValue, timeSeriesData.maxValue)
)
}
})

Expand All @@ -183,27 +155,21 @@
return dataSource
}

// Example time-series GeoJSON data
const geoJsons = $data

const start = Cesium.JulianDate.fromIso8601('$startTime')
const stop = Cesium.JulianDate.fromIso8601('$stopTime')

viewer.clock.startTime = start.clone()
viewer.clock.stopTime = stop.clone()
viewer.clock.currentTime = start.clone()
viewer.clock.clockRange = Cesium.ClockRange.LOOP_STOP
viewer.clock.multiplier = 3600 // 1 hour per second
viewer.clock.multiplier = 3600

viewer.timeline.zoomTo(start, stop)

for (const geoJsonData of geoJsons) {
const timeSeriesData = processTimeSeriesData(geoJsonData)
const dataSource = createTimeSeriesEntities(
timeSeriesData,
start,
stop
)
const dataSource = createTimeSeriesEntities(timeSeriesData, start, stop)
viewer.dataSources.add(dataSource)
viewer.zoomTo(dataSource)
}
Expand All @@ -214,10 +180,20 @@


def read_var(state, var):
'''
Helper function to get access to nested fields in the state using '/'

'''
return get_by_path(state, re.split("/", var))


class GeoPlot:
'''
init function for initializing the geoplot visualizer.
Args:
config: dictionary containing simulation metadata like num_episodes, num_steps...etc.
options: dictionary containing with required config like cesium_token, feature, type of visualization,..etc.
'''
def __init__(self, config, options):
self.config = config
(
Expand All @@ -233,16 +209,26 @@ def __init__(self, config, options):
options["feature"],
options["visualization_type"],
)
self.auto_color_scale = options.get("auto_color_scale", True)
self.manual_min_value = options.get("min_value", None)
self.manual_max_value= options.get("max_value", None)

def render(self, state_trajectory):
'''
Converts simulation trajectory into GeoJSON and HTML visualization.
Args:
state_trajectory: A 2D list representing simulation states across episodes & steps.
'''
coords, values = [], []
name = self.config["simulation_metadata"]["name"]
geodata_path, geoplot_path = f"{name}.geojson", f"{name}.html"
name = self.config["simulation_metadata"]["name"] # Extracting output file names from simulation metadata
geodata_path, geoplot_path = f"{name}.geojson", f"{name}.html"

for i in range(0, len(state_trajectory) - 1):
final_state = state_trajectory[i][-1]
# Only get the final state ONCE for coordinates (assumes agents don't move)
final_state = state_trajectory[0][-1]
coords = np.array(read_var(final_state, self.entity_position)).tolist()

coords = np.array(read_var(final_state, self.entity_position)).tolist()
for i in range(len(state_trajectory)):
final_state = state_trajectory[i][-1]
values.append(
np.array(read_var(final_state, self.entity_property)).flatten().tolist()
)
Expand All @@ -256,38 +242,49 @@ def render(self, state_trajectory):
)
]

# Flatten values and compute min/max if needed
all_values = [v for episode_values in values for v in episode_values]
if self.auto_color_scale:
min_value = min(all_values)
max_value = max(all_values)
else:
min_value = self.manual_min_value
max_value = self.manual_max_value

# Generate GeoJSON time series
geojsons = []
for i, coord in enumerate(coords):
features = []
for time, value_list in zip(timestamps, values):
features.append(
{
"type": "Feature",
"geometry": {
"type": "Point",
"coordinates": [coord[1], coord[0]],
},
"properties": {
"value": value_list[i],
"time": time.isoformat(),
},
}
)
features.append({
"type": "Feature",
"geometry": {
"type": "Point",
"coordinates": [coord[1], coord[0]],
},
"properties": {
"id": i, # Required for Cesium
"value": value_list[i],
"time": time.isoformat(),
},
})
geojsons.append({"type": "FeatureCollection", "features": features})

# Write GeoJSON to disk
with open(geodata_path, "w", encoding="utf-8") as f:
json.dump(geojsons, f, ensure_ascii=False, indent=2)

# Render the HTML visualization
tmpl = Template(geoplot_template)
with open(geoplot_path, "w", encoding="utf-8") as f:
f.write(
tmpl.substitute(
{
"accessToken": self.cesium_token,
"startTime": timestamps[0].isoformat(),
"stopTime": timestamps[-1].isoformat(),
"data": json.dumps(geojsons),
"visualType": self.visualization_type,
}
)
tmpl.substitute({
"accessToken": self.cesium_token,
"startTime": timestamps[0].isoformat(),
"stopTime": timestamps[-1].isoformat(),
"data": json.dumps(geojsons),
"visualType": self.visualization_type,
"min_value": min_value,
"max_value": max_value
})
)
34 changes: 18 additions & 16 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ An example of its usage is as follows:
```py
from agent_torch.visualize import GeoPlot

# create a simulation
# ...

# create a visualizer
geoplot = GeoPlot(config, cesium_token)

# visualize in the runner-loop
for i in range(0, num_episodes):
runner.step(num_steps_per_episode)

geoplot.visualize(
name = f"consumer-money-spent-{i}",
state_trajectory = runner.state_trajectory,
entity_position = "consumers/coordinates",
entity_property = "consumers/money_spent",
)
# Initialize the visualizer
geoplot = GeoPlot(
config=config,
cesium_token="your-cesium-ion-access-token",
visualization_type="color", # or "size"
auto_color_scale=True, # or False
manual_min_value=0, # required if auto_color_scale=False
manual_max_value=100 # required if auto_color_scale=False
)

# Inside the runner loop
geoplot.visualize(
name=f"consumer-money-spent-{i}",
state_trajectory=runner.state_trajectory,
entity_position="consumers/coordinates",
entity_property="consumers/money_spent"
)

```