Skip to content

apply_by_to_key keyword for searchsorted #43509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 93 additions & 14 deletions base/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export # not exported by Base
By, Lt, Perm,
ReverseOrdering, ForwardOrdering,
DirectOrdering,
lt, ord, ordtype
lt, ord, ordtype, maybe_skip_by

"""
Base.Order.Ordering
Expand Down Expand Up @@ -70,6 +70,30 @@ Reverse ordering according to [`isless`](@ref).
"""
const Reverse = ReverseOrdering()

struct MaybeSkipBy{T}
id::T
end

"""
Base.maybe_skip_by(a)

Generates an object such that if the `require_by` flag is set
in `ord` or the By class is used, then `by` is applied to `a`.
But if the `require_by` flag is cleared the `MaybeBy` class is
used, then `by` is applied to `a`. This exists to support the
apply_by_to_key=false option in the searchsorted family of
functions.
"""
maybe_skip_by(a) = MaybeSkipBy(a)

maybe_apply(by, a) = by(a)
maybe_apply(::Any, a::MaybeSkipBy) = a.id
always_apply(by, a) = by(a)
always_apply(by, a::MaybeSkipBy) = by(a.id)
_id(a) = a
_id(a::MaybeSkipBy) = a.id


"""
By(by, order::Ordering=Forward)

Expand All @@ -81,6 +105,12 @@ struct By{T, O} <: Ordering
order::O
end

struct MaybeBy{T, O} <: Ordering
by::T
order::O
end


# backwards compatibility with VERSION < v"1.5-"
By(by) = By(by, Forward)

Expand All @@ -107,39 +137,78 @@ struct Perm{O<:Ordering,V<:AbstractVector} <: Ordering
end

ReverseOrdering(by::By) = By(by.by, ReverseOrdering(by.order))
ReverseOrdering(by::MaybeBy) = MaybeBy(by.by, ReverseOrdering(by.order))
ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data)

"""
lt(o::Ordering, a, b)

Test whether `a` is less than `b` according to the ordering `o`.
"""
lt(o::ForwardOrdering, a, b) = isless(a,b)
lt(o::ReverseOrdering, a, b) = lt(o.fwd,b,a)
lt(o::By, a, b) = lt(o.order,o.by(a),o.by(b))
lt(o::Lt, a, b) = o.lt(a,b)
lt(o::ForwardOrdering, a, b) = isless(_id(a),_id(b))
lt(o::ReverseOrdering, a, b) = lt(o.fwd,_id(b),_id(a))
lt(o::By, a, b) =
lt(o.order, always_apply(o.by,a), always_apply(o.by,b))
lt(o::MaybeBy, a, b) =
lt(o.order, maybe_apply(o.by,a), maybe_apply(o.by,b))
lt(o::Lt, a, b) = o.lt(_id(a),_id(b))

@propagate_inbounds function lt(p::Perm, a::Integer, b::Integer)
da = p.data[a]
db = p.data[b]
lt(p.order, da, db) | (!lt(p.order, db, da) & (a < b))
end

_ord(lt::typeof(isless), by::typeof(identity), order::Ordering) = order
_ord(lt::typeof(isless), by, order::Ordering) = By(by, order)
# This is necessary in case Base.Val is not yet defined during
# the bootstrap process


struct _Val{x}
end

_Val(x) = _Val{x}()


function _ord(lt, by, order::Ordering)


# If the 4th argument to _ord is _Val{false}, then the `by` function
# is skipped for any element a of the form maybe_skip_by(a)

_ord(lt::typeof(isless), by::typeof(identity), order::Ordering, ::_Val{true}) =
order
_ord(lt::typeof(isless), by::typeof(identity), order::Ordering, ::_Val{false}) =
order
_ord(lt::typeof(isless), by, order::Ordering, ::_Val{true}) =
By(by, order)
_ord(lt::typeof(isless), by, order::Ordering, ::_Val{false}) =
MaybeBy(by, order)

function _ord(lt, by, order::Ordering, ::_Val{false})
if order === Forward
return Lt((x, y) -> lt(maybe_apply(by,x), maybe_apply(by,y)))
elseif order === Reverse
return Lt((x, y) -> lt(maybe_apply(by,y), maybe_apply(by,x)))
else
error("Passing both lt= and order= arguments is ambiguous; please pass order=Forward or order=Reverse (or leave default)")
end
end

function _ord(lt, by, order::Ordering, ::_Val{true})
if order === Forward
return Lt((x, y) -> lt(by(x), by(y)))
return Lt((x, y) -> lt(always_apply(by,x), always_apply(by,y)))
elseif order === Reverse
return Lt((x, y) -> lt(by(y), by(x)))
return Lt((x, y) -> lt(always_apply(by,y), always_apply(by,x)))
else
error("Passing both lt= and order= arguments is ambiguous; please pass order=Forward or order=Reverse (or leave default)")
end
end




"""
ord(lt, by, rev::Union{Bool, Nothing}, order::Ordering=Forward)
ord(lt, by, rev::Union{Bool, Nothing},
order::Ordering=Forward, require_by=true)

