Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions rust/sedona-spatial-join/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//! can produce `SpatialJoinExec`.

use datafusion::execution::SessionStateBuilder;
use datafusion_common::Result;

mod logical_plan_node;
mod optimizer;
Expand All @@ -34,10 +35,12 @@ mod spatial_expr_utils;
/// implementation provided by this crate and ensures joins created by SQL or using
/// a DataFrame API that meet certain conditions (e.g. contain a spatial predicate as
/// a join condition) are executed using the `SpatialJoinExec`.
pub fn register_planner(state_builder: SessionStateBuilder) -> SessionStateBuilder {
pub fn register_planner(state_builder: SessionStateBuilder) -> Result<SessionStateBuilder> {
// Enable the logical rewrite that turns Filter(CrossJoin) into Join(filter=...)
let state_builder = optimizer::register_spatial_join_logical_optimizer(state_builder);
let state_builder = optimizer::register_spatial_join_logical_optimizer(state_builder)?;

// Enable planning SpatialJoinExec via an extension node during logical->physical planning.
physical_planner::register_spatial_join_planner(state_builder)
Ok(physical_planner::register_spatial_join_planner(
state_builder,
))
}
10 changes: 10 additions & 0 deletions rust/sedona-spatial-join/src/planner/logical_plan_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode {
)
}

fn necessary_children_exprs(&self, _output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
// Request all columns from both children. The default implementation returns None, which
// should also be fine, but we need to return the columns indices explicitly to workaround
// a bug in DataFusion's handling of None projection indices in FFI table provider.
// See https://github.com/apache/datafusion/pull/20393
let left_indices: Vec<usize> = (0..self.left.schema().fields().len()).collect();
let right_indices: Vec<usize> = (0..self.right.schema().fields().len()).collect();
Some(vec![left_indices, right_indices])
}
Comment on lines 109 to 117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for working around a bug in DataFusion. I'll submit a patch later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


fn with_exprs_and_inputs(
&self,
mut exprs: Vec<Expr>,
Expand Down
227 changes: 175 additions & 52 deletions rust/sedona-spatial-join/src/planner/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,127 @@ use std::sync::Arc;

use crate::planner::logical_plan_node::SpatialJoinPlanNode;
use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
use crate::planner::spatial_expr_utils::is_spatial_predicate;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
use datafusion::optimizer::{ApplyOrder, Optimizer, OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::NullEquality;
use datafusion_common::Result;
use datafusion_expr::logical_plan::Extension;
use datafusion_expr::{BinaryExpr, Expr, Operator};
use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
use sedona_common::option::SedonaOptions;
use sedona_common::{sedona_internal_datafusion_err, sedona_internal_err};

/// Register only the logical spatial join optimizer rule.
/// Register the logical spatial join optimizer rules.
///
/// This enables building `Join(filter=...)` from patterns like `Filter(CrossJoin)`.
/// It intentionally does not register any physical plan rewrite rules.
/// This inserts rules at specific positions relative to DataFusion's built-in `PushDownFilter`
/// rule to ensure correct semantics for KNN joins:
///
/// - `MergeSpatialFilterIntoJoin` and `KnnJoinEarlyRewrite` are inserted *before*
/// `PushDownFilter` so that KNN joins are converted to `SpatialJoinPlanNode` extension nodes
/// before filter pushdown runs. Extension nodes naturally block filter pushdown via
/// `prevent_predicate_push_down_columns()`, preventing incorrect pushdown to the build side
/// of KNN joins.
///
/// - `SpatialJoinLogicalRewrite` is appended at the end so that non-KNN spatial joins still
/// benefit from filter pushdown before being converted to extension nodes.
pub(crate) fn register_spatial_join_logical_optimizer(
session_state_builder: SessionStateBuilder,
) -> SessionStateBuilder {
session_state_builder
.with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin))
.with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite))
mut session_state_builder: SessionStateBuilder,
) -> Result<SessionStateBuilder> {
let optimizer = session_state_builder
.optimizer()
.get_or_insert_with(Optimizer::new);

// Find PushDownFilter position by name
let push_down_pos = optimizer
.rules
.iter()
.position(|r| r.name() == "push_down_filter")
.ok_or_else(|| {
sedona_internal_datafusion_err!(
"PushDownFilter rule not found in default optimizer rules"
)
})?;

// Insert KNN-specific rules BEFORE PushDownFilter.
// MergeSpatialFilterIntoJoin must come first because it creates the Join(filter=...)
// nodes that KnnJoinEarlyRewrite then converts to SpatialJoinPlanNode.
optimizer
.rules
.insert(push_down_pos, Arc::new(KnnJoinEarlyRewrite));
optimizer
.rules
.insert(push_down_pos, Arc::new(MergeSpatialFilterIntoJoin));

// Append SpatialJoinLogicalRewrite at the end so non-KNN joins benefit from filter pushdown.
optimizer.rules.push(Arc::new(SpatialJoinLogicalRewrite));

Ok(session_state_builder)
}
/// Logical optimizer rule that enables spatial join planning.

