Skip to content

Commit 4960284

Browse files
authored
fix: Median() integer overflow (apache#19509)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> - Closes apache#19322 ## Rationale for this change When calculating the median of an even-length array of integers, averaging the two middle values using `add_wrapping` causes incorrect results due to integer overflow. For example, with Int8 values -85 and -56: ``` Expected: (-85 + -56) / 2 = -70 Actual: -85 + -56 = -141 wraps to 115, then 115 / 2 = 57 ``` <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Fix `calculate_median` : Use add_checked to detect overflow, and fall back to a safe midpoint formula `a/2 + b/2 + ((a%2 + b%2) / 2)` when overflow occurs. - Add tests <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? - All previous tests pass - Added new tests <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 9eddf47 commit 4960284

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ async fn window_using_aggregates() -> Result<()> {
11121112
| -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 |
11131113
| -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 |
11141114
| -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 |
1115-
| -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 |
1115+
| -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 |
11161116
| -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 |
11171117
| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |
11181118
| -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 |

datafusion/functions-aggregate/src/median.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,25 @@ fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::
604604
let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
605605
// Get the maximum of the low (left side after bi-partitioning)
606606
let left_max = slice_max::<T>(low);
607-
let median = left_max
608-
.add_wrapping(*high)
609-
.div_wrapping(T::Native::usize_as(2));
607+
// Calculate median as the average of the two middle values.
608+
// Use checked arithmetic to detect overflow and fall back to safe formula.
609+
let two = T::Native::usize_as(2);
610+
let median = match left_max.add_checked(*high) {
611+
Ok(sum) => sum.div_wrapping(two),
612+
Err(_) => {
613+
// Overflow detected - use safe midpoint formula:
614+
// a/2 + b/2 + ((a%2 + b%2) / 2)
615+
// This avoids overflow by dividing before adding.
616+
let half_left = left_max.div_wrapping(two);
617+
let half_right = (*high).div_wrapping(two);
618+
let rem_left = left_max.mod_wrapping(two);
619+
let rem_right = (*high).mod_wrapping(two);
620+
// The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed)
621+
// divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed)
622+
let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
623+
half_left.add_wrapping(half_right).add_wrapping(correction)
624+
}
625+
};
610626
Some(median)
611627
} else {
612628
let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,62 @@ SELECT approx_median(col_f64_nan) FROM median_table
990990
----
991991
NaN
992992

993+
994+
# median_i8_overflow_negative
995+
query I
996+
SELECT median(v) FROM (VALUES (arrow_cast(-85, 'Int8')), (arrow_cast(-56, 'Int8'))) AS t(v);
997+
----
998+
-70
999+
1000+
# median_i8_overflow_positive
1001+
# Test overflow with positive values: 100 + 120 = 220 > 127 (max i8)
1002+
query I
1003+
SELECT median(v) FROM (VALUES (arrow_cast(100, 'Int8')), (arrow_cast(120, 'Int8'))) AS t(v);
1004+
----
1005+
110
1006+
1007+
# median_u8_overflow
1008+
# Test unsigned overflow: 200 + 250 = 450 > 255 (max u8)
1009+
query I
1010+
SELECT median(v) FROM (VALUES (arrow_cast(200, 'UInt8')), (arrow_cast(250, 'UInt8'))) AS t(v);
1011+
----
1012+
225
1013+
1014+
# median_i8_no_overflow_normal_case
1015+
# Normal case that doesn't overflow for comparison
1016+
query I
1017+
SELECT median(v) FROM (VALUES (arrow_cast(4, 'Int8')), (arrow_cast(5, 'Int8'))) AS t(v);
1018+
----
1019+
4
1020+
1021+
# median_i8_max_values
1022+
# Test with both i8::MAX values: 127 + 127 = 254 > 127, overflow
1023+
query I
1024+
SELECT median(v) FROM (VALUES (arrow_cast(127, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v);
1025+
----
1026+
127
1027+
1028+
# median_i8_min_values
1029+
# Test with both i8::MIN values: -128 + -128 = -256 < -128, underflow
1030+
query I
1031+
SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(-128, 'Int8'))) AS t(v);
1032+
----
1033+
-128
1034+
1035+
# median_i8_min_max_values
1036+
# Test with i8::MIN and i8::MAX: -128 + 127 = -1, no overflow, median = 0 (truncated from -0.5)
1037+
query I
1038+
SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v);
1039+
----
1040+
0
1041+
1042+
# median_u8_max_values
1043+
# Test with both u8::MAX values: 255 + 255 = 510 > 255, overflow
1044+
query I
1045+
SELECT median(v) FROM (VALUES (arrow_cast(255, 'UInt8')), (arrow_cast(255, 'UInt8'))) AS t(v);
1046+
----
1047+
255
1048+
9931049
# median_sliding_window
9941050
statement ok
9951051
CREATE TABLE median_window_test (

0 commit comments

Comments
 (0)