Construct an [`Ordering`](@ref) object from the same arguments used by
[`sort!`](@ref).
Expand All @@ -153,11 +222,21 @@ Passing an `lt` other than `isless` along with an `order` other than
[`Base.Order.Forward`](@ref) or [`Base.Order.Reverse`](@ref) is not permitted,
otherwise all options are independent and can be used together in all possible
combinations.

If `require_by` is true, then function `maybe_skip_by(a)`
is the same as `identity(a)` and therefore
`by` is applied in all cases.
If `require_by` is false then the `by` function is
not applied to element of the form `maybe_skip_by(a)`; instead `a` itself
is returned. This option exists to support the `apply_by_to_key` option
in the searchsorted family of functions.

"""
ord(lt, by, rev::Nothing, order::Ordering=Forward) = _ord(lt, by, order)
ord(lt, by, rev::Nothing, order::Ordering=Forward, require_by=true) =
_ord(lt, by, order, _Val(require_by))

function ord(lt, by, rev::Bool, order::Ordering=Forward)
o = _ord(lt, by, order)
function ord(lt, by, rev::Bool, order::Ordering=Forward, require_by=true)
o = _ord(lt, by, order, _Val(require_by))
return rev ? ReverseOrdering(o) : o
end

Expand Down
30 changes: 19 additions & 11 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::key
hi = hi + u
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, v[m], x)
if lt(o, v[m], maybe_skip_by(x))
lo = m
else
hi = m
Expand All @@ -198,7 +198,7 @@ function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keyt
hi = hi + u
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, x, v[m])
if lt(o, maybe_skip_by(x), v[m])
hi = m
else
lo = m
Expand All @@ -216,9 +216,9 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRa
hi = ihi + u
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, v[m], x)
if lt(o, v[m], maybe_skip_by(x))
lo = m
elseif lt(o, x, v[m])
elseif lt(o, maybe_skip_by(x), v[m])
hi = m
else
a = searchsortedfirst(v, x, max(lo,ilo), m, o)
Expand Down Expand Up @@ -294,18 +294,23 @@ for s in [:searchsortedfirst, :searchsortedlast, :searchsorted]
@eval begin
$s(v::AbstractVector, x, o::Ordering) = (inds = axes(v, 1); $s(v,x,first(inds),last(inds),o))
$s(v::AbstractVector, x;
lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward) =
$s(v,x,ord(lt,by,rev,order))
lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward, apply_by_to_key=true) =
$s(v,x,ord(lt,by,rev,order,apply_by_to_key))
end
end

"""
searchsorted(a, x; by=<transform>, lt=<comparison>, rev=false)
searchsorted(a, x; by=<transform>, lt=<comparison>, rev=false,
apply_by_to_key=true)

Return the range of indices of `a` which compare as equal to `x` (using binary search)
according to the order specified by the `by`, `lt` and `rev` keywords, assuming that `a`
is already sorted in that order. Return an empty range located at the insertion point
if `a` does not contain values equal to `x`.
if `a` does not contain values equal to `x`. If the `apply_by_to_key`
keyword is set to `true`, then the `by` function is also appied
to the key `x` (default, for legacy reasons). If the `apply_by_to_key`
keyword is false, then the `by` function is applied to the elements of `a`
but not the key (more typical usage).

See also: [`insorted`](@ref), [`searchsortedfirst`](@ref), [`sort`](@ref), [`findall`](@ref).

Expand All @@ -329,7 +334,8 @@ julia> searchsorted([1, 2, 4, 5, 5, 7], 0) # no match, insert at start
""" searchsorted

"""
searchsortedfirst(a, x; by=<transform>, lt=<comparison>, rev=false)
searchsortedfirst(a, x; by=<transform>, lt=<comparison>, rev=false,
apply_by_to_key=true)

Return the index of the first value in `a` greater than or equal to `x`, according to the
specified order. Return `lastindex(a) + 1` if `x` is greater than all values in `a`.
Expand Down Expand Up @@ -357,7 +363,8 @@ julia> searchsortedfirst([1, 2, 4, 5, 5, 7], 0) # no match, insert at start
""" searchsortedfirst

"""
searchsortedlast(a, x; by=<transform>, lt=<comparison>, rev=false)
searchsortedlast(a, x; by=<transform>, lt=<comparison>, rev=false,
apply_by_to_key=true)

Return the index of the last value in `a` less than or equal to `x`, according to the
specified order. Return `firstindex(a) - 1` if `x` is less than all values in `a`. `a` is
Expand All @@ -383,7 +390,8 @@ julia> searchsortedlast([1, 2, 4, 5, 5, 7], 0) # no match, insert at start
""" searchsortedlast

"""
insorted(a, x; by=<transform>, lt=<comparison>, rev=false) -> Bool
insorted(a, x; by=<transform>, lt=<comparison>, rev=false,
apply_by_to_key=true) -> Bool

Determine whether an item is in the given sorted collection, in the sense that
it is [`==`](@ref) to one of the values of the collection according to the order
Expand Down
6 changes: 6 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ end
Float16, Float32, Float64, BigInt, BigFloat]

@test searchsorted([1:10;], 1, by=(x -> x >= 5)) == 1:4
@test searchsorted([1:10;], false, by=(x -> x >= 5),
apply_by_to_key=false) == 1:4
@test searchsorted([1:10;], 10, by=(x -> x >= 5)) == 5:10
@test searchsorted([1:10;], true, by=(x -> x >= 5),
apply_by_to_key=false) == 5:10
@test searchsorted([1:5; 1:5; 1:5], 1, 6, 10, Forward) == 6:6
@test searchsorted(fill(1, 15), 1, 6, 10, Forward) == 6:10

Expand Down Expand Up @@ -346,6 +350,7 @@ end

@test insorted(1, collect(1:10), by=(>=(5)))
@test insorted(10, collect(1:10), by=(>=(5)))
@test insorted(true, collect(1:10), by=(>=(5)), apply_by_to_key=false)

for R in numTypes, T in numTypes
@test !insorted(T(0), R[1, 1, 2, 2, 3, 3])
Expand Down Expand Up @@ -393,6 +398,7 @@ end
@test !insorted(0, [1,2,3])
@test !insorted(4, [1,2,3])
@test insorted(3, [10,8,6,9,4,7,2,5,3,1], by=(x -> iseven(x) ? x+5 : x), rev=true)
@test insorted(4, [10,8,6,9,4,7,2,5,3,1], by=(x -> iseven(x) ? x+6 : x+1), rev=true, apply_by_to_key=false)
end
@testset "PartialQuickSort" begin
a = rand(1:10000, 1000)
Expand Down