Skip to content

Commit b7348e7

Browse files
authored
Merge pull request #353 from JuliaGPU/tb/statistics
Port statistics functions from CUDA.jl.
2 parents 5f5db15 + e483d11 commit b7348e7

File tree

6 files changed

+106
-0
lines changed

6 files changed

+106
-0
lines changed

Manifest.toml

+8
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,13 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3030
[[Serialization]]
3131
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
3232

33+
[[SparseArrays]]
34+
deps = ["LinearAlgebra", "Random"]
35+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
36+
37+
[[Statistics]]
38+
deps = ["LinearAlgebra", "SparseArrays"]
39+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
40+
3341
[[Unicode]]
3442
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213

1314
[compat]
1415
AbstractFFTs = "0.4, 0.5, 1.0"

src/GPUArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("host/math.jl")
3333
include("host/random.jl")
3434
include("host/quirks.jl")
3535
include("host/uniformscaling.jl")
36+
include("host/statistics.jl")
3637

3738
include("deprecated.jl")
3839

src/host/statistics.jl

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using Statistics
2+
3+
Statistics.varm(A::AbstractGPUArray{<:Real},m::AbstractArray{<:Real}; dims, corrected::Bool=true) =
4+
sum((A .- m).^2, dims=dims)/(prod(size(A)[[dims...]])::Int-corrected)
5+
6+
Statistics.stdm(A::AbstractGPUArray{<:Real},m::AbstractArray{<:Real}, dim::Int; corrected::Bool=true) =
7+
sqrt.(varm(A,m;dims=dim,corrected=corrected))
8+
9+
Statistics._std(A::AbstractGPUArray, corrected::Bool, mean, dims) =
10+
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=dims))
11+
12+
Statistics._std(A::AbstractGPUArray, corrected::Bool, mean, ::Colon) =
13+
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=:))
14+
15+
# Revert https://github.com/JuliaLang/Statistics.jl/pull/25
16+
Statistics._mean(A::AbstractGPUArray, ::Colon) = sum(A) / length(A)
17+
Statistics._mean(f, A::AbstractGPUArray, ::Colon) = sum(f, A) / length(A)
18+
Statistics._mean(A::AbstractGPUArray, dims) = mean!(Base.reducedim_init(t -> t/2, +, A, dims), A)
19+
Statistics._mean(f, A::AbstractGPUArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1)
20+
21+
function Statistics.covzm(x::AbstractGPUMatrix, vardim::Int=1; corrected::Bool=true)
22+
C = Statistics.unscaled_covzm(x, vardim)
23+
T = promote_type(typeof(one(eltype(C)) / 1), eltype(C))
24+
A = convert(AbstractArray{T}, C)
25+
b = 1//(size(x, vardim) - corrected)
26+
A .*= b
27+
return A
28+
end
29+
30+
function Statistics.cov2cor!(C::AbstractGPUMatrix{T}, xsd::AbstractGPUArray) where T
31+
nx = length(xsd)
32+
size(C) == (nx, nx) || throw(DimensionMismatch("inconsistent dimensions"))
33+
tril!(C, -1)
34+
C += adjoint(C)
35+
C = Statistics.clampcor.(C ./ (xsd * xsd'))
36+
C[diagind(C)] .= oneunit(T)
37+
return C
38+
end
39+
40+
function Statistics.corzm(x::AbstractGPUMatrix, vardim::Int=1)
41+
c = Statistics.unscaled_covzm(x, vardim)
42+
return Statistics.cov2cor!(c, sqrt.(diag(c)))
43+
end

test/testsuite.jl

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ include("testsuite/linalg.jl")
7373
include("testsuite/math.jl")
7474
include("testsuite/random.jl")
7575
include("testsuite/uniformscaling.jl")
76+
include("testsuite/statistics.jl")
7677

7778
"""
7879
Runs the entire GPUArrays test suite on array type `AT`

test/testsuite/statistics.jl

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Statistics
2+
3+
@testsuite "statistics" AT->begin
4+
@testset "std" begin
5+
@test compare(std, AT, rand(10))
6+
@test compare(std, AT, rand(10,1,2))
7+
@test compare(std, AT, rand(10,1,2); corrected=true)
8+
@test compare(std, AT, rand(10,1,2); dims=1)
9+
end
10+
11+
@testset "var" begin
12+
@test compare(var, AT, rand(10))
13+
@test compare(var, AT, rand(10,1,2))
14+
@test compare(var, AT, rand(10,1,2); corrected=true)
15+
@test compare(var, AT, rand(10,1,2); dims=1)
16+
@test compare(var, AT, rand(10,1,2); dims=[1])
17+
@test compare(var, AT, rand(10,1,2); dims=(1,))
18+
@test compare(var, AT, rand(10,1,2); dims=[2,3])
19+
@test compare(var, AT, rand(10,1,2); dims=(2,3))
20+
end
21+
22+
@testset "mean" begin
23+
@test compare(mean, AT, rand(2,2))
24+
@test compare(mean, AT, rand(2,2); dims=2)
25+
@test compare(mean, AT, rand(2,2,2); dims=[1,3])
26+
@test compare(x->mean(sin, x), AT, rand(2,2))
27+
@test compare(x->mean(sin, x; dims=2), AT, rand(2,2))
28+
@test compare(x->mean(sin, x; dims=[1,3]), AT, rand(2,2,2))
29+
end
30+
31+
@testset "cov" begin
32+
s = 100
33+
@test compare(cov, AT, rand(s))
34+
@test compare(cov, AT, rand(Complex{Float64}, s))
35+
@test compare(cov, AT, rand(s, 2))
36+
@test compare(cov, AT, rand(Complex{Float64}, s, 2))
37+
@test compare(cov, AT, rand(s, 2); dims=2)
38+
@test compare(cov, AT, rand(Complex{Float64}, s, 2); dims=2)
39+
@test compare(cov, AT, rand(1:100, s))
40+
end
41+
42+
@testset "cor" begin
43+
s = 100
44+
@test compare(cor, AT, rand(s))
45+
@test compare(cor, AT, rand(Complex{Float64}, s))
46+
@test compare(cor, AT, rand(s, 2))
47+
@test compare(cor, AT, rand(Complex{Float64}, s, 2))
48+
@test compare(cor, AT, rand(s, 2); dims=2)
49+
@test compare(cor, AT, rand(Complex{Float64}, s, 2); dims=2)
50+
@test compare(cor, AT, rand(1:100, s))
51+
end
52+
end

0 commit comments

Comments
 (0)