diff --git a/base/ordering.jl b/base/ordering.jl index e49102159c962..a5a7ed2159309 100644 --- a/base/ordering.jl +++ b/base/ordering.jl @@ -16,7 +16,7 @@ export # not exported by Base By, Lt, Perm, ReverseOrdering, ForwardOrdering, DirectOrdering, - lt, ord, ordtype + lt, ord, ordtype, maybe_skip_by """ Base.Order.Ordering @@ -70,6 +70,30 @@ Reverse ordering according to [`isless`](@ref). """ const Reverse = ReverseOrdering() +struct MaybeSkipBy{T} + id::T +end + +""" + Base.maybe_skip_by(a) + +Generates an object such that if the `require_by` flag is set +in `ord` or the By class is used, then `by` is applied to `a`. +But if the `require_by` flag is cleared the `MaybeBy` class is +used, then `by` is applied to `a`. This exists to support the +apply_by_to_key=false option in the searchsorted family of +functions. +""" +maybe_skip_by(a) = MaybeSkipBy(a) + +maybe_apply(by, a) = by(a) +maybe_apply(::Any, a::MaybeSkipBy) = a.id +always_apply(by, a) = by(a) +always_apply(by, a::MaybeSkipBy) = by(a.id) +_id(a) = a +_id(a::MaybeSkipBy) = a.id + + """ By(by, order::Ordering=Forward) @@ -81,6 +105,12 @@ struct By{T, O} <: Ordering order::O end +struct MaybeBy{T, O} <: Ordering + by::T + order::O +end + + # backwards compatibility with VERSION < v"1.5-" By(by) = By(by, Forward) @@ -107,6 +137,7 @@ struct Perm{O<:Ordering,V<:AbstractVector} <: Ordering end ReverseOrdering(by::By) = By(by.by, ReverseOrdering(by.order)) +ReverseOrdering(by::MaybeBy) = MaybeBy(by.by, ReverseOrdering(by.order)) ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data) """ @@ -114,10 +145,13 @@ ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data) Test whether `a` is less than `b` according to the ordering `o`. """ -lt(o::ForwardOrdering, a, b) = isless(a,b) -lt(o::ReverseOrdering, a, b) = lt(o.fwd,b,a) -lt(o::By, a, b) = lt(o.order,o.by(a),o.by(b)) -lt(o::Lt, a, b) = o.lt(a,b) +lt(o::ForwardOrdering, a, b) = isless(_id(a),_id(b)) +lt(o::ReverseOrdering, a, b) = lt(o.fwd,_id(b),_id(a)) +lt(o::By, a, b) = + lt(o.order, always_apply(o.by,a), always_apply(o.by,b)) +lt(o::MaybeBy, a, b) = + lt(o.order, maybe_apply(o.by,a), maybe_apply(o.by,b)) +lt(o::Lt, a, b) = o.lt(_id(a),_id(b)) @propagate_inbounds function lt(p::Perm, a::Integer, b::Integer) da = p.data[a] @@ -125,21 +159,56 @@ lt(o::Lt, a, b) = o.lt(a,b) lt(p.order, da, db) | (!lt(p.order, db, da) & (a < b)) end -_ord(lt::typeof(isless), by::typeof(identity), order::Ordering) = order -_ord(lt::typeof(isless), by, order::Ordering) = By(by, order) +# This is necessary in case Base.Val is not yet defined during +# the bootstrap process + + +struct _Val{x} +end + +_Val(x) = _Val{x}() + -function _ord(lt, by, order::Ordering) + + +# If the 4th argument to _ord is _Val{false}, then the `by` function +# is skipped for any element a of the form maybe_skip_by(a) + +_ord(lt::typeof(isless), by::typeof(identity), order::Ordering, ::_Val{true}) = + order +_ord(lt::typeof(isless), by::typeof(identity), order::Ordering, ::_Val{false}) = + order +_ord(lt::typeof(isless), by, order::Ordering, ::_Val{true}) = + By(by, order) +_ord(lt::typeof(isless), by, order::Ordering, ::_Val{false}) = + MaybeBy(by, order) + +function _ord(lt, by, order::Ordering, ::_Val{false}) + if order === Forward + return Lt((x, y) -> lt(maybe_apply(by,x), maybe_apply(by,y))) + elseif order === Reverse + return Lt((x, y) -> lt(maybe_apply(by,y), maybe_apply(by,x))) + else + error("Passing both lt= and order= arguments is ambiguous; please pass order=Forward or order=Reverse (or leave default)") + end +end + +function _ord(lt, by, order::Ordering, ::_Val{true}) if order === Forward - return Lt((x, y) -> lt(by(x), by(y))) + return Lt((x, y) -> lt(always_apply(by,x), always_apply(by,y))) elseif order === Reverse - return Lt((x, y) -> lt(by(y), by(x))) + return Lt((x, y) -> lt(always_apply(by,y), always_apply(by,x))) else error("Passing both lt= and order= arguments is ambiguous; please pass order=Forward or order=Reverse (or leave default)") end end + + + """ - ord(lt, by, rev::Union{Bool, Nothing}, order::Ordering=Forward) + ord(lt, by, rev::Union{Bool, Nothing}, + order::Ordering=Forward, require_by=true) Construct an [`Ordering`](@ref) object from the same arguments used by [`sort!`](@ref). @@ -153,11 +222,21 @@ Passing an `lt` other than `isless` along with an `order` other than [`Base.Order.Forward`](@ref) or [`Base.Order.Reverse`](@ref) is not permitted, otherwise all options are independent and can be used together in all possible combinations. + +If `require_by` is true, then function `maybe_skip_by(a)` +is the same as `identity(a)` and therefore +`by` is applied in all cases. +If `require_by` is false then the `by` function is +not applied to element of the form `maybe_skip_by(a)`; instead `a` itself +is returned. This option exists to support the `apply_by_to_key` option +in the searchsorted family of functions. + """ -ord(lt, by, rev::Nothing, order::Ordering=Forward) = _ord(lt, by, order) +ord(lt, by, rev::Nothing, order::Ordering=Forward, require_by=true) = + _ord(lt, by, order, _Val(require_by)) -function ord(lt, by, rev::Bool, order::Ordering=Forward) - o = _ord(lt, by, order) +function ord(lt, by, rev::Bool, order::Ordering=Forward, require_by=true) + o = _ord(lt, by, order, _Val(require_by)) return rev ? ReverseOrdering(o) : o end diff --git a/base/sort.jl b/base/sort.jl index bfa3e1d0dc0e2..2d1bf74828a90 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -181,7 +181,7 @@ function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::key hi = hi + u @inbounds while lo < hi - u m = midpoint(lo, hi) - if lt(o, v[m], x) + if lt(o, v[m], maybe_skip_by(x)) lo = m else hi = m @@ -198,7 +198,7 @@ function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keyt hi = hi + u @inbounds while lo < hi - u m = midpoint(lo, hi) - if lt(o, x, v[m]) + if lt(o, maybe_skip_by(x), v[m]) hi = m else lo = m @@ -216,9 +216,9 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRa hi = ihi + u @inbounds while lo < hi - u m = midpoint(lo, hi) - if lt(o, v[m], x) + if lt(o, v[m], maybe_skip_by(x)) lo = m - elseif lt(o, x, v[m]) + elseif lt(o, maybe_skip_by(x), v[m]) hi = m else a = searchsortedfirst(v, x, max(lo,ilo), m, o) @@ -294,18 +294,23 @@ for s in [:searchsortedfirst, :searchsortedlast, :searchsorted] @eval begin $s(v::AbstractVector, x, o::Ordering) = (inds = axes(v, 1); $s(v,x,first(inds),last(inds),o)) $s(v::AbstractVector, x; - lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward) = - $s(v,x,ord(lt,by,rev,order)) + lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward, apply_by_to_key=true) = + $s(v,x,ord(lt,by,rev,order,apply_by_to_key)) end end """ - searchsorted(a, x; by=, lt=, rev=false) + searchsorted(a, x; by=, lt=, rev=false, + apply_by_to_key=true) Return the range of indices of `a` which compare as equal to `x` (using binary search) according to the order specified by the `by`, `lt` and `rev` keywords, assuming that `a` is already sorted in that order. Return an empty range located at the insertion point -if `a` does not contain values equal to `x`. +if `a` does not contain values equal to `x`. If the `apply_by_to_key` +keyword is set to `true`, then the `by` function is also appied +to the key `x` (default, for legacy reasons). If the `apply_by_to_key` +keyword is false, then the `by` function is applied to the elements of `a` +but not the key (more typical usage). See also: [`insorted`](@ref), [`searchsortedfirst`](@ref), [`sort`](@ref), [`findall`](@ref). @@ -329,7 +334,8 @@ julia> searchsorted([1, 2, 4, 5, 5, 7], 0) # no match, insert at start """ searchsorted """ - searchsortedfirst(a, x; by=, lt=, rev=false) + searchsortedfirst(a, x; by=, lt=, rev=false, + apply_by_to_key=true) Return the index of the first value in `a` greater than or equal to `x`, according to the specified order. Return `lastindex(a) + 1` if `x` is greater than all values in `a`. @@ -357,7 +363,8 @@ julia> searchsortedfirst([1, 2, 4, 5, 5, 7], 0) # no match, insert at start """ searchsortedfirst """ - searchsortedlast(a, x; by=, lt=, rev=false) + searchsortedlast(a, x; by=, lt=, rev=false, + apply_by_to_key=true) Return the index of the last value in `a` less than or equal to `x`, according to the specified order. Return `firstindex(a) - 1` if `x` is less than all values in `a`. `a` is @@ -383,7 +390,8 @@ julia> searchsortedlast([1, 2, 4, 5, 5, 7], 0) # no match, insert at start """ searchsortedlast """ - insorted(a, x; by=, lt=, rev=false) -> Bool + insorted(a, x; by=, lt=, rev=false, + apply_by_to_key=true) -> Bool Determine whether an item is in the given sorted collection, in the sense that it is [`==`](@ref) to one of the values of the collection according to the order diff --git a/test/sorting.jl b/test/sorting.jl index e90138549afd8..c971740411325 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -86,7 +86,11 @@ end Float16, Float32, Float64, BigInt, BigFloat] @test searchsorted([1:10;], 1, by=(x -> x >= 5)) == 1:4 + @test searchsorted([1:10;], false, by=(x -> x >= 5), + apply_by_to_key=false) == 1:4 @test searchsorted([1:10;], 10, by=(x -> x >= 5)) == 5:10 + @test searchsorted([1:10;], true, by=(x -> x >= 5), + apply_by_to_key=false) == 5:10 @test searchsorted([1:5; 1:5; 1:5], 1, 6, 10, Forward) == 6:6 @test searchsorted(fill(1, 15), 1, 6, 10, Forward) == 6:10 @@ -346,6 +350,7 @@ end @test insorted(1, collect(1:10), by=(>=(5))) @test insorted(10, collect(1:10), by=(>=(5))) + @test insorted(true, collect(1:10), by=(>=(5)), apply_by_to_key=false) for R in numTypes, T in numTypes @test !insorted(T(0), R[1, 1, 2, 2, 3, 3]) @@ -393,6 +398,7 @@ end @test !insorted(0, [1,2,3]) @test !insorted(4, [1,2,3]) @test insorted(3, [10,8,6,9,4,7,2,5,3,1], by=(x -> iseven(x) ? x+5 : x), rev=true) + @test insorted(4, [10,8,6,9,4,7,2,5,3,1], by=(x -> iseven(x) ? x+6 : x+1), rev=true, apply_by_to_key=false) end @testset "PartialQuickSort" begin a = rand(1:10000, 1000)