/// Early optimizer rule that converts KNN joins to `SpatialJoinPlanNode` extension nodes
/// *before* DataFusion's `PushDownFilter` runs.
///
/// This rule turns eligible `Join(filter=...)` nodes into a `SpatialJoinPlanNode` extension.
/// This prevents `PushDownFilter` from pushing filters to the build (object) side of KNN joins,
/// which would change which objects are the K nearest neighbors and produce incorrect results.
///
/// Extension nodes naturally block filter pushdown because their default
/// `prevent_predicate_push_down_columns()` returns all columns.
///
/// Handles two patterns that DataFusion's SQL planner creates:
///
/// 1. `Join(filter=ST_KNN(...))` — when the ON clause has only the spatial predicate
/// 2. `Filter(ST_KNN(...), Join(on=[...]))` — when the ON clause also has equi-join conditions,
/// the SQL planner separates equi-conditions into `on` and the spatial predicate into a Filter
#[derive(Default, Debug)]
struct KnnJoinEarlyRewrite;

impl OptimizerRule for KnnJoinEarlyRewrite {
fn name(&self) -> &str {
"knn_join_early_rewrite"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let options = config.options();
let Some(ext) = options.extensions.get::<SedonaOptions>() else {
return Ok(Transformed::no(plan));
};
if !ext.spatial_join.enable {
return Ok(Transformed::no(plan));
}

// Join(filter=ST_KNN(...))
if let LogicalPlan::Join(join) = &plan {
if let Some(filter) = join.filter.as_ref() {
let names = collect_spatial_predicate_names(filter);
if names.contains("st_knn") {
return rewrite_join_to_spatial_join_plan_node(join);
}
}
}

Ok(Transformed::no(plan))
}
}

/// Logical optimizer rule that converts non-KNN spatial joins to `SpatialJoinPlanNode`.
///
/// This rule runs *after* `PushDownFilter` so that non-KNN spatial joins benefit from
/// filter pushdown before being converted to extension nodes.
///
/// KNN joins are skipped here because they are already handled by [[KnnJoinEarlyRewrite]].
#[derive(Default, Debug)]
struct SpatialJoinLogicalRewrite;

Expand Down Expand Up @@ -86,54 +182,74 @@ impl OptimizerRule for SpatialJoinLogicalRewrite {
return Ok(Transformed::no(plan));
}

// Join with with equi-join condition and spatial join condition. Only handle it
// when the join condition contains ST_KNN. KNN join is not a regular join and
// ST_KNN is also not a regular predicate. It must be handled by our spatial join exec.
if !join.on.is_empty() && !spatial_predicate_names.contains("st_knn") {
return Ok(Transformed::no(plan));
if spatial_predicate_names.contains("st_knn") {
// KNN joins should have already been rewritten by KnnJoinEarlyRewrite, so we shouldn't
// see them here.
return sedona_internal_err!(
"Found KNN predicate in SpatialJoinLogicalRewrite, which should have been handled by KnnJoinEarlyRewrite");
}

// Build new filter expression including equi-join conditions
let filter = filter.clone();
let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
Operator::Eq
} else {
Operator::IsNotDistinctFrom
};
let filter = join.on.iter().fold(filter, |acc, (l, r)| {
let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(l.clone()),
eq_op,
Box::new(r.clone()),
));
Expr::and(acc, eq_expr)
});

let schema = Arc::clone(&join.schema);
let node = SpatialJoinPlanNode {
left: join.left.as_ref().clone(),
right: join.right.as_ref().clone(),
join_type: join.join_type,
filter,
schema,
join_constraint: join.join_constraint,
null_equality: join.null_equality,
};
// Join with with equi-join condition should be planned as a regular HashJoin
// or SortMergeJoin.
if !join.on.is_empty() {
return Ok(Transformed::no(plan));
}

Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(node),
})))
rewrite_join_to_spatial_join_plan_node(join)
}
}

/// Shared helper: convert a `Join` node (with spatial predicate in `filter`) to a
/// `SpatialJoinPlanNode`, folding any equi-join `on` conditions into the filter expression.
fn rewrite_join_to_spatial_join_plan_node(join: &Join) -> Result<Transformed<LogicalPlan>> {
let filter = join
.filter
.as_ref()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"join filter must be present for spatial join rewrite".to_string(),
)
})?
.clone();

let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
Operator::Eq
} else {
Operator::IsNotDistinctFrom
};
let filter = join.on.iter().fold(filter, |acc, (l, r)| {
let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(l.clone()),
eq_op,
Box::new(r.clone()),
));
Expr::and(acc, eq_expr)
});

let schema = Arc::clone(&join.schema);
let node = SpatialJoinPlanNode {
left: join.left.as_ref().clone(),
right: join.right.as_ref().clone(),
join_type: join.join_type,
filter,
schema,
join_constraint: join.join_constraint,
null_equality: join.null_equality,
};

Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(node),
})))
}

/// Logical optimizer rule that enables spatial join planning.
///
/// This rule turns eligible `Filter(Join(filter=...))` nodes into a `Join(filter=...)` node,
/// so that the spatial join can be rewritten later by [SpatialJoinLogicalRewrite].
#[derive(Debug, Default)]
struct MergeSpatialProjectionIntoJoin;
struct MergeSpatialFilterIntoJoin;

impl OptimizerRule for MergeSpatialProjectionIntoJoin {
impl OptimizerRule for MergeSpatialFilterIntoJoin {
fn name(&self) -> &str {
"merge_spatial_filter_into_join"
}
Expand Down Expand Up @@ -188,7 +304,9 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
else {
return Ok(Transformed::no(plan));
};
if !is_spatial_predicate(predicate) {

let spatial_predicates = collect_spatial_predicate_names(predicate);
if spatial_predicates.is_empty() {
return Ok(Transformed::no(plan));
}

Expand All @@ -207,20 +325,25 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
};

// Check if this is a suitable join for rewriting
let is_equi_join = !on.is_empty() && !spatial_predicates.contains("st_knn");
if !matches!(
join_type,
JoinType::Inner | JoinType::Left | JoinType::Right
) || !on.is_empty()
|| filter.is_some()
) || is_equi_join
{
return Ok(Transformed::no(plan));
}

let new_filter = match filter {
Some(existing_filter) => Expr::and(predicate.clone(), existing_filter.clone()),
None => predicate.clone(),
};

let rewritten_plan = Join::try_new(
Arc::clone(left),
Arc::clone(right),
on.clone(),
Some(predicate.clone()),
Some(new_filter),
JoinType::Inner,
*join_constraint,
*null_equality,
Expand Down
Loading