From 39ce55ea199d38789337d508b40d39f563ad41b3 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Tue, 24 Jun 2025 10:06:58 +0200 Subject: [PATCH 01/13] ci: test CI --- ci.test | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ci.test diff --git a/ci.test b/ci.test new file mode 100644 index 000000000000..e69de29bb2d1 From 365aa9e7cb5c9d885aebb7d911cc1f181c99ec20 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:08:12 +0800 Subject: [PATCH 02/13] chore(deps): bump tracing-subscriber from 0.3.19 to 0.3.20 (#17355) Bumps [tracing-subscriber](https://github.com/tokio-rs/tracing) from 0.3.19 to 0.3.20. - [Release notes](https://github.com/tokio-rs/tracing/releases) - [Commits](https://github.com/tokio-rs/tracing/compare/tracing-subscriber-0.3.19...tracing-subscriber-0.3.20) --- updated-dependencies: - dependency-name: tracing-subscriber dependency-version: 0.3.20 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ffdb8c6403c..1012587c0c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4235,12 +4235,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -4434,12 +4433,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" version = "4.2.1" @@ -6675,9 +6668,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", From 7ff9daaa684c085189782a4a34c4c0bfd82655cb Mon Sep 17 00:00:00 2001 From: wiedld Date: Tue, 15 Oct 2024 12:51:48 -0700 Subject: [PATCH 03/13] chore: default=true for skip_physical_aggregate_schema_check, and add warn logging --- datafusion/common/src/config.rs | 2 +- datafusion/core/src/physical_planner.rs | 3 +++ datafusion/sqllogictest/test_files/information_schema.slt | 4 ++-- docs/source/user-guide/configs.md | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 31159d4a8588..6758ed4799d5 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -389,7 +389,7 @@ config_namespace! { /// /// This is used to workaround bugs in the planner that are now caught by /// the new schema verification step. - pub skip_physical_aggregate_schema_check: bool, default = false + pub skip_physical_aggregate_schema_check: bool, default = true /// Sets the compression codec used when spilling data to disk. /// diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab123dcceada..b355f5f894c7 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -712,6 +712,9 @@ impl DefaultPhysicalPlanner { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } } + + log::warn!("Physical input schema should be the same as the one converted from logical input schema, but did not match for logical plan:\n{}", input.display_indent()); + return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences .iter() .map(|s| format!("\n\t- {s}")) diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index f76e436e0ad3..0f1fb892c746 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -259,7 +259,7 @@ datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 -datafusion.execution.skip_physical_aggregate_schema_check false +datafusion.execution.skip_physical_aggregate_schema_check true datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 @@ -371,7 +371,7 @@ datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer ve datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode -datafusion.execution.skip_physical_aggregate_schema_check false When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. +datafusion.execution.skip_physical_aggregate_schema_check true When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index d453cb0684da..62be57ec5b85 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,7 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.skip_physical_aggregate_schema_check | true | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | | datafusion.execution.spill_compression | uncompressed | Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. | | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | | datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | From 08f5e76b071bfe21cdd0cc8102de94be5653d024 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Tue, 1 Jul 2025 16:43:37 +0200 Subject: [PATCH 04/13] fix: temporary fix to handle incorrect coalesce (inserted during EnforceDistribution) which later causes an error during EnforceSort (without our patch). The next DataFusion version 46 upgrade does the proper fix, which is to not insert the coalesce in the first place. test: recreating the iox plan: * demonstrate the insertion of coalesce after the use of column estimates, and the removal of the test scenario's forcing of rr repartitioning test: reproducer of SanityCheck failure after EnforceSorting removes the coalesce added in the EnforceDistribution fix: special case to not remove the needed coalesce --- .../enforce_distribution.rs | 332 +++++++++++++++++- .../physical_optimizer/enforce_sorting.rs | 99 +++++- .../tests/physical_optimizer/test_utils.rs | 7 + .../src/enforce_sorting/mod.rs | 6 +- datafusion/physical-optimizer/src/utils.rs | 6 + 5 files changed, 432 insertions(+), 18 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index fd847763124a..2dce87de00ed 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -23,7 +23,7 @@ use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, parquet_exec_with_stats, repartition_exec, schema, sort_exec, sort_exec_with_preserve_partitioning, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, trim_plan_display, union_exec, }; use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; @@ -39,10 +39,12 @@ use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::{assert_contains, ScalarValue}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr::{JoinType, Operator}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{AggregateUDF, JoinType, Operator}; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -51,6 +53,7 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; +use datafusion_physical_optimizer::sanity_checker::check_plan_sanity; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -66,7 +69,7 @@ use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + displayable, get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, }; @@ -162,8 +165,8 @@ impl ExecutionPlan for SortRequiredExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -237,7 +240,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc, alias_pairs: Vec<(String, String)>, ) -> Arc { @@ -251,6 +254,15 @@ fn projection_exec_with_alias( fn aggregate_exec_with_alias( input: Arc, alias_pairs: Vec<(String, String)>, +) -> Arc { + aggregate_exec_with_aggr_expr_and_alias(input, vec![], alias_pairs) +} + +#[expect(clippy::type_complexity)] +fn aggregate_exec_with_aggr_expr_and_alias( + input: Arc, + aggr_expr: Vec<(Arc, Vec>)>, + alias_pairs: Vec<(String, String)>, ) -> Arc { let schema = schema(); let mut group_by_expr: Vec<(Arc, String)> = vec![]; @@ -271,18 +283,31 @@ fn aggregate_exec_with_alias( .collect::>(); let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + let aggr_expr = aggr_expr + .into_iter() + .map(|(udaf, exprs)| { + AggregateExprBuilder::new(udaf.clone(), exprs) + .alias(udaf.name()) + .schema(Arc::clone(&schema)) + .build() + .map(Arc::new) + .unwrap() + }) + .collect::>(); + let filter_exprs = std::iter::repeat_n(None, aggr_expr.len()).collect::>(); + Arc::new( AggregateExec::try_new( AggregateMode::FinalPartitioned, final_grouping, - vec![], - vec![], + aggr_expr.clone(), + filter_exprs.clone(), Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, - vec![], - vec![], + aggr_expr, + filter_exprs, input, schema.clone(), ) @@ -439,6 +464,12 @@ impl TestConfig { self } + /// Set batch size. + fn with_batch_size(mut self, batch_size: usize) -> Self { + self.config.execution.batch_size = batch_size; + self + } + /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, /// and return the result plan (for potentional subsequent runs). @@ -2027,6 +2058,285 @@ fn repartition_ignores_union() -> Result<()> { Ok(()) } +fn aggregate_over_union(input: Vec>) -> Arc { + let union = union_exec(input); + let plan = + aggregate_exec_with_alias(union, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +// Aggregate over a union, +// with current testing setup. +// +// It will repartiton twice for an aggregate over a union. +// * repartitions before the partial aggregate. +// * repartitions before the final aggregation. +#[test] +fn repartitions_twice_for_aggregate_after_union() -> Result<()> { + let plan = aggregate_over_union(vec![parquet_exec(); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // Updated plan (post optimization) will have added RepartitionExecs (btwn union and aggregation). + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default(); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +// Aggregate over a union, +// but make the test setup more realistic. +// +// It will repartiton once for an aggregate over a union. +// * repartitions btwn partial & final aggregations. +#[test] +fn repartitions_once_for_aggregate_after_union() -> Result<()> { + // use parquet exec with stats + let plan: Arc = + aggregate_over_union(vec![parquet_exec_with_stats(10000); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // This removes the forced round-robin repartitioning, + // by no longer hard-coding batch_size=1. + // + // Updated plan (post optimization) will have added only 1 RepartitionExec. + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_once_for_aggregate_after_union`], but adds a sort btwn +/// the union and the aggregate. This changes the outcome: +/// +/// * we no longer get a distribution error. +/// * but we still get repartitioning? +#[test] +fn repartitions_for_aggregate_after_sorted_union() -> Result<()> { + let plan = aggregate_over_sorted_union(vec![parquet_exec_with_stats(10000); 2]); + + // With the sort, there is no distribution error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It does not repartition on the first run + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // But does repartition on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_sorted_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union_projection( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let union_projection = projection_exec_with_alias( + union, + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union_projection); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "ProjectionExec: expr=[a@0 as a, b@1 as value]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_for_aggregate_after_sorted_union`], but adds a projection +/// as well between the union and aggregate. This change the outcome: +/// +/// * we no longer get repartitioning, and instead get coalescing. +#[test] +fn coalesces_for_aggregate_after_sorted_union_projection() -> Result<()> { + let plan = + aggregate_over_sorted_union_projection(vec![parquet_exec_with_stats(10000); 2]); + + // Same as `repartitions_for_aggregate_after_sorted_union`. No error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It no longer does a repartition on the first run. + // Instead adds a SPM. + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // Then it removes the SPM, and inserts a coalesace on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + #[test] fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e31a30cc0883..8f3cc3191b17 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,15 +17,16 @@ use std::sync::Arc; +use crate::physical_optimizer::enforce_distribution::projection_exec_with_alias; use crate::physical_optimizer::test_utils::{ aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + parquet_exec_with_stats, projection_exec, repartition_exec, schema, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -47,6 +48,9 @@ use datafusion_physical_expr_common::sort_expr::{ }; use datafusion_physical_expr::{Distribution, Partitioning}; use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -2292,6 +2296,93 @@ async fn test_commutativity() -> Result<()> { Ok(()) } +fn single_partition_aggregate( + input: Arc, + alias_pairs: Vec<(String, String)>, +) -> Arc { + let schema = schema(); + let group_by = alias_pairs + .iter() + .map(|(column, alias)| (col(column, &input.schema()).unwrap(), alias.to_string())) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(group_by); + + Arc::new( + AggregateExec::try_new( + AggregateMode::SinglePartitioned, + group_by, + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +#[tokio::test] +async fn test_preserve_needed_coalesce() -> Result<()> { + // Input to EnforceSorting, from our test case. + let plan = projection_exec_with_alias( + union_exec(vec![parquet_exec_with_stats(10000); 2]), + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let plan = Arc::new(CoalescePartitionsExec::new(plan)); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let plan: Arc = + single_partition_aggregate(plan, vec![("a".to_string(), "a1".to_string())]); + let plan = sort_exec(sort_key, plan); + + // Starting plan: as in our test case. + assert_eq!( + get_plan_string(&plan), + vec![ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + let checker = SanityCheckPlan::new().optimize(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // EnforceSorting will remove the coalesce, and add an SPM further up (above the aggregate). + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(plan, &Default::default())?; + assert_eq!( + get_plan_string(&optimized), + vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + // Plan is valid. + let checker = SanityCheckPlan::new(); + let checker = checker.optimize(optimized, &Default::default()); + assert!(checker.is_ok()); + + Ok(()) +} + #[tokio::test] async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 7fb0f795f294..c5b3e2743811 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -509,6 +509,13 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} + // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 8a71b28486a2..dae0edcfb171 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -48,8 +48,8 @@ use crate::enforce_sorting::sort_pushdown::{ }; use crate::output_requirements::OutputRequirementExec; use crate::utils::{ - add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, - is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, + add_sort_above, add_sort_above_with_check, is_aggregation, is_coalesce_partitions, + is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; use crate::PhysicalOptimizerRule; @@ -678,7 +678,7 @@ fn remove_bottleneck_in_subplan( ) -> Result { let plan = &requirements.plan; let children = &mut requirements.children; - if is_coalesce_partitions(&children[0].plan) { + if is_coalesce_partitions(&children[0].plan) && !is_aggregation(plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_child_node = children[0].children.swap_remove(0); while new_child_node.plan.output_partitioning() == plan.output_partitioning() diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 3655e555a744..d3207d4880a7 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use datafusion_common::Result; use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -113,3 +114,8 @@ pub fn is_repartition(plan: &Arc) -> bool { pub fn is_limit(plan: &Arc) -> bool { plan.as_any().is::() || plan.as_any().is::() } + +/// Checks whether the given operator is a [`AggregateExec`]. +pub fn is_aggregation(plan: &Arc) -> bool { + plan.as_any().is::() +} From ba566f7349859d690740b6db7c55b5bfb6865250 Mon Sep 17 00:00:00 2001 From: Christian van der Loo Date: Wed, 23 Jul 2025 10:11:53 -0400 Subject: [PATCH 05/13] fix(build-wasm): put `arrow-ipc/zstd` dep under `compression` feature flag (#16844) --- Cargo.toml | 1 - datafusion/core/Cargo.toml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 601d11f12dd8..bb28098104df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,7 +99,6 @@ arrow-flight = { version = "55.2.0", features = [ ] } arrow-ipc = { version = "55.2.0", default-features = false, features = [ "lz4", - "zstd", ] } arrow-ord = { version = "55.2.0", default-features = false } arrow-schema = { version = "55.2.0", default-features = false } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c4455e271c84..1a6a66923e55 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", + "arrow-ipc/zstd", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] From fde4c359c4a886d68cb9f491d6076e7fe369af2e Mon Sep 17 00:00:00 2001 From: Liam Bao Date: Wed, 6 Aug 2025 06:00:14 -0400 Subject: [PATCH 06/13] Support `centroids` config for `approx_percentile_cont_with_weight` (#17003) * Support centroids config for `approx_percentile_cont_with_weight` * Match two functions' signature * Update docs * Address comments and unify centroids config --- .../src/approx_percentile_cont.rs | 15 ++- .../src/approx_percentile_cont_with_weight.rs | 111 +++++++++++++----- .../tests/cases/roundtrip_logical_plan.rs | 13 +- .../sqllogictest/test_files/aggregate.slt | 10 ++ docs/source/user-guide/expressions.md | 42 +++---- .../user-guide/sql/aggregate_functions.md | 17 ++- 6 files changed, 151 insertions(+), 57 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 55c8c847ad0a..863ee15d89ec 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -77,8 +77,14 @@ pub fn approx_percentile_cont( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -313,7 +319,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } if arg_types.len() == 3 && !arg_types[2].is_integer() { return plan_err!( - "approx_percentile_cont requires integer max_size input types" + "approx_percentile_cont requires integer centroids input types" ); } Ok(arg_types[0].clone()) @@ -360,6 +366,11 @@ impl ApproxPercentileAccumulator { } } + // public for approx_percentile_cont_with_weight + pub(crate) fn max_size(&self) -> usize { + self.digest.max_size() + } + // public for approx_percentile_cont_with_weight pub fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ab847e838869..d30ea624cae9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -25,32 +25,53 @@ use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, -}; -use datafusion_functions_aggregate_common::tdigest::{ - Centroid, TDigest, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, }; +use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -make_udaf_expr_and_func!( +create_func!( ApproxPercentileContWithWeight, - approx_percentile_cont_with_weight, - expression weight percentile, - "Computes the approximate percentile continuous with weight of a set of numbers", approx_percentile_cont_with_weight_udaf ); +/// Computes the approximate percentile continuous with weight of a set of numbers +pub fn approx_percentile_cont_with_weight( + order_by: Sort, + weight: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let expr = order_by.expr.clone(); + + let args = if let Some(centroids) = centroids { + vec![expr, weight, percentile, centroids] + } else { + vec![expr, weight, percentile] + }; + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_with_weight_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql > SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; +---------------------------------------------------------------------------------------------+ @@ -58,6 +79,12 @@ make_udaf_expr_and_func!( +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -67,6 +94,10 @@ make_udaf_expr_and_func!( argument( name = "percentile", description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] pub struct ApproxPercentileContWithWeight { @@ -91,21 +122,26 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with weight and float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + ])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + int.clone(), + ])); + } + } Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::one_of(variants, Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -138,6 +174,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { if arg_types[2] != DataType::Float64 { return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); } + if arg_types.len() == 4 && !arg_types[3].is_integer() { + return plan_err!( + "approx_percentile_cont_with_weight requires integer centroids input types" + ); + } Ok(arg_types[0].clone()) } @@ -148,17 +189,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.exprs.len() != 3 { + if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 { return plan_err!( - "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]" ); } let sub_args = AccumulatorArgs { - exprs: &[ - Arc::clone(&acc_args.exprs[0]), - Arc::clone(&acc_args.exprs[2]), - ], + exprs: if acc_args.exprs.len() == 4 { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + Arc::clone(&acc_args.exprs[3]), // centroids + ] + } else { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + ] + }, ..acc_args }; let approx_percentile_cont_accumulator = @@ -244,7 +293,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { let mut digests: Vec = vec![]; for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, + self.approx_percentile_cont_accumulator.max_size(), Centroid::new(*mean, *weight), )) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6c51d553fe16..b56fdc0fede6 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -981,7 +981,18 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), - approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + None, + ), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + Some(lit(50)), + ), grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 753820b6b619..122907d83181 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1790,6 +1790,16 @@ c 123 d 124 e 115 +# approx_percentile_cont_with_weight with centroids +query TI +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + # csv_query_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 03ab86eeb813..abf0286fa85b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Functions -| Syntax | Description | -| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | -| avg(expr) | Сalculates the average value for `expr`. | -| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | -| approx_median(expr) | Calculates an approximation of the median for `expr`. | -| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. | -| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. | -| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | -| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | -| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | -| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | -| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | -| count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | -| cube(exprs) | Creates a grouping set for all combination of `exprs` | -| grouping_set(exprs) | Create a grouping set. | -| max(expr) | Finds the maximum value of `expr`. | -| median(expr) | Сalculates the median of `expr`. | -| min(expr) | Finds the minimum value of `expr`. | -| rollup(exprs) | Creates a grouping set for rollup sets. | -| sum(expr) | Сalculates the sum of `expr`. | +| Syntax | Description | +| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| avg(expr) | Сalculates the average value for `expr`. | +| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | +| approx_median(expr) | Calculates an approximation of the median for `expr`. | +| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | +| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | +| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | +| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | +| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | +| count(expr) | Returns the number of rows for `expr`. | +| count_distinct | Creates an expression to represent the count(distinct) aggregate function | +| cube(exprs) | Creates a grouping set for all combination of `exprs` | +| grouping_set(exprs) | Create a grouping set. | +| max(expr) | Finds the maximum value of `expr`. | +| median(expr) | Сalculates the median of `expr`. | +| min(expr) | Finds the minimum value of `expr`. | +| rollup(exprs) | Creates a grouping set for rollup sets. | +| sum(expr) | Сalculates the sum of `expr`. | ## Aggregate Function Builder diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf3..3c88714b5f10 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,6 +846,12 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) #### Example ```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -859,7 +865,7 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -867,6 +873,7 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex - **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. #### Example @@ -877,4 +884,10 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ ``` From 6ce84620e3a64eb9448a61c8192c7194616a9414 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Aug 2025 05:44:44 -0700 Subject: [PATCH 07/13] (Re)Support old syntax for `approx_percentile_cont` and `approx_percentile_cont_with_weight` (#16999) * Add sqllogictests * Allow both new and old sytanx for approx_percentile_cont and approx_percentile_cont_with_weight * Update docs * Add documentation and more tests --- .../src/approx_percentile_cont.rs | 19 ++++++++- .../src/approx_percentile_cont_with_weight.rs | 10 +++++ datafusion/sql/src/expr/function.rs | 7 +--- .../sqllogictest/test_files/aggregate.slt | 41 ++++++++++++++++++- .../user-guide/sql/aggregate_functions.md | 29 +++++++++++++ 5 files changed, 99 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 863ee15d89ec..fce300e79bea 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -91,7 +91,24 @@ pub fn approx_percentile_cont( +-----------------------------------------------------------------------+ | 65.0 | +-----------------------------------------------------------------------+ -```"#, +``` +An alternate syntax is also supported: +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` +"#, standard_argument(name = "expression",), argument( name = "percentile", diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index d30ea624cae9..f70d751a8cb9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -85,6 +85,16 @@ pub fn approx_percentile_cont_with_weight( +--------------------------------------------------------------------------------------------------+ | 78.5 | +--------------------------------------------------------------------------------------------------+ +``` +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e63ca75d019d..74d40e145f1c 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -380,10 +380,6 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - if fm.is_ordered_set_aggregate() && within_group.is_empty() { - return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); - } - if null_treatment.is_some() && !fm.supports_null_handling_clause() { return plan_err!( "[IGNORE | RESPECT] NULLS are not permitted for {}", @@ -403,7 +399,8 @@ impl SqlToRel<'_, S> { None, )?; - // add target column expression in within group clause to function arguments + // Add the WITHIN GROUP ordering expressions to the front of the argument list + // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) if !within_group.is_empty() { args = within_group .iter() diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 122907d83181..4671408349e2 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1287,7 +1287,7 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 ## Column `c12` is omitted due to a large relative error (~10%) due to the small ## float values. -#csv_query_approx_percentile_cont (c2) +# csv_query_approx_percentile_cont (c2) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- @@ -1303,6 +1303,23 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS D ---- true + +# csv_query_approx_percentile_cont (c2, alternate syntax, should be the same as above) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + # csv_query_approx_percentile_cont (c3) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 @@ -1743,6 +1760,17 @@ c 122 d 124 e 115 + +# csv_query_approx_percentile_cont_with_weight (should be the same as above) +query TI +SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1762,6 +1790,17 @@ c 122 d 124 e 115 +# csv_query_approx_percentile_cont_with_weight alternate syntax +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + query TI SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 3c88714b5f10..4f2f0abe55c9 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -860,6 +860,24 @@ approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expressi +-----------------------------------------------------------------------+ ``` +An alternate syntax is also supported: + +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` + ### `approx_percentile_cont_with_weight` Returns the weighted approximate percentile of input values using the t-digest algorithm. @@ -891,3 +909,14 @@ approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROU | 78.5 | +--------------------------------------------------------------------------------------------------+ ``` + +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ +``` From f189789d2de5a4d9f3d989068f73f18ae9636cba Mon Sep 17 00:00:00 2001 From: Qi Zhu <821684824@qq.com> Date: Mon, 28 Jul 2025 17:26:49 +0800 Subject: [PATCH 08/13] feat: support distinct for window (#16925) * feat: support distinct for window * fix * fix * fisx * fix unparse * fix test * fix test * easy way * add test * add comments --- datafusion-examples/examples/advanced_udwf.rs | 1 + datafusion/core/src/physical_planner.rs | 2 + .../core/tests/fuzz_cases/window_fuzz.rs | 3 + .../physical_optimizer/enforce_sorting.rs | 1 + .../tests/physical_optimizer/test_utils.rs | 1 + datafusion/expr/src/expr.rs | 27 ++- datafusion/expr/src/expr_fn.rs | 1 + datafusion/expr/src/planner.rs | 1 + datafusion/expr/src/tree_node.rs | 12 ++ datafusion/expr/src/udaf.rs | 42 +++-- datafusion/functions-aggregate/src/count.rs | 163 +++++++++++++++++- datafusion/functions-window/src/planner.rs | 15 +- .../optimizer/src/analyzer/type_coercion.rs | 29 +++- .../src/windows/bounded_window_agg_exec.rs | 1 + datafusion/physical-plan/src/windows/mod.rs | 42 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../proto/src/physical_plan/from_proto.rs | 1 + datafusion/sql/src/expr/function.rs | 12 ++ datafusion/sql/src/unparser/expr.rs | 13 +- datafusion/sqllogictest/test_files/window.slt | 79 +++++++++ .../consumer/expr/window_function.rs | 1 + .../producer/expr/window_function.rs | 1 + 22 files changed, 407 insertions(+), 42 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index f7316ddc1bec..e0fab7ee9f31 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -199,6 +199,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, }, })) }; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b355f5f894c7..6d2393c99d2b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1649,6 +1649,7 @@ pub fn create_window_expr_with_name( order_by, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); let physical_args = @@ -1677,6 +1678,7 @@ pub fn create_window_expr_with_name( window_frame, physical_schema, ignore_nulls, + *distinct, ) } other => plan_err!("Invalid window expression '{other:?}'"), diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 316d3ba5a926..23e3281cf386 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -288,6 +288,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { Arc::new(window_frame), &extended_schema, false, + false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -660,6 +661,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec1, false, @@ -678,6 +680,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec2, search_mode.clone(), diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 8f3cc3191b17..ef29d51e5d37 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -3766,6 +3766,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.window_frame, input_schema.as_ref(), false, + false, )?; let window_exec = if window_expr.uses_bounded_memory() { Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index c5b3e2743811..5e2d61e68f8d 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -265,6 +265,7 @@ pub fn bounded_window_exec_with_partition( Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, + false, ) .unwrap(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0749ff0e98b7..efe8a639087a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1131,6 +1131,8 @@ pub struct WindowFunctionParams { pub window_frame: WindowFrame, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, + /// Distinct flag + pub distinct: bool, } impl WindowFunction { @@ -1145,6 +1147,7 @@ impl WindowFunction { order_by: Vec::default(), window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, } } @@ -2291,6 +2294,7 @@ impl NormalizeEq for Expr { partition_by: self_partition_by, order_by: self_order_by, null_treatment: self_null_treatment, + distinct: self_distinct, }, } = left.as_ref(); let WindowFunction { @@ -2302,6 +2306,7 @@ impl NormalizeEq for Expr { partition_by: other_partition_by, order_by: other_order_by, null_treatment: other_null_treatment, + distinct: other_distinct, }, } = other.as_ref(); @@ -2325,6 +2330,7 @@ impl NormalizeEq for Expr { && a.nulls_first == b.nulls_first && a.expr.normalize_eq(&b.expr) }) + && self_distinct == other_distinct } ( Expr::Exists(Exists { @@ -2558,11 +2564,13 @@ impl HashNode for Expr { order_by: _, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); + distinct.hash(state); } Expr::InList(InList { expr: _expr, @@ -2865,15 +2873,27 @@ impl Display for SchemaDisplay<'_> { order_by, window_frame, null_treatment, + distinct, } = params; + // Write function name and open parenthesis + write!(f, "{fun}(")?; + + // If DISTINCT, emit the keyword + if *distinct { + write!(f, "DISTINCT ")?; + } + + // Write the comma‑separated argument list write!( f, - "{}({})", - fun, + "{}", schema_name_from_exprs_comma_separated_without_space(args)? )?; + // **Close the argument parenthesis** + write!(f, ")")?; + if let Some(null_treatment) = null_treatment { write!(f, " {null_treatment}")?; } @@ -3260,9 +3280,10 @@ impl Display for Expr { order_by, window_frame, null_treatment, + distinct, } = params; - fmt_function(f, &fun.to_string(), false, args, true)?; + fmt_function(f, &fun.to_string(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, "{nt}")?; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c0351a9dcaca..fab86fe7663d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -945,6 +945,7 @@ impl ExprFuncBuilder { window_frame: window_frame .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, + distinct, }, }) } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 067c7a94279f..b04fe32d376e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -308,6 +308,7 @@ pub struct RawWindowExpr { pub order_by: Vec, pub window_frame: WindowFrame, pub null_treatment: Option, + pub distinct: bool, } /// Result of planning a raw expr with [`ExprPlanner`] diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f953aec5a1e3..b6f583ca4c74 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -242,10 +242,22 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { + if distinct { + return Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build() + .unwrap(); + } + Expr::from(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b6c8eb627c77..15c0dd57ad2c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -554,14 +554,25 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut schema_name = String::new(); - schema_name.write_fmt(format_args!( - "{}({})", - self.name(), - schema_name_from_exprs(args)? - ))?; + + // Inject DISTINCT into the schema name when requested + if *distinct { + schema_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } else { + schema_name.write_fmt(format_args!( + "{}({})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } if let Some(null_treatment) = null_treatment { schema_name.write_fmt(format_args!(" {null_treatment}"))?; @@ -579,7 +590,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { " ORDER BY [{}]", schema_name_from_sorts(order_by)? ))?; - }; + } schema_name.write_fmt(format_args!(" {window_frame}"))?; @@ -648,15 +659,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut display_name = String::new(); - display_name.write_fmt(format_args!( - "{}({})", - self.name(), - expr_vec_fmt!(args) - ))?; + if *distinct { + display_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + expr_vec_fmt!(args) + ))?; + } else { + display_name.write_fmt(format_args!( + "{}({})", + self.name(), + expr_vec_fmt!(args) + ))?; + } if let Some(null_treatment) = null_treatment { display_name.write_fmt(format_args!(" {null_treatment}"))?; diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 09904bbad6ec..7a7c2879aa79 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -31,7 +31,7 @@ use arrow::{ }; use datafusion_common::{ downcast_value, internal_err, not_impl_err, stats::Precision, - utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue, + utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, }; use datafusion_expr::{ expr::WindowFunction, @@ -59,6 +59,7 @@ use std::{ ops::BitAnd, sync::Arc, }; + make_udaf_expr_and_func!( Count, count, @@ -406,6 +407,98 @@ impl AggregateUDFImpl for Count { // the same as new values are seen. SetMonotonicity::Increasing } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + if args.is_distinct { + let acc = + SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?; + Ok(Box::new(acc)) + } else { + let acc = CountAccumulator::new(); + Ok(Box::new(acc)) + } + } +} + +// DistinctCountAccumulator does not support retract_batch and sliding window +// this is a specialized accumulator for distinct count that supports retract_batch +// and sliding window. +#[derive(Debug)] +pub struct SlidingDistinctCountAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl SlidingDistinctCountAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + counts: HashMap::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctCountAccumulator { + fn state(&mut self) -> Result> { + let keys = self.counts.keys().cloned().collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + keys.as_slice(), + &self.data_type, + ))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); + } + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let list_arr = states[0].as_list::(); + for inner in list_arr.iter().flatten() { + for j in 0..inner.len() { + let v = ScalarValue::try_from_array(&*inner, j)?; + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.counts.len() as i64))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } } #[derive(Debug)] @@ -878,4 +971,72 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn sliding_distinct_count_accumulator_basic() -> Result<()> { + // Basic update_batch + evaluate functionality + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // Create an Int32Array: [1, 2, 2, 3, null] + let values: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + ])); + acc.update_batch(&[values])?; + // Expect distinct values {1,2,3} → count = 3 + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_retract() -> Result<()> { + // Test that retract_batch properly decrements counts + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?; + // Initial batch: ["a", "b", "a"] + let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])) + as ArrayRef; + acc.update_batch(&[arr1])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"} + + // Retract batch: ["a", null, "b"] + let arr2 = + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef; + acc.retract_batch(&[arr2])?; + // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"} + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_merge_states() -> Result<()> { + // Test merging multiple accumulator states with merge_batch + let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // acc1 sees [1, 2] + acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?; + // acc2 sees [2, 3] + acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?; + // Extract their states as Vec + let state_sv1 = acc1.state()?; + let state_sv2 = acc2.state()?; + // Convert ScalarValue states into Vec, propagating errors + // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray + let state_arr1: Vec = state_sv1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + let state_arr2: Vec = state_sv2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + // Merge both states into a fresh accumulator + let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + merged.merge_batch(&state_arr1)?; + merged.merge_batch(&state_arr2)?; + // Expect distinct {1,2,3} → count = 3 + assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } } diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 091737bb9c15..5e3a6bc6336c 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -41,6 +41,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; let origin_expr = Expr::from(WindowFunction { @@ -51,6 +52,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, }); @@ -68,6 +70,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let raw_expr = RawWindowExpr { @@ -77,6 +80,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }; // TODO: remove the next line after `Expr::Wildcard` is removed @@ -93,18 +97,23 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; - let new_expr = Expr::from(WindowFunction::new( + let mut new_expr_before_build = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?; + .null_treatment(null_treatment); + if distinct { + new_expr_before_build = new_expr_before_build.distinct(); + } + + let new_expr = new_expr_before_build.build()?; let new_expr = saved_name.restore(new_expr); return Ok(PlannerResult::Planned(new_expr)); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a98b0fdcc3d3..e6fc006cb2ff 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -549,6 +549,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let window_frame = @@ -565,14 +566,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { _ => args, }; - Ok(Transformed::yes( - Expr::from(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, - )) + if distinct { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build()?, + )) + } else { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) + } } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index d3335c0e7fe1..4c991544f877 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1377,6 +1377,7 @@ mod tests { Arc::new(window_frame), &input.schema(), false, + false, )?], input, input_order_mode, diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5583abfd72a2..085b17cab9bc 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,21 +103,38 @@ pub fn create_window_expr( window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, + distinct: bool, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) - .alias(name) - .with_ignore_nulls(ignore_nulls) - .build() - .map(Arc::new)?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) + if distinct { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .distinct() + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } else { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } } WindowFunctionDefinition::WindowUDF(fun) => Arc::new(StandardWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, @@ -805,6 +822,7 @@ mod tests { Arc::new(WindowFrame::new(None)), schema.as_ref(), false, + false, )?], blocking_exec, false, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43afaa0fbe65..f59e97df0d46 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -316,6 +316,7 @@ pub fn serialize_expr( ref window_frame, // TODO: support null treatment in proto null_treatment: _, + distinct: _, }, } = window_fun.as_ref(); let mut buf = Vec::new(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c60470b2218..2ed6ec037fc8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -179,6 +179,7 @@ pub fn parse_physical_window_expr( Arc::new(window_frame), &extended_schema, false, + false, ) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 74d40e145f1c..fd0e7dc6e3b9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -352,6 +352,7 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct: function_args.distinct, }; for planner in self.context_provider.get_expr_planners().iter() { @@ -368,8 +369,19 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct, } = window_expr; + if distinct { + return Expr::from(expr::WindowFunction::new(func_def, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build(); + } + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 4ddd5ccccbbd..4c0dc316615c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,9 @@ use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, Expr as AstExpr, Function, Ident, Interval, - ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, ValueWithSpan, + self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, + Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, + ValueWithSpan, }; use std::sync::Arc; use std::vec; @@ -198,6 +199,7 @@ impl Unparser<'_> { partition_by, order_by, window_frame, + distinct, .. }, } = window_fun.as_ref(); @@ -256,7 +258,8 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, + duplicate_treatment: distinct + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -339,7 +342,7 @@ impl Unparser<'_> { }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: distinct - .then_some(ast::DuplicateTreatment::Distinct), + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -2051,6 +2054,7 @@ mod tests { order_by: vec![], window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, @@ -2076,6 +2080,7 @@ mod tests { ), ), null_treatment: None, + distinct: false, }, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 82de11302857..bed9121eec3f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5650,3 +5650,82 @@ WINDOW 3 7 4 11 5 16 + + +# window with distinct operation +statement ok +CREATE TABLE table_test_distinct_count ( + k VARCHAR, + v Int, + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table_test_distinct_count (k, v, time) VALUES + ('a', 1, '1970-01-01T00:01:00.00Z'), + ('a', 1, '1970-01-01T00:02:00.00Z'), + ('a', 1, '1970-01-01T00:03:00.00Z'), + ('a', 2, '1970-01-01T00:03:00.00Z'), + ('a', 1, '1970-01-01T00:04:00.00Z'), + ('b', 3, '1970-01-01T00:01:00.00Z'), + ('b', 3, '1970-01-01T00:02:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'); + +query TPII +SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:04:00Z 4 2 +b 1970-01-01T00:01:00Z 1 1 +b 1970-01-01T00:02:00Z 2 1 +b 1970-01-01T00:03:00Z 4 2 +b 1970-01-01T00:03:00Z 4 2 + + +query TT +EXPLAIN SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ODER BY k, time; +---- +logical_plan +01)Projection: oder.k, oder.time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +02)--WindowAggr: windowExpr=[[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +03)----SubqueryAlias: oder +04)------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)ProjectionExec: expr=[k@0 as k, time@2 as time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +02)--BoundedWindowAggExec: wdw=[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +03)----SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=1 +05)--------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +06)----------DataSourceExec: partitions=2, partition_sizes=[5, 4] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 80b643a547ee..27f0de84b7a0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -112,6 +112,7 @@ pub async fn from_window_function( order_by, window_frame, null_treatment: None, + distinct: false, }, })) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs index 17e71f2d7c14..94a39e930f1c 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -42,6 +42,7 @@ pub fn from_window_function( order_by, window_frame, null_treatment: _, + distinct: _, }, } = window_fn; // function reference From c57e72cc12ed01b13f019bea6397104375d9806b Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 5 Sep 2025 13:42:04 +0200 Subject: [PATCH 09/13] fix: return ALL constants in `EquivalenceProperties::constants` (#17404) * test: regression test for #17372 * test: add more direct regression for #17372 * fix: return ALL constants in `EquivalenceProperties::constants` --- .../physical_optimizer/sanity_checker.rs | 81 ++++++++++++++++++- .../src/equivalence/properties/mod.rs | 9 ++- .../src/equivalence/properties/union.rs | 59 ++++++++++++++ 3 files changed, 142 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index 6233f5d09c56..ce6eb13c86c4 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -20,7 +20,8 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, }; use arrow::compute::SortOptions; @@ -28,8 +29,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, Result}; -use datafusion_physical_expr::expressions::col; +use datafusion_common::{JoinType, Result, ScalarValue}; +use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; @@ -665,3 +666,77 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { assert_sanity_check(&smj, false); Ok(()) } + +/// A particular edge case. +/// +/// See . +#[tokio::test] +async fn test_union_with_sorts_and_constants() -> Result<()> { + let schema_in = create_test_schema2(); + + let proj_exprs_1 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + let proj_exprs_2 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + + let source_1 = memory_exec(&schema_in); + let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap(); + let schema_sources = source_1.schema(); + let ordering_sources: LexOrdering = + [sort_expr("a", &schema_sources).nulls_last()].into(); + let source_1 = sort_exec(ordering_sources.clone(), source_1); + + let source_2 = memory_exec(&schema_in); + let source_2 = projection_exec(proj_exprs_2, source_2).unwrap(); + let source_2 = sort_exec(ordering_sources.clone(), source_2); + + let plan = union_exec(vec![source_1, source_2]); + + let schema_out = plan.schema(); + let ordering_out: LexOrdering = [ + sort_expr("const_1", &schema_out).nulls_last(), + sort_expr("const_2", &schema_out).nulls_last(), + sort_expr("a", &schema_out).nulls_last(), + ] + .into(); + + let plan = sort_preserving_merge_exec(ordering_out, plan); + + let plan_str = displayable(plan.as_ref()).indent(true).to_string(); + let plan_str = plan_str.trim(); + assert_snapshot!( + plan_str, + @r" + SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST] + UnionExec + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + " + ); + + assert_sanity_check(&plan, true); + + Ok(()) +} diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 6d18d34ca4de..fed3b78de801 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -255,10 +255,11 @@ impl EquivalenceProperties { pub fn constants(&self) -> Vec { self.eq_group .iter() - .filter_map(|c| { - c.constant.as_ref().and_then(|across| { - c.canonical_expr() - .map(|expr| ConstExpr::new(Arc::clone(expr), across.clone())) + .flat_map(|c| { + c.iter().filter_map(|expr| { + c.constant + .as_ref() + .map(|across| ConstExpr::new(Arc::clone(expr), across.clone())) }) }) .collect() diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 4f44b9b0c9d4..efbefd0d39bf 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -921,4 +921,63 @@ mod tests { .collect::>(), )) } + + #[test] + fn test_constants_share_values() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("const_1", DataType::Utf8, false), + Field::new("const_2", DataType::Utf8, false), + ])); + + let col_const_1 = col("const_1", &schema)?; + let col_const_2 = col("const_2", &schema)?; + + let literal_foo = ScalarValue::Utf8(Some("foo".to_owned())); + let literal_bar = ScalarValue::Utf8(Some("bar".to_owned())); + + let const_expr_1_foo = ConstExpr::new( + Arc::clone(&col_const_1), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_foo = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_bar = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_bar.clone())), + ); + + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + + // | Input | Const_1 | Const_2 | + // | ----- | ------- | ------- | + // | 1 | foo | foo | + // | 2 | foo | bar | + input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?; + input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?; + + // Calculate union properties + let union_props = calculate_union(vec![input1, input2], schema)?; + + // This should result in: + // const_1 = Uniform("foo") + // const_2 = Heterogeneous + assert_eq!(union_props.constants().len(), 2); + let union_const_1 = &union_props.constants()[0]; + assert!(union_const_1.expr.eq(&col_const_1)); + assert_eq!( + union_const_1.across_partitions, + AcrossPartitions::Uniform(Some(literal_foo)), + ); + let union_const_2 = &union_props.constants()[1]; + assert!(union_const_2.expr.eq(&col_const_2)); + assert_eq!( + union_const_2.across_partitions, + AcrossPartitions::Heterogeneous, + ); + + Ok(()) + } } From f7cae504b8b7760e1091f452beb031b195151237 Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Sun, 7 Sep 2025 02:05:17 +1000 Subject: [PATCH 10/13] feat: Support binary data types for `SortMergeJoin` `on` clause (#17431) * feat: Support binary data types for `SortMergeJoin` `on` clause * Add sql level tests for merge join on binary keys --------- Co-authored-by: Andrew Lamb --- .../src/joins/sort_merge_join.rs | 155 +++++++++++++++++- .../test_files/sort_merge_join.slt | 58 ++++++- 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9a6832283486..3708ec4900a0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2503,6 +2503,10 @@ fn compare_join_arrays( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2571,6 +2575,10 @@ fn is_join_arrays_equal( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2600,7 +2608,8 @@ mod tests { use arrow::array::{ builder::{BooleanBuilder, UInt64Builder}, - BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array, + BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Int32Array, RecordBatch, UInt64Array, }; use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; @@ -2694,6 +2703,56 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + fn build_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Binary, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(BinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_fixed_size_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::FixedSizeBinary(3), false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(FixedSizeBinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + /// returns a table with 3 columns of i32 in memory pub fn build_table_i32_nullable( a: (&str, &Vec>), @@ -3932,6 +3991,100 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_binary() -> Result<()> { + let left = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_fixed_size_binary() -> Result<()> { + let left = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + #[tokio::test] async fn join_left_sort_order() -> Result<()> { let left = build_table( diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index c17fe8dfc7e6..ed463333217a 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -833,9 +833,61 @@ t2 as ( 11 14 12 15 -# return sql params back to default values statement ok -set datafusion.optimizer.prefer_hash_join = true; +set datafusion.execution.batch_size = 8192; + +###### +## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys +###### statement ok -set datafusion.execution.batch_size = 8192; +create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5); + +statement ok +create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50); + +# Binary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# LargeBinary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# BinaryView join keys +query ?I?I +with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1), + t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# FixedSizeBinary join keys +query ?I?I +with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1), + t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +statement ok +drop table t1; + +statement ok +drop table t2; + +# return sql params back to default values +statement ok +set datafusion.optimizer.prefer_hash_join = true; From d13bb15d1f45dcb7beef9a823b799b819c4c7597 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 14:54:00 +0200 Subject: [PATCH 11/13] chore: skip order calculation / exponential planning --- .../src/equivalence/properties/union.rs | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index efbefd0d39bf..8ec2464068ef 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -67,16 +67,43 @@ fn calculate_union_binary( }) .collect::>(); + // TEMP HACK WORKAROUND + // Revert code from https://github.com/apache/datafusion/pull/12562 + // Context: https://github.com/apache/datafusion/issues/13748 + // Context: https://github.com/influxdata/influxdb_iox/issues/13038 + // Next, calculate valid orderings for the union by searching for prefixes // in both sides. - let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(&lhs, &rhs)?; - orderings.add_satisfied_orderings(&rhs, &lhs)?; - let orderings = orderings.build(); + let mut orderings = vec![]; + for ordering in lhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !rhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + for ordering in rhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !lhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } let mut eq_properties = EquivalenceProperties::new(lhs.schema); eq_properties.add_constants(constants)?; eq_properties.add_orderings(orderings); + Ok(eq_properties) } @@ -122,6 +149,7 @@ struct UnionEquivalentOrderingBuilder { orderings: Vec, } +#[expect(unused)] impl UnionEquivalentOrderingBuilder { fn new() -> Self { Self { orderings: vec![] } @@ -504,6 +532,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -579,6 +608,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); @@ -607,6 +637,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -658,6 +689,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) From a9cf9aca9ebf0d6c04e0861d2baebffa0ba77dbc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 12:14:19 -0400 Subject: [PATCH 12/13] (New) Test + workaround for SanityCheck plan --- datafusion/physical-optimizer/src/sanity_checker.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index acc70d39f057..3cc5319f9e10 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -32,6 +32,8 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -135,6 +137,14 @@ pub fn check_plan_sanity( plan.required_input_ordering(), plan.required_input_distribution(), ) { + // TEMP HACK WORKAROUND https://github.com/apache/datafusion/issues/11492 + if child.as_any().downcast_ref::().is_some() { + continue; + } + if child.as_any().downcast_ref::().is_some() { + continue; + } + let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { let sort_req = sort_req.into_single(); From f9256d5c661c176139ae1f98d5b8339886f11a48 Mon Sep 17 00:00:00 2001 From: Denise Wiedl Date: Thu, 18 Sep 2025 23:07:21 +0300 Subject: [PATCH 13/13] Keep aggregate udaf schema names unique when missing an order-by * test: reproducer of bug * fix: make schema names unique for approx_percentile_cont * test: regression test is now resolved --- datafusion/expr/src/udaf.rs | 2 +- .../sqllogictest/test_files/aggregate.slt | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 15c0dd57ad2c..a0d2b6a96a48 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -459,7 +459,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { // exclude the first function argument(= column) in ordered set aggregate function, // because it is duplicated with the WITHIN GROUP clause in schema name. - let args = if self.is_ordered_set_aggregate() { + let args = if self.is_ordered_set_aggregate() && !order_by.is_empty() { &args[1..] } else { &args[..] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4671408349e2..ab31a87b9e35 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1771,6 +1771,29 @@ c 122 d 124 e 115 + +# using approx_percentile_cont on 2 columns with same signature +query TII +SELECT c1, approx_percentile_cont(c2, 0.95) AS c2, approx_percentile_cont(c3, 0.95) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 5 73 +b 5 68 +c 5 122 +d 5 124 +e 5 115 + +# error is unique to this UDAF +query TRR +SELECT c1, avg(c2) AS c2, avg(c3) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 2.857142857143 -18.333333333333 +b 3.263157894737 -5.842105263158 +c 2.666666666667 -1.333333333333 +d 2.444444444444 25.444444444444 +e 3 40.333333333333 + + + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----