Skip to content

Commit 423dc4b

Browse files
committed
Add 2-arg versions of findmax/min, argmax/min
Fixes JuliaLang#27613. Related: JuliaLang#27639, JuliaLang#27612, JuliaLang#34674. Thanks to @tkf, @StefanKarpinski and @drewrobson for their assistance with this PR.
1 parent 7bbb84f commit 423dc4b

File tree

4 files changed

+229
-134
lines changed

4 files changed

+229
-134
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ New library functions
8787
* New function `bitrotate(x, k)` for rotating the bits in a fixed-width integer ([#33937]).
8888
* One argument methods `startswith(x)` and `endswith(x)` have been added, returning partially-applied versions of the functions, similar to existing methods like `isequal(x)` ([#33193]).
8989
* New function `isgreater(a, b)` defines a descending total order where unorderable values and missing are ordered smaller than any regular value.
90+
* Two argument methods `findmax(f, domain)`, `argmax(f, domain)` and the corresponding `min` versions ([#27613]).
9091

9192
New library features
9293
--------------------

base/array.jl

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,140 +2096,6 @@ findall(x::Bool) = x ? [1] : Vector{Int}()
20962096
findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}()
20972097
findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}()
20982098

2099-
"""
2100-
findmax(itr) -> (x, index)
2101-
2102-
Return the maximum element of the collection `itr` and its index. If there are multiple
2103-
maximal elements, then the first one will be returned.
2104-
If any data element is `NaN`, this element is returned.
2105-
The result is in line with `max`.
2106-
2107-
The collection must not be empty.
2108-
2109-
# Examples
2110-
```jldoctest
2111-
julia> findmax([8,0.1,-9,pi])
2112-
(8.0, 1)
2113-
2114-
julia> findmax([1,7,7,6])
2115-
(7, 2)
2116-
2117-
julia> findmax([1,7,7,NaN])
2118-
(NaN, 4)
2119-
```
2120-
"""
2121-
findmax(a) = _findmax(a, :)
2122-
2123-
function _findmax(a, ::Colon)
2124-
p = pairs(a)
2125-
y = iterate(p)
2126-
if y === nothing
2127-
throw(ArgumentError("collection must be non-empty"))
2128-
end
2129-
(mi, m), s = y
2130-
i = mi
2131-
while true
2132-
y = iterate(p, s)
2133-
y === nothing && break
2134-
m != m && break
2135-
(i, ai), s = y
2136-
if ai != ai || isless(m, ai)
2137-
m = ai
2138-
mi = i
2139-
end
2140-
end
2141-
return (m, mi)
2142-
end
2143-
2144-
"""
2145-
findmin(itr) -> (x, index)
2146-
2147-
Return the minimum element of the collection `itr` and its index. If there are multiple
2148-
minimal elements, then the first one will be returned.
2149-
If any data element is `NaN`, this element is returned.
2150-
The result is in line with `min`.
2151-
2152-
The collection must not be empty.
2153-
2154-
# Examples
2155-
```jldoctest
2156-
julia> findmin([8,0.1,-9,pi])
2157-
(-9.0, 3)
2158-
2159-
julia> findmin([7,1,1,6])
2160-
(1, 2)
2161-
2162-
julia> findmin([7,1,1,NaN])
2163-
(NaN, 4)
2164-
```
2165-
"""
2166-
findmin(a) = _findmin(a, :)
2167-
2168-
function _findmin(a, ::Colon)
2169-
p = pairs(a)
2170-
y = iterate(p)
2171-
if y === nothing
2172-
throw(ArgumentError("collection must be non-empty"))
2173-
end
2174-
(mi, m), s = y
2175-
i = mi
2176-
while true
2177-
y = iterate(p, s)
2178-
y === nothing && break
2179-
m != m && break
2180-
(i, ai), s = y
2181-
if ai != ai || isless(ai, m)
2182-
m = ai
2183-
mi = i
2184-
end
2185-
end
2186-
return (m, mi)
2187-
end
2188-
2189-
"""
2190-
argmax(itr) -> Integer
2191-
2192-
Return the index of the maximum element in a collection. If there are multiple maximal
2193-
elements, then the first one will be returned.
2194-
2195-
The collection must not be empty.
2196-
2197-
# Examples
2198-
```jldoctest
2199-
julia> argmax([8,0.1,-9,pi])
2200-
1
2201-
2202-
julia> argmax([1,7,7,6])
2203-
2
2204-
2205-
julia> argmax([1,7,7,NaN])
2206-
4
2207-
```
2208-
"""
2209-
argmax(a) = findmax(a)[2]
2210-
2211-
"""
2212-
argmin(itr) -> Integer
2213-
2214-
Return the index of the minimum element in a collection. If there are multiple minimal
2215-
elements, then the first one will be returned.
2216-
2217-
The collection must not be empty.
2218-
2219-
# Examples
2220-
```jldoctest
2221-
julia> argmin([8,0.1,-9,pi])
2222-
3
2223-
2224-
julia> argmin([7,1,1,6])
2225-
2
2226-
2227-
julia> argmin([7,1,1,NaN])
2228-
4
2229-
```
2230-
"""
2231-
argmin(a) = findmin(a)[2]
2232-
22332099
# similar to Matlab's ismember
22342100
"""
22352101
indexin(a, b)

base/reduce.jl

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,202 @@ julia> minimum([1,2,3])
659659
"""
660660
minimum(a) = mapreduce(identity, min, a)
661661

662+
## findmax, findmin, argmax & argmin
663+
664+
"""
665+
findmax(f, domain) -> (f(x), x)
666+
findmax(f)
667+
668+
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
669+
value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there
670+
are multiple maximal points, then the first one will be returned.
671+
672+
When `domain` is provided it may be any iterable and must not be empty.
673+
674+
When `domain` is omitted, `f` must have an implicit domain. In particular, if
675+
`f` is an indexable collection, it is interpreted as a function mapping keys
676+
(domain) to values (codomain), i.e. `findmax(itr)` returns the maximal element
677+
of the collection `itr` and its index.
678+
679+
Values are compared with `isless`.
680+
681+
# Examples
682+
683+
```jldoctest
684+
julia> findmax(identity, 5:9)
685+
(9, 9)
686+
687+
julia> findmax(-, 1:10)
688+
(-1, 1)
689+
690+
julia> findmax(cos, 0:π/2:2π)
691+
(1.0, 0.0)
692+
693+
julia> findmax([8,0.1,-9,pi])
694+
(8.0, 1)
695+
696+
julia> findmax([1,7,7,6])
697+
(7, 2)
698+
699+
julia> findmax([1,7,7,NaN])
700+
(NaN, 4)
701+
```
702+
703+
"""
704+
findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
705+
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)
706+
707+
"""
708+
findmin(f, domain) -> (f(x), x)
709+
findmin(f)
710+
711+
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
712+
value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there
713+
are multiple minimal points, then the first one will be returned.
714+
715+
When `domain` is provided it may be any iterable and must not be empty.
716+
717+
When `domain` is omitted, `f` must have an implicit domain. In particular, if
718+
`f` is an indexable collection, it is interpreted as a function mapping keys
719+
(domain) to values (codomain), i.e. `findmin(itr)` returns the minimal element
720+
of the collection `itr` and its index.
721+
722+
Values are compared with `isgreater`.
723+
724+
# Examples
725+
726+
```jldoctest
727+
julia> findmin(identity, 5:9)
728+
(5, 5)
729+
730+
julia> findmin(-, 1:10)
731+
(-10, 10)
732+
733+
julia> findmin(cos, 0:π/2:2π)
734+
(-1.0, 3.141592653589793)
735+
736+
julia> findmin([8,0.1,-9,pi])
737+
(-9, 3)
738+
739+
julia> findmin([1,7,7,6])
740+
(1, 1)
741+
742+
julia> findmin([1,7,7,NaN])
743+
(NaN, 4)
744+
```
745+
746+
"""
747+
findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
748+
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)
749+
750+
findmax(a) = _findmax(a, :)
751+
752+
function _findmax(a, ::Colon)
753+
p = pairs(a)
754+
y = iterate(p)
755+
if y === nothing
756+
throw(ArgumentError("collection must be non-empty"))
757+
end
758+
(mi, m), s = y
759+
i = mi
760+
while true
761+
y = iterate(p, s)
762+
y === nothing && break
763+
m != m && break
764+
(i, ai), s = y
765+
if ai != ai || isless(m, ai)
766+
m = ai
767+
mi = i
768+
end
769+
end
770+
return (m, mi)
771+
end
772+
773+
findmin(a) = _findmin(a, :)
774+
775+
function _findmin(a, ::Colon)
776+
p = pairs(a)
777+
y = iterate(p)
778+
if y === nothing
779+
throw(ArgumentError("collection must be non-empty"))
780+
end
781+
(mi, m), s = y
782+
i = mi
783+
while true
784+
y = iterate(p, s)
785+
y === nothing && break
786+
m != m && break
787+
(i, ai), s = y
788+
if ai != ai || isless(ai, m)
789+
m = ai
790+
mi = i
791+
end
792+
end
793+
return (m, mi)
794+
end
795+
796+
"""
797+
argmax(f, domain)
798+
argmax(f)
799+
800+
Return a value `x` in the domain of `f` for which `f(x)` is maximised.
801+
If there are multiple maximal values for `f(x)` then the first one will be found.
802+
803+
When `domain` is provided it may be any iterable and must not be empty.
804+
805+
When `domain` is omitted, `f` must have an implicit domain. In particular, if
806+
`f` is an indexable collection, it is interpreted as a function mapping keys
807+
(domain) to values (codomain), i.e. `argmax(itr)` returns the index of the
808+
maximal element in `itr`.
809+
810+
Values are compared with `isless`.
811+
812+
# Examples
813+
```jldoctest
814+
julia> argmax([8,0.1,-9,pi])
815+
1
816+
817+
julia> argmax([1,7,7,6])
818+
2
819+
820+
julia> argmax([1,7,7,NaN])
821+
4
822+
```
823+
"""
824+
argmax(f, domain) = findmax(f, domain)[2]
825+
argmax(f) = findmax(f)[2]
826+
827+
"""
828+
argmin(f, domain)
829+
argmin(f)
830+
831+
Return a value `x` in the domain of `f` for which `f(x)` is minimised.
832+
If there are multiple minimal values for `f(x)` then the first one will be found.
833+
834+
When `domain` is provided it may be any iterable and must not be empty.
835+
836+
When `domain` is omitted, `f` must have an implicit domain. In particular, if
837+
`f` is an indexable collection, it is interpreted as a function mapping keys
838+
(domain) to values (codomain), i.e. `argmin(itr)` returns the index of the
839+
minimal element in `itr`.
840+
841+
Values are compared with `isgreater`.
842+
843+
# Examples
844+
```jldoctest
845+
julia> argmin([8,0.1,-9,pi])
846+
3
847+
848+
julia> argmin([7,1,1,6])
849+
2
850+
851+
julia> argmin([7,1,1,NaN])
852+
4
853+
```
854+
"""
855+
argmin(f, domain) = findmin(f, domain)[2]
856+
argmin(f) = findmin(f)[2]
857+
662858
## all & any
663859

664860
"""

test/reduce.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,38 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1))
338338
@test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3)))
339339
@test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1)
340340

341+
# findmin, findmax, argmin, argmax
342+
343+
@testset "findmin(f, domain)" begin
344+
@test findmin(-, 1:10) == (-10, 10)
345+
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
346+
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
347+
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
348+
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
349+
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
350+
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
351+
end
352+
353+
@testset "findmax(f, domain)" begin
354+
@test findmax(-, 1:10) == (-1, 1)
355+
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
356+
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
357+
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
358+
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
359+
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
360+
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
361+
end
362+
363+
@testset "argmin(f, domain)" begin
364+
@test argmin(-, 1:10) == 10
365+
@test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1)
366+
end
367+
368+
@testset "argmax(f, domain)" begin
369+
@test argmax(-, 1:10) == 1
370+
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
371+
end
372+
341373
# any & all
342374

343375
@test @inferred any([]) == false

0 commit comments

Comments
 (0)