Skip to content

Update geoplot.py #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
343 changes: 97 additions & 246 deletions geoplot.py
Original file line number Diff line number Diff line change
@@ -1,293 +1,144 @@
"""
geoplot.py
----------

This visualization renders a 3-D plot of the data given the state
trajectory of a simulation, and the path of the property to render.

It generates an HTML file that contains code to render the plot
using Cesium Ion, and the GeoJSON file of data provided to the plot.

An example of its usage is as follows:

```py
from agent_torch.visualize import GeoPlot

# create a simulation
# ...

# create a visualizer
engine = GeoPlot(config, {
cesium_token: "...",
step_time: 3600,
coordinates = "agents/consumers/coordinates",
feature = "agents/consumers/money_spent",
})

# visualize in the runner-loop
for i in range(0, num_episodes):
runner.step(num_steps_per_episode)
engine.render(runner.state_trajectory)
```
"""

import re
import json

import pandas as pd
import numpy as np

from string import Template
from agent_torch.core.helpers import get_by_path
@@ -26,7 +25,6 @@
# HTML template for Cesium visualization
# This template defines the structure and behavior of the Cesium-based visualization.
# It includes functions for interpolating colors, determining pixel sizes, and processing time-series data.

geoplot_template = """
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta
name="viewport"
content="width=device-width, initial-scale=1.0"
/>
<title>Cesium Time-Series Heatmap Visualization</title>
<script src="https://cesium.com/downloads/cesiumjs/releases/1.95/Build/Cesium/Cesium.js"></script>
<link
href="https://cesium.com/downloads/cesiumjs/releases/1.95/Build/Cesium/Widgets/widgets.css"
rel="stylesheet"
/>
<style>
#cesiumContainer {
width: 100%;
height: 100%;
}
</style>
</head>
<body>
<div id="cesiumContainer"></div>
<script>
// Your Cesium ion access token here
Cesium.Ion.defaultAccessToken = '$accessToken'

// Create the viewer
const viewer = new Cesium.Viewer('cesiumContainer')

function interpolateColor(color1, color2, factor) {
const result = new Cesium.Color()
result.red = color1.red + factor * (color2.red - color1.red)
result.green =
color1.green + factor * (color2.green - color1.green)
result.blue = color1.blue + factor * (color2.blue - color1.blue)
result.alpha = '$visualType' == 'size' ? 0.2 :
color1.alpha + factor * (color2.alpha - color1.alpha)
return result
}

function getColor(value, min, max) {
const factor = (value - min) / (max - min)
return interpolateColor(
Cesium.Color.BLUE,
Cesium.Color.RED,
factor
)
}

function getPixelSize(value, min, max) {
const factor = (value - min) / (max - min)
return 100 * (1 + factor)
}

function processTimeSeriesData(geoJsonData) {
const timeSeriesMap = new Map()
let minValue = Infinity
let maxValue = -Infinity

geoJsonData.features.forEach((feature) => {
const id = feature.properties.id
const time = Cesium.JulianDate.fromIso8601(
feature.properties.time
)
const value = feature.properties.value
const coordinates = feature.geometry.coordinates

if (!timeSeriesMap.has(id)) {
timeSeriesMap.set(id, [])
}
timeSeriesMap.get(id).push({ time, value, coordinates })

minValue = Math.min(minValue, value)
maxValue = Math.max(maxValue, value)
})

return { timeSeriesMap, minValue, maxValue }
}

function createTimeSeriesEntities(
timeSeriesData,
startTime,
stopTime
) {
const dataSource = new Cesium.CustomDataSource(
'AgentTorch Simulation'
)

for (const [id, timeSeries] of timeSeriesData.timeSeriesMap) {
const entity = new Cesium.Entity({
id: id,
availability: new Cesium.TimeIntervalCollection([
new Cesium.TimeInterval({
start: startTime,
stop: stopTime,
}),
]),
position: new Cesium.SampledPositionProperty(),
point: {
pixelSize: '$visualType' == 'size' ? new Cesium.SampledProperty(Number) : 10,
color: new Cesium.SampledProperty(Cesium.Color),
},
properties: {
value: new Cesium.SampledProperty(Number),
},
})

timeSeries.forEach(({ time, value, coordinates }) => {
const position = Cesium.Cartesian3.fromDegrees(
coordinates[0],
coordinates[1]
)
entity.position.addSample(time, position)
entity.properties.value.addSample(time, value)
entity.point.color.addSample(
time,
getColor(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
)

if ('$visualType' == 'size') {
entity.point.pixelSize.addSample(
time,
getPixelSize(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
)
}
})

dataSource.entities.add(entity)
}

return dataSource
}

// Example time-series GeoJSON data
const geoJsons = $data

const start = Cesium.JulianDate.fromIso8601('$startTime')
const stop = Cesium.JulianDate.fromIso8601('$stopTime')

viewer.clock.startTime = start.clone()
viewer.clock.stopTime = stop.clone()
viewer.clock.currentTime = start.clone()
viewer.clock.clockRange = Cesium.ClockRange.LOOP_STOP
viewer.clock.multiplier = 3600 // 1 hour per second

viewer.timeline.zoomTo(start, stop)

for (const geoJsonData of geoJsons) {
const timeSeriesData = processTimeSeriesData(geoJsonData)
const dataSource = createTimeSeriesEntities(
timeSeriesData,
start,
stop
)
viewer.dataSources.add(dataSource)
viewer.zoomTo(dataSource)
}
</script>
</body>
@@ -41,9 +39,11 @@
<body>
<div id="cesiumContainer"></div>
<script>
// Setup Cesium viewer
Cesium.Ion.defaultAccessToken = '$accessToken';
const viewer = new Cesium.Viewer('cesiumContainer');
// Interpolate between two colors based on a factor
function interpolateColor(color1, color2, factor) {
const result = new Cesium.Color();
result.red = color1.red + factor * (color2.red - color1.red);
@@ -53,16 +53,19 @@
return result;
}
// Get color based on value (between blue and red)
function getColor(value, min, max) {
const factor = (value - min) / (max - min);
return interpolateColor(Cesium.Color.BLUE, Cesium.Color.RED, factor);
}
// Get pixel size based on value
function getPixelSize(value, min, max) {
const factor = (value - min) / (max - min);
return 100 * (1 + factor);
}
// Process GeoJSON data to build a time-series map
function processTimeSeriesData(geoJsonData) {
const timeSeriesMap = new Map();
let minValue = Infinity;
@@ -86,6 +89,7 @@
return { timeSeriesMap, minValue, maxValue };
}
// Create Cesium entities for the time series
function createTimeSeriesEntities(timeSeriesData, startTime, stopTime) {
const dataSource = new Cesium.CustomDataSource('AgentTorch Simulation');
@@ -105,6 +109,7 @@
},
});
// Add each time sample
timeSeries.forEach(({ time, value, coordinates }) => {
const position = Cesium.Cartesian3.fromDegrees(coordinates[0], coordinates[1]);
entity.position.addSample(time, position);
@@ -121,6 +126,7 @@
return dataSource;
}
// Load and visualize the time series data
const geoJsons = $data;
const start = Cesium.JulianDate.fromIso8601('$startTime');
const stop = Cesium.JulianDate.fromIso8601('$stopTime');
@@ -133,6 +139,7 @@
viewer.timeline.zoomTo(start, stop);
// Load all GeoJSON datasets
for (const geoJsonData of geoJsons) {
const timeSeriesData = processTimeSeriesData(geoJsonData);
const dataSource = createTimeSeriesEntities(timeSeriesData, start, stop);
@@ -144,17 +151,12 @@
</html>
"""

# Helper function to extract nested property from state based on path
# This function uses the get_by_path utility to navigate nested dictionaries.

# Helper function to extract nested property from the simulation state
def read_var(state, var):
"""Helper to extract nested property from state based on path."""
return get_by_path(state, re.split("/", var))

# GeoPlot class
# This class encapsulates the logic for generating GeoJSON and HTML visualizations.
# It takes configuration and visualization options as input.

# GeoPlot class for generating visualization outputs
class GeoPlot:
def __init__(self, config, options):
self.config = config
(
self.cesium_token,
self.step_time,
self.entity_position,
self.entity_property,
self.visualization_type,
) = (
options["cesium_token"],
options["step_time"],
options["coordinates"],
options["feature"],
"""Initialize GeoPlot with config and visualization options."""
@@ -173,24 +175,21 @@ def __init__(self, config, options):
options["visualization_type"],
)

# Render the trajectory to a GeoJSON and HTML visualization
# This method processes the state trajectory to generate GeoJSON features and an HTML file.
# It extracts coordinates and property values, generates timestamps, and constructs GeoJSON features.

def render(self, state_trajectory):
"""Render the trajectory to a GeoJSON and HTML visualization."""
"""Render the trajectory to GeoJSON and HTML visualization."""
coords, values = [], []
name = self.config["simulation_metadata"]["name"]
geodata_path, geoplot_path = f"{name}.geojson", f"{name}.html"

# Extract coordinates and property values from final states
# Extract final state coordinates and properties
for i in range(0, len(state_trajectory) - 1):
final_state = state_trajectory[i][-1]

coords = np.array(read_var(final_state, self.entity_position)).tolist()
values.append(
np.array(read_var(final_state, self.entity_property)).flatten().tolist()
)

# Start time for the simulation
start_time = pd.Timestamp.utcnow()
timestamps = [
start_time + pd.Timedelta(seconds=i * self.step_time)
for i in range(
self.config["simulation_metadata"]["num_episodes"]
* self.config["simulation_metadata"]["num_steps_per_episode"]
)
]

# Generate timestamps spaced by step_time
@@ -204,7 +203,7 @@ def render(self, state_trajectory):

geojsons = []

# Construct GeoJSON features for each coordinate
# Construct GeoJSON features for visualization
for i, coord in enumerate(coords):
features = []
for time, value_list in zip(timestamps, values):
features.append(
{
"type": "Feature",
"geometry": {
"type": "Point",
"coordinates": [coord[1], coord[0]],
},
"properties": {
"value": value_list[i],
"time": time.isoformat(),
},
}
)
@@ -221,14 +220,11 @@ def render(self, state_trajectory):
})
geojsons.append({"type": "FeatureCollection", "features": features})

# Write GeoJSON file
# Write GeoJSON output file
with open(geodata_path, "w", encoding="utf-8") as f:
json.dump(geojsons, f, ensure_ascii=False, indent=2)

# Fill the HTML template with real data and token
# The HTML file is generated by substituting values into the Cesium template.
# It includes the Cesium token, start and stop times, GeoJSON data, and visualization type.

# Generate HTML file by substituting template values
tmpl = Template(geoplot_template)
with open(geoplot_path, "w", encoding="utf-8") as f:
f.write(
tmpl.substitute(
{
"accessToken": self.cesium_token,
"startTime": timestamps[0].isoformat(),
"stopTime": timestamps[-1].isoformat(),
"data": json.dumps(geojsons),
"visualType": self.visualization_type,
}
)
@@ -240,3 +236,4 @@ def render(self, state_trajectory):
"visualType": self.visualization_type,
})
)