Skip to content

Add detailed comments to geoplot.py to improve readability #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions geoplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from string import Template
from agent_torch.core.helpers import get_by_path

# HTML template for CesiumJS-based interactive time-series visualization
geoplot_template = """
<!doctype html>
<html lang="en">
Expand All @@ -65,12 +66,13 @@
<body>
<div id="cesiumContainer"></div>
<script>
// Your Cesium ion access token here
// Set the Cesium Ion access token for authentication
Cesium.Ion.defaultAccessToken = '$accessToken'

// Create the viewer
// Create the Cesium viewer object
const viewer = new Cesium.Viewer('cesiumContainer')

// Helper function to interpolate between two colors based on factor
function interpolateColor(color1, color2, factor) {
const result = new Cesium.Color()
result.red = color1.red + factor * (color2.red - color1.red)
Expand All @@ -82,6 +84,7 @@
return result
}

// Compute a color value based on input range
function getColor(value, min, max) {
const factor = (value - min) / (max - min)
return interpolateColor(
Expand All @@ -91,11 +94,13 @@
)
}

// Compute marker pixel size based on value
function getPixelSize(value, min, max) {
const factor = (value - min) / (max - min)
return 100 * (1 + factor)
}

// Parse GeoJSON features into time series data for Cesium entities
function processTimeSeriesData(geoJsonData) {
const timeSeriesMap = new Map()
let minValue = Infinity
Expand All @@ -121,6 +126,7 @@
return { timeSeriesMap, minValue, maxValue }
}

// Create Cesium entities for visualization from time-series data
function createTimeSeriesEntities(
timeSeriesData,
startTime,
Expand Down Expand Up @@ -167,13 +173,13 @@

if ('$visualType' == 'size') {
entity.point.pixelSize.addSample(
time,
getPixelSize(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
)
time,
getPixelSize(
value,
timeSeriesData.minValue,
timeSeriesData.maxValue
)
)
}
})

Expand All @@ -183,12 +189,13 @@
return dataSource
}

// Example time-series GeoJSON data
// Load the generated GeoJSON time series data
const geoJsons = $data

const start = Cesium.JulianDate.fromIso8601('$startTime')
const stop = Cesium.JulianDate.fromIso8601('$stopTime')

// Configure the Cesium viewer clock
viewer.clock.startTime = start.clone()
viewer.clock.stopTime = stop.clone()
viewer.clock.currentTime = start.clone()
Expand All @@ -197,6 +204,7 @@

viewer.timeline.zoomTo(start, stop)

// Add data sources to viewer
for (const geoJsonData of geoJsons) {
const timeSeriesData = processTimeSeriesData(geoJsonData)
const dataSource = createTimeSeriesEntities(
Expand All @@ -212,13 +220,27 @@
</html>
"""


def read_var(state, var):
"""
Retrieve a nested variable from the state dictionary using a path string.
"""
return get_by_path(state, re.split("/", var))


class GeoPlot:
def __init__(self, config, options):
"""
Initialize the GeoPlot visualizer.

Parameters:
- config: Simulation configuration dictionary.
- options: Dictionary containing visualization options:
- cesium_token: Cesium Ion access token.
- step_time: Time step between simulation states (in seconds).
- coordinates: Path to the entity coordinates in the state.
- feature: Path to the entity property to visualize.
- visualization_type: 'size' or 'color' to determine visualization mode.
"""
self.config = config
(
self.cesium_token,
Expand All @@ -235,18 +257,33 @@ def __init__(self, config, options):
)

def render(self, state_trajectory):
"""
Renders the simulation trajectory into a 3D Cesium visualization.

Args:
state_trajectory (list of lists): Each entry is a list of states from each step
of an episode. Each state is a nested dictionary where data can be accessed via paths.

Output:
- Generates two files in the current working directory:
- <simulation_name>.geojson: Contains the data in GeoJSON format.
- <simulation_name>.html: Interactive Cesium viewer with time-series entities.
"""
coords, values = [], []
name = self.config["simulation_metadata"]["name"]
geodata_path, geoplot_path = f"{name}.geojson", f"{name}.html"

# Iterate through state trajectory to extract final state of each episode
for i in range(0, len(state_trajectory) - 1):
final_state = state_trajectory[i][-1]

# Read coordinates and property values for all agents/entities
coords = np.array(read_var(final_state, self.entity_position)).tolist()
values.append(
np.array(read_var(final_state, self.entity_property)).flatten().tolist()
)

# Create a list of timestamps for each step in the trajectory
start_time = pd.Timestamp.utcnow()
timestamps = [
start_time + pd.Timedelta(seconds=i * self.step_time)
Expand All @@ -257,6 +294,7 @@ def render(self, state_trajectory):
]

geojsons = []
# For each coordinate, create a FeatureCollection of time-series features
for i, coord in enumerate(coords):
features = []
for time, value_list in zip(timestamps, values):
Expand All @@ -265,19 +303,21 @@ def render(self, state_trajectory):
"type": "Feature",
"geometry": {
"type": "Point",
"coordinates": [coord[1], coord[0]],
"coordinates": [coord[1], coord[0]], # GeoJSON uses [lon, lat]
},
"properties": {
"value": value_list[i],
"value": value_list[i], # Property value at this timestep
"time": time.isoformat(),
},
}
)
geojsons.append({"type": "FeatureCollection", "features": features})

# Write the GeoJSON to file
with open(geodata_path, "w", encoding="utf-8") as f:
json.dump(geojsons, f, ensure_ascii=False, indent=2)

# Substitute the template placeholders with actual data and write HTML
tmpl = Template(geoplot_template)
with open(geoplot_path, "w", encoding="utf-8") as f:
f.write(
Expand All @@ -290,4 +330,4 @@ def render(self, state_trajectory):
"visualType": self.visualization_type,
}
)
)
)