Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ struct Agent {
int closest_path_idx_wp;

// Metrics and status tracking
float metrics_array[10]; // [collision, offroad, red_light, reached_goal, lane_dist,
float metrics_array[11]; // [collision, offroad, red_light, reached_goal, lane_dist,
// lane_angle, comfort_violation, velocity_progress, speed_limit, avg_displacement_error]
int collision_state;
int aabb_collision_state;
Expand Down
241 changes: 213 additions & 28 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
#define CONTROL_WOSAC 2
#define CONTROL_SDC_ONLY 3

// Lane selection scoring
#define LANE_SELECTION_DISTANCE_WEIGHT 0.6f
#define LANE_SELECTION_HEADING_WEIGHT 0.4f
#define LANE_DISTANCE_NORMALIZATION 4.0f
#define LANE_SWITCH_THRESHOLD 0.05f // Hysteresis: new lane must be 5% better to switch
#define LANE_ALIGN_COS_THRESHOLD 0.5f

// Minimum distance to goal position
#define MIN_DISTANCE_TO_GOAL 2.0f

Expand All @@ -78,8 +85,15 @@
// Metrics array indices
#define COLLISION_IDX 0
#define OFFROAD_IDX 1
#define REACHED_GOAL_IDX 2
#define LANE_ALIGNED_IDX 3
#define RED_LIGHT_IDX 2
#define REACHED_GOAL_IDX 3
#define LANE_DIST_IDX 4
#define LANE_ANGLE_IDX 5
#define COMFORT_VIOLATION_IDX 6
#define VELOCITY_PROGRESS_IDX 7
#define SPEED_LIMIT_IDX 8
#define AVG_DISPLACEMENT_ERROR_IDX 9
#define LANE_ALIGNED_IDX 10
Comment on lines 85 to +96
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics indices now mismatch

metrics_array is resized to 11 in datatypes.h, but the new indices in drive.h imply you’re now writing up to index 10 (LANE_ALIGNED_IDX). Existing comments in Agent still describe only 10 items, and several reset/respawn paths only clear a subset of indices. This will leave stale values (e.g., LANE_DIST_IDX, LANE_ANGLE_IDX) across steps/episodes and makes the meaning of each index ambiguous. Please update the Agent comment to list all 11 entries and ensure all codepaths that reset metrics zero the full [0..LANE_ALIGNED_IDX] range (or explicitly initialize new slots).


// Grid cell size
#define GRID_CELL_SIZE 5.0f
Expand All @@ -96,9 +110,11 @@
#define ROAD_FEATURES_ONEHOT 14
#define PARTNER_FEATURES 8

#define MAX_CHECKED_LANES 32

// Ego features depend on dynamics model
#define EGO_FEATURES_CLASSIC 8
#define EGO_FEATURES_JERK 11
#define EGO_FEATURES_CLASSIC 11
#define EGO_FEATURES_JERK 14

// Observation normalization constants
#define MAX_SPEED 100.0f
Expand Down Expand Up @@ -329,6 +345,11 @@ float normalize_heading(float heading) {
return heading;
}

static float compute_heading_diff(float heading1, float heading2) {
float heading_diff = heading1 - heading2;
return normalize_heading(heading_diff);
}

// Note: added for 2.5D
typedef struct {
float dis;
Expand Down Expand Up @@ -850,6 +871,37 @@ void load_map_binary(const char *filename, Drive *env) {
// ========================================

// void compute_multi_segment_alignment(void){}
static float compute_multi_segment_alignment(RoadMapElement *element, int center_seg_idx) {
// NOTE: This function returns the average heading in radians for a lane segment,
// with more weight given to the center segment.

float avg_heading = 0.0f;
float total_weight = 0.0f;

int start = (center_seg_idx > 0) ? (center_seg_idx - 1) : center_seg_idx;
int end = (center_seg_idx < element->segment_length - 2) ? (center_seg_idx + 1) : (element->segment_length - 2);

for (int seg_idx = start; seg_idx <= end; seg_idx++) {
if (seg_idx < 0 || seg_idx >= element->segment_length - 1)
continue;

float dx = element->x[seg_idx + 1] - element->x[seg_idx];
float dy = element->y[seg_idx + 1] - element->y[seg_idx];
float seg_heading = atan2f(dy, dx);

float weight = (seg_idx == center_seg_idx) ? 2.0f : 1.0f;

if (total_weight == 0.0f) {
avg_heading = seg_heading;
} else {
float angle_diff = compute_heading_diff(seg_heading, avg_heading);
avg_heading += weight * angle_diff / (total_weight + weight);
}
total_weight += weight;
}

return avg_heading;
}

// void get_drivable_lane_indices(void){}

Expand All @@ -860,6 +912,60 @@ void load_map_binary(const char *filename, Drive *env) {
// void compute_remaining_lane_distance(void){}

// void find_closest_segment_on_lane(void){}
static float find_closest_segment_on_lane(RoadMapElement *lane, float agent_x, float agent_y, int *out_segment_idx) {
int num_segments = lane->segment_length - 1;
if (num_segments < 1) {
*out_segment_idx = 0;
return 1e9f;
}

float min_dist_sq = 1e18f;
int closest_idx = 0;
float closest_cross = 0.0f;

for (int seg_idx = 0; seg_idx < num_segments; seg_idx++) {
float seg_start_x = lane->x[seg_idx];
float seg_start_y = lane->y[seg_idx];
float seg_end_x = lane->x[seg_idx + 1];
float seg_end_y = lane->y[seg_idx + 1];

float seg_dx = seg_end_x - seg_start_x;
float seg_dy = seg_end_y - seg_start_y;
float seg_length_sq = seg_dx * seg_dx + seg_dy * seg_dy;

float to_agent_x = agent_x - seg_start_x;
float to_agent_y = agent_y - seg_start_y;

// cross > 0 means agent is left of lane direction
float cross = seg_dx * to_agent_y - seg_dy * to_agent_x;

float dist_sq;
if (seg_length_sq > 1e-6f) {
float t = (to_agent_x * seg_dx + to_agent_y * seg_dy) / seg_length_sq;
if (t <= 0.0f) {
dist_sq = to_agent_x * to_agent_x + to_agent_y * to_agent_y;
} else if (t >= 1.0f) {
float dx = agent_x - seg_end_x;
float dy = agent_y - seg_end_y;
dist_sq = dx * dx + dy * dy;
} else {
dist_sq = (cross * cross) / seg_length_sq;
}
} else {
dist_sq = to_agent_x * to_agent_x + to_agent_y * to_agent_y;
}

if (dist_sq < min_dist_sq) {
min_dist_sq = dist_sq;
closest_idx = seg_idx;
closest_cross = cross;
}
}

*out_segment_idx = closest_idx;
float abs_dist = sqrtf(min_dist_sq);
return (closest_cross >= 0.0f) ? -abs_dist : abs_dist;
}

// void compute_log_trajectory_distance(void){}

Expand Down Expand Up @@ -1100,7 +1206,10 @@ void reset_agent_metrics(Drive *env, int agent_idx) {
Agent *agent = &env->agents[agent_idx];
agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision
agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; // goal reached
agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle
agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center
agent->collision_state = 0;
agent->aabb_collision_state = 0;
}
Expand Down Expand Up @@ -1155,6 +1264,8 @@ void set_start_position(Drive *env) {
e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad
e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal
e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned
e->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle
e->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center
e->respawn_timestep = -1;
e->stopped = 0;
e->removed = 0;
Expand Down Expand Up @@ -1530,8 +1641,11 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
float sin_heading = sinf(agent->sim_heading);
float min_distance = (float)INT16_MAX;

int closest_lane_entity_idx = -1;
int closest_lane_geometry_idx = -1;
float best_score = 1e9f;
int best_candidate_entity_idx = -1;
int best_candidate_geometry_idx = -1;
float best_candidate_signed_lane_distance = 0.0f;
float best_candidate_lane_heading = 0.0f;

float corners[4][2];
for (int i = 0; i < 4; i++) {
Expand All @@ -1541,6 +1655,12 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
agent->sim_y + (offsets[i][0] * half_length * sin_heading + offsets[i][1] * half_width * cos_heading);
}
int list_size = 0;
// Vehicle-width based distance threshold (3x width)
float max_distance_threshold = 3.0f * agent->sim_width;

// Track already-checked drivable lanes to avoid redundant processing
int checked_lanes[MAX_CHECKED_LANES];
int num_checked_lanes = 0;
GridMapEntity *entity_list = checkNeighbors(env, agent->sim_x, agent->sim_y, collision_offsets,
COLLISION_RANGE * COLLISION_RANGE, &list_size);
for (int i = 0; i < list_size; i++) {
Expand All @@ -1550,6 +1670,7 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
continue;
RoadMapElement *entity;
entity = &env->road_elements[entity_list[i].entity_idx];
int entity_idx = entity_list[i].entity_idx;

// Check for offroad collision with road edges
if (entity->type == ROAD_EDGE) {
Expand All @@ -1571,41 +1692,88 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
break;

// Find closest point on the road centerline to the agent
if (entity->type == ROAD_LANE) {
int entity_idx = entity_list[i].entity_idx;
int geometry_idx = entity_list[i].geometry_idx;
if (is_drivable_road_lane(entity->type) || entity->type == ROAD_LANE) {
// Check if we've already processed this lane (skip duplicates)
int already_checked = 0;
for (int c = 0; c < num_checked_lanes; c++) {
if (checked_lanes[c] == entity_idx) {
already_checked = 1;
break;
}
}
if (already_checked)
continue;

float start[2] = {entity->x[geometry_idx], entity->y[geometry_idx]};
float end[2] = {entity->x[geometry_idx + 1], entity->y[geometry_idx + 1]};
// Mark this lane as checked
if (num_checked_lanes < MAX_CHECKED_LANES) {
checked_lanes[num_checked_lanes++] = entity_idx;
}

float dist = point_to_segment_distance_2d(agent->sim_x, agent->sim_y, start[0], start[1], end[0], end[1]);
float heading_diff = fabsf(atan2f(end[1] - start[1], end[0] - start[0]) - agent->sim_heading);
// Find closest segment on this lane (returns signed distance)
int closest_segment_idx;
float signed_dist = find_closest_segment_on_lane(entity, agent->sim_x, agent->sim_y, &closest_segment_idx);
float abs_dist = fabsf(signed_dist);

// Normalize heading difference to [0, pi]
if (heading_diff > M_PI)
heading_diff = 2.0f * M_PI - heading_diff;
if (abs_dist > max_distance_threshold)
continue; // Skip this lane, too far away

// Penalize if heading differs by more than 30 degrees
if (heading_diff > (M_PI / 6.0f))
dist += 3.0f;
// Compute lane heading using multi-segment alignment
float avg_lane_heading = compute_multi_segment_alignment(entity, closest_segment_idx);

if (dist < min_distance) {
min_distance = dist;
closest_lane_entity_idx = entity_idx;
closest_lane_geometry_idx = geometry_idx;
// Compute heading alignment penalty (0.0 = perfect, 1.0 = opposite)
float heading_diff = compute_heading_diff(agent->sim_heading, avg_lane_heading);
float heading_penalty = fabsf(heading_diff) / M_PI; // Normalize to [0, 1]

// Normalize distance for scoring
float distance_penalty = abs_dist / LANE_DISTANCE_NORMALIZATION;

// Combined score using defined weights
float score =
LANE_SELECTION_DISTANCE_WEIGHT * distance_penalty + LANE_SELECTION_HEADING_WEIGHT * heading_penalty;

// Hysteresis: penalize switching away from current lane
if (agent->current_lane_index != entity_idx && agent->current_lane_index != -1) {
score += LANE_SWITCH_THRESHOLD;
}

// Track best candidate
if (score < best_score) {
min_distance = abs_dist;
best_score = score;
best_candidate_entity_idx = entity_idx;
best_candidate_geometry_idx = closest_segment_idx;
best_candidate_signed_lane_distance = signed_dist;
best_candidate_lane_heading = avg_lane_heading;
}
}
}

// Update lane alignment metric (running average)
if (best_candidate_entity_idx != -1) {
agent->current_lane_index = best_candidate_entity_idx;
agent->current_lane_geometry_idx = best_candidate_geometry_idx;

// Lane distance and angle metrics (GIGAFLOW Frenet coordinates)
// x_f = lateral offset from lane center (left = negative, right = positive)
agent->metrics_array[LANE_DIST_IDX] = best_candidate_signed_lane_distance;
// theta_f = angle relative to lane heading
float theta_f = compute_heading_diff(agent->sim_heading, best_candidate_lane_heading);
agent->metrics_array[LANE_ANGLE_IDX] = cosf(theta_f); // Store cos(θ_f)
} else {
// Agent not on any lane - use "bad" values to indicate offroad state
agent->current_lane_index = -1;
agent->current_lane_geometry_idx = -1;
agent->metrics_array[LANE_DIST_IDX] = LANE_DISTANCE_NORMALIZATION; // Max distance (far from lane)
agent->metrics_array[LANE_ANGLE_IDX] = 0.0f;
}

// check if aligned with closest lane and set current lane
// 4.0m threshold: agents more than 4 meters from any lane are considered off-road
if (min_distance > 4.0f || closest_lane_entity_idx == -1) {
if (min_distance > max_distance_threshold || best_candidate_entity_idx == -1) {
agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f;
agent->current_lane_index = -1;
} else {
agent->current_lane_index = closest_lane_entity_idx;
int lane_aligned =
check_lane_aligned(agent, &env->road_elements[closest_lane_entity_idx], closest_lane_geometry_idx);
agent->current_lane_index = best_candidate_entity_idx;
int lane_aligned = (fabs(agent->metrics_array[LANE_ANGLE_IDX]) > 0.965) ? 1 : 0;
agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned;
}

Expand Down Expand Up @@ -1663,6 +1831,13 @@ void compute_observations(Drive *env) {
float v_dot_heading = ego_entity->sim_vx * cos_heading + ego_entity->sim_vy * sin_heading;
float signed_speed = copysignf(speed_magnitude, v_dot_heading);

// Adding speed limit calculation
float speed_limit = 20.0f;
// We need to add speed limit calculation

// Adding lane angle and center information
float lane_center_dist = ego_entity->metrics_array[LANE_DIST_IDX] / LANE_DISTANCE_NORMALIZATION;
lane_center_dist = fmaxf(-1.0f, fminf(1.0f, lane_center_dist));
// Set goal distances
float goal_x = ego_entity->goal_position_x - ego_entity->sim_x;
float goal_y = ego_entity->goal_position_y - ego_entity->sim_y;
Expand All @@ -1688,8 +1863,14 @@ void compute_observations(Drive *env) {
(ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3];
obs[9] = ego_entity->a_lat / JERK_LAT[2];
obs[10] = (ego_entity->respawn_timestep != -1) ? 1 : 0;
obs[11] = fminf(speed_limit / MAX_SPEED, 1.0f);
obs[12] = lane_center_dist;
obs[13] = ego_entity->metrics_array[LANE_ANGLE_IDX];
} else {
obs[7] = (ego_entity->respawn_timestep != -1) ? 1 : 0;
obs[8] = fminf(speed_limit / MAX_SPEED, 1.0f);
obs[9] = lane_center_dist;
obs[10] = ego_entity->metrics_array[LANE_ANGLE_IDX];
}

// Relative Pos of other cars
Expand Down Expand Up @@ -1836,6 +2017,8 @@ void respawn_agent(Drive *env, int agent_idx) {
agent->metrics_array[OFFROAD_IDX] = 0.0f;
agent->metrics_array[REACHED_GOAL_IDX] = 0.0f;
agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f;
agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle
agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center

agent->respawn_timestep = env->timestep;
agent->collided_before_goal = 0;
Expand Down Expand Up @@ -2106,6 +2289,8 @@ void c_reset(Drive *env) {
agent->metrics_array[OFFROAD_IDX] = 0.0f;
agent->metrics_array[REACHED_GOAL_IDX] = 0.0f;
agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f;
agent->metrics_array[LANE_ANGLE_IDX] = 0.0f; // lane angle
agent->metrics_array[LANE_DIST_IDX] = 0.0f; // distance from lane center
agent->stopped = 0;
agent->removed = 0;

Expand Down
2 changes: 1 addition & 1 deletion pufferlib/ocean/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs):
self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories

# Determine ego dimension from environment's dynamics model
self.ego_dim = 11 if env.dynamics_model == "jerk" else 8
self.ego_dim = env.ego_features

self.ego_encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)),
Expand Down