Skip to content

Commit c40a6fb

Browse files
Allow dropdims with reduction to take mutliple args and kwargs
1 parent 012d910 commit c40a6fb

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

NEWS.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@ New library functions
2626
* The `tempname` function now takes a `cleanup::Bool` keyword argument defaulting to `true`, which causes the process to try to ensure that any file or directory at the path returned by `tempname` is deleted upon process exit ([#33090]).
2727
* The `readdir` function now takes a `join::Bool` keyword argument defaulting to `false`, which when set causes `readdir` to join its directory argument with each listed name ([#33113]).
2828

29-
3029
Standard library changes
3130
------------------------
3231

3332
* The methods of `mktemp` and `mktempdir` which take a function body to pass temporary paths to no longer throw errors if the path is already deleted when the function body returns ([#33091]).
3433

35-
* A new `squeeze(f, A, dims)` method computes the reduction `f` over the region in
36-
`A` described by `dims` and then drops those dimensions from the result ([#23500]).
34+
* A new `dropdims(f, args...; dims, kwargs...)` method computes the reduction `f` over the region described by `dims` and then drops those dimensions from the result ([#23500]).
3735

3836
#### Libdl
3937

base/abstractarraymath.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,31 @@ end
8888
_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))
8989

9090
"""
91-
squeeze(f, A, dims)
91+
dropdims(f, args...; dims, kwargs...)
9292
93-
Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result.
93+
Compute reduction `f` over dimensions `dims` and drop those dimensions from the result.
94+
95+
# Examples
96+
```jldoctest
97+
julia> a = [3.0 2.0 6.0 8.0
98+
6.0 1.0 4.0 2.0
99+
3.0 0.0 7.0 6.0];
100+
101+
julia> dropdims(sum, a, dims=1)
102+
4-element Array{Float64,1}:
103+
12.0
104+
3.0
105+
17.0
106+
16.0
107+
108+
julia> dropdims(sum, abs2, a, dims=2)
109+
3-element Array{Float64,1}:
110+
113.0
111+
57.0
112+
94.0
113+
```
94114
"""
95-
squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims)
115+
dropdims(f, args...; dims, kwargs...) = _dropdims(f(args...; kwargs..., dims=dims), dims)
96116

97117
## Unary operators ##
98118

test/arrayops.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,36 @@ end
303303
@test_throws ArgumentError dropdims(a, dims=4)
304304
@test_throws ArgumentError dropdims(a, dims=6)
305305

306-
@test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1))
307-
@test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1))
308-
@test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1))
309-
@test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8))
310-
@test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8))
311-
@test_throws ArgumentError squeeze(sum, a, 0)
312-
@test_throws ArgumentError squeeze(sum, a, (1, 1))
313-
@test_throws ArgumentError squeeze(sum, a, (1, 2, 1))
314-
@test_throws ArgumentError squeeze(sum, a, (1, 1, 2))
315-
@test_throws ArgumentError squeeze(sum, a, 6)
306+
# dropdims with reductions. issue #16606
307+
@test (@inferred(dropdims(sum, a, dims=1)) ==
308+
@inferred(dropdims(sum, a, dims=(1,))) ==
309+
reshape(sum(a, dims=1), (1, 8, 8, 1)))
310+
@test (@inferred(dropdims(sum, a, dims=3)) ==
311+
@inferred(dropdims(sum, a, dims=(3,))) ==
312+
reshape(sum(a, dims=3), (1, 1, 8, 1)))
313+
@test (@inferred(dropdims(sum, a, dims=4)) ==
314+
@inferred(dropdims(sum, a, dims=(4,))) ==
315+
reshape(sum(a, dims=4), (1, 1, 8, 1)))
316+
@test (@inferred(dropdims(sum, a, dims=(1, 5))) ==
317+
dropdims(sum, a, dims=(5, 1)) ==
318+
reshape(sum(a, dims=(5, 1)), (1, 8, 8)))
319+
@test (@inferred(dropdims(sum, a, dims=(1, 2, 5))) ==
320+
dropdims(sum, a, dims=(5, 2, 1)) ==
321+
reshape(sum(a, dims=(5, 2, 1)), (8, 8)))
322+
@test (@inferred(dropdims(sum, abs2, a, dims=1)) ==
323+
@inferred(dropdims(sum, abs2, a, dims=(1,))) ==
324+
reshape(sum(abs2, a, dims=1), (1, 8, 8, 1)))
325+
_sumplus(x; dims, plus) = sum(x; dims=dims) .+ plus # reduction with keywords
326+
@test (@inferred(dropdims(_sumplus, a, dims=4, plus=1)) ==
327+
@inferred(dropdims(_sumplus, a, dims=(4,), plus=1)) ==
328+
reshape(sum(a, dims=4) .+ 1, (1, 1, 8, 1)))
329+
@test_throws UndefKeywordError dropdims(sum, a)
330+
@test_throws UndefKeywordError dropdims(sum, a, 1)
331+
@test_throws ArgumentError dropdims(sum, a, dims=0)
332+
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1))
333+
@test_throws ArgumentError dropdims(sum, a, dims=(1, 2, 1))
334+
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1, 2))
335+
@test_throws ArgumentError dropdims(sum, a, dims=6)
316336

317337
sz = (5,8,7)
318338
A = reshape(1:prod(sz),sz...)

0 commit comments

Comments
 (0)