Skip to content

Commit 460dfae

Browse files
committed
BNNS Random extension
1 parent 00235a2 commit 460dfae

File tree

6 files changed

+351
-4
lines changed

6 files changed

+351
-4
lines changed

Project.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@ uuid = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
33
version = "0.4.1"
44

55
[deps]
6+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
7+
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
68
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810

11+
[weakdeps]
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
14+
[extensions]
15+
RandomExt = "Random"
16+
917
[compat]
18+
BFloat16s = "0.5.0"
19+
CEnum = "0.5.0"
1020
julia = "1.9"
1121

1222
[extras]
23+
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
1324
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1425
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1526
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -18,7 +29,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1829
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
1930
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2031
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
21-
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
2232
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2333
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2434

ext/RandomExt.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
module RandomExt
2+
3+
@static if Sys.isapple()
4+
5+
using BFloat16s
6+
using AppleAccelerate: BNNS
7+
using .BNNS: BNNSFilterParameters,
8+
BNNSRandomGeneratorMethodAES_CTR,
9+
BNNSCreateRandomGenerator,
10+
BNNSCreateRandomGeneratorWithSeed,
11+
BNNSRandomGeneratorStateSize,
12+
BNNSRandomGeneratorSetState,
13+
BNNSRandomGeneratorGetState,
14+
BNNSNDArrayDescriptor,
15+
BNNSRandomFillNormalFloat,
16+
BNNSRandomFillUniformFloat,
17+
BNNSRandomFillUniformInt
18+
using Random: Random, AbstractRNG
19+
20+
"""
21+
RNG()
22+
23+
A random number generator using AppleAccelerate's BNNS functionality.
24+
"""
25+
mutable struct RNG <: AbstractRNG
26+
ptr::Ptr{Nothing}
27+
function RNG(filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
28+
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
29+
res = new(BNNSCreateRandomGenerator(BNNSRandomGeneratorMethodAES_CTR, params))
30+
# finalizer(res) do
31+
# BNNSDestroyRandomGenerator(res.ptr)
32+
# end
33+
return res
34+
end
35+
function RNG(seed::Integer, filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
36+
seed = seed%UInt64
37+
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
38+
res = new(BNNSCreateRandomGeneratorWithSeed(BNNSRandomGeneratorMethodAES_CTR, seed, params))
39+
# finalizer(res) do
40+
# BNNSDestroyRandomGenerator(res.ptr)
41+
# end
42+
return res
43+
end
44+
end
45+
46+
BNNS.bnns_rng() = RNG()
47+
BNNS.bnns_rng(seed::Integer) = RNG(seed)
48+
49+
@static if isdefined(Base, :Memory) #VERSION >= v"1.11"
50+
function _get_rng_state(rng::RNG)
51+
stateSize = BNNSRandomGeneratorStateSize(rng.ptr)
52+
state = Memory{UInt8}(undef, Int64(stateSize))
53+
BNNSRandomGeneratorGetState(rng.ptr, stateSize, state)
54+
return state
55+
end
56+
else
57+
function _get_rng_state(rng::RNG)
58+
stateSize = BNNSRandomGeneratorStateSize(rng.ptr)
59+
state = Vector{UInt8}(undef, Int64(stateSize))
60+
BNNSRandomGeneratorGetState(rng.ptr, stateSize, state)
61+
return state
62+
end
63+
end
64+
65+
function Base.copy!(dest::RNG, src::RNG)
66+
state = _get_rng_state(src)
67+
BNNSRandomGeneratorSetState(dest.ptr, length(state), state)
68+
return dest
69+
end
70+
71+
function Base.copy(rng::RNG)
72+
newrng = RNG()
73+
return copy!(newrng, rng)
74+
end
75+
76+
Base.:(==)(rng1::RNG, rng2::RNG) = _get_rng_state(rng1) == _get_rng_state(rng2)
77+
78+
function Random.seed!(rng::RNG, seed::Integer)
79+
return copy!(rng, RNG(seed))
80+
end
81+
82+
function Random.seed!(rng::RNG)
83+
return copy!(rng, RNG())
84+
end
85+
86+
const GLOBAL_RNG = Ref{RNG}()
87+
function BNNS.default_rng()
88+
if !isassigned(GLOBAL_RNG)
89+
GLOBAL_RNG[] = BNNS.bnns_rng()
90+
end
91+
return GLOBAL_RNG[]
92+
end
93+
94+
const BNNSInt = Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}
95+
const BNNSFloat = Union{Float16, Float32, BFloat16}
96+
97+
const BNNSUniform = Union{<:BNNSInt,<:BNNSFloat}
98+
const BNNSNormal = BNNSFloat
99+
100+
function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSInt}
101+
isempty(A) && return A
102+
desc = Ref(BNNSNDArrayDescriptor(A))
103+
res = BNNSRandomFillUniformInt(rng.ptr, desc, typemin(signed(T)), typemax(signed(T)))
104+
@assert res == 0
105+
return A
106+
end
107+
function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
108+
isempty(A) && return A
109+
desc = Ref(BNNSNDArrayDescriptor(A))
110+
res = BNNSRandomFillUniformFloat(rng.ptr, desc, T(0), T(1))
111+
@assert res == 0
112+
return A
113+
end
114+
function Random.randn!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
115+
isempty(A) && return A
116+
desc = Ref(BNNSNDArrayDescriptor(A))
117+
res = BNNSRandomFillNormalFloat(rng.ptr, desc, Float32(0), Float32(1))
118+
@assert res == 0
119+
return A
120+
end
121+
122+
# Out of place
123+
Random.rand(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSUniform =
124+
Random.rand!(rng, Array{T,length(dims)}(undef, dims...))
125+
Random.randn(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSNormal =
126+
Random.randn!(rng, Array{T,length(dims)}(undef, dims...))
127+
128+
# support all dimension specifications
129+
Random.rand(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform =
130+
Random.rand!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))
131+
Random.randn(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal =
132+
Random.randn!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))
133+
134+
# untyped out-of-place
135+
Random.rand(rng::RNG, dim1::Integer, dims::Integer...) =
136+
Random.rand!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))
137+
Random.randn(rng::RNG, dim1::Integer, dims::Integer...) =
138+
Random.randn!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))
139+
140+
# scalars
141+
Random.rand(rng::RNG, T::Union{Type{Float16}, Type{Float32}, Type{BFloat16},
142+
Type{Int8}, Type{UInt8},
143+
Type{Int16}, Type{UInt16},
144+
Type{Int32}, Type{UInt32},
145+
Type{Int64}, Type{UInt64}}=Float32) = Random.rand(rng, T, 1)[1]
146+
147+
# This is the only way I could fix method ambiguity
148+
Random.randn(rng::RNG, T::Type{BFloat16}) = Random.randn(rng, T, 1)[1]
149+
Random.randn(rng::RNG, T::Type{Float16}) = Random.randn(rng, T, 1)[1]
150+
Random.randn(rng::RNG, T::Type{Float32}) = Random.randn(rng, T, 1)[1]
151+
Random.randn(rng::RNG) = Random.randn(rng, Float32)
152+
153+
154+
# GPUArrays out-of-place
155+
function BNNS.rand(::Type{T}, dims::Dims) where T <: BNNSUniform
156+
return Random.rand!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
157+
end
158+
function BNNS.randn(::Type{T}, dims::Dims) where T <: BNNSNormal
159+
return Random.randn!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
160+
end
161+
162+
# support all dimension specifications
163+
function BNNS.rand(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform
164+
return Random.rand!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
165+
end
166+
function BNNS.randn(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal
167+
return Random.randn!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
168+
end
169+
170+
# untyped out-of-place
171+
BNNS.rand(dim1::Integer, dims::Integer...) =
172+
Random.rand!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))
173+
BNNS.randn(dim1::Integer, dims::Integer...) =
174+
Random.randn!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))
175+
176+
# scalars
177+
BNNS.rand(T::Type=Float32) = BNNS.rand(T, 1)[1]
178+
BNNS.randn(T::Type=Float32) = BNNS.randn(T, 1)[1]
179+
180+
# seeding
181+
function BNNS.seed!(seed=Base.rand(UInt64))
182+
Random.seed!(BNNS.default_rng(), seed)
183+
end
184+
185+
end
186+
end # module

lib/BNNS/BNNS.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using BFloat16s
2+
3+
include("libBNNS.jl")
4+
5+
bnnsdatatype_modifier(::Type{T}) where {T <: Union{AbstractFloat, Bool}} = BNNSDataTypeFloatBit
6+
bnnsdatatype_modifier(::Type{T}) where {T <: Signed} = BNNSDataTypeIntBit
7+
bnnsdatatype_modifier(::Type{T}) where {T <: Unsigned} = BNNSDataTypeUIntBit
8+
bnnsdatatype_modifier(::Type{Bool}) = BNNSDataTypeMiscellaneousBit
9+
bnnsdatatype_modifier(::Type{BFloat16}) = 0x18000
10+
11+
Base.convert(::Type{BNNSDataType}, T) = BNNSDataType(bnnsdatatype_modifier(T) | UInt32(sizeof(T)*8))
12+
13+
function BNNSNDArrayDescriptor(arr::AbstractArray{T, N}) where {T,N}
14+
N > 8 && throw(ArgumentError("BNNSNDArrays do not support more than 8 dimensions."))
15+
16+
17+
layout = BNNSDataLayout(UInt32(N) * UInt32(BNNSDataLayoutVector) | 0x8000)
18+
# layout = datalayout[N]
19+
sz = ntuple(Val(8)) do i
20+
Csize_t(get(size(arr), i, 0))
21+
end
22+
stride = ntuple(_ -> Csize_t(0), Val(8))
23+
return GC.@preserve arr BNNSNDArrayDescriptor(BNNSNDArrayFlagBackpropSet,
24+
layout,
25+
sz,
26+
stride,
27+
Ptr{Nothing}(pointer(arr)),
28+
T,
29+
0,
30+
T,
31+
1,
32+
0)
33+
end
34+
35+
# Definitions for the Random extension
36+
function bnns_rng end
37+
function default_rng end
38+
function rand end
39+
function randn end
40+
function rand! end
41+
function randn! end
42+
function seed! end

src/AppleAccelerate.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,16 @@ function __init__()
9191
load_accelerate(; load_ilp64=true, use_external_lapack=false)
9292
end
9393

94-
if Sys.isapple()
94+
@static if Sys.isapple()
9595
include("Util.jl")
9696
include("Array.jl")
9797
include("DSP.jl")
9898
end
9999

100+
module BNNS
101+
@static if Sys.isapple()
102+
include("../lib/BNNS/BNNS.jl")
103+
end
104+
end # module BNNS
105+
100106
end # module

test/BNNS.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
const RAND_TYPES = [BFloat16, Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
2+
UInt64]
3+
const RANDN_TYPES = [BFloat16, Float16, Float32]
4+
const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES];
5+
[(randn!, T) for T in RANDN_TYPES]]
6+
const OOPLACE_TUPLES = [[(BNNS.rand, rand, T) for T in RAND_TYPES];
7+
[(BNNS.randn, rand, T) for T in RANDN_TYPES]]
8+
9+
@testset "random" begin
10+
# in-place
11+
@testset "in-place" begin
12+
rng = BNNS.bnns_rng()
13+
14+
@testset "$f with $T" for (f, T) in INPLACE_TUPLES
15+
# d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
16+
@testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
17+
A = Array{T}(undef, d)
18+
19+
# specifie BNNS rng
20+
fill!(A, T(0))
21+
f(rng, A)
22+
@test !iszero(collect(A))
23+
end
24+
25+
@testset "0" begin
26+
A = Array{T}(undef, 0)
27+
28+
# specified BNNS rng
29+
fill!(A, T(0))
30+
f(rng, A)
31+
@test Array(A) == fill(1, 0)
32+
end
33+
end
34+
end
35+
# out-of-place
36+
@testset "out-of-place" begin
37+
@testset "$fr with implicit type" for (fm, fr, T) in
38+
((BNNS.rand, Random.rand, Float32), (BNNS.randn, Random.randn, Float32))
39+
rng = BNNS.bnns_rng()
40+
@testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000))
41+
# default_rng
42+
A = fm(args...)
43+
@test eltype(A) == T
44+
45+
# specified MPS rng
46+
B = fr(rng, args...)
47+
@test eltype(B) == T
48+
end
49+
50+
@testset "scalar" begin
51+
a = fm()
52+
@test typeof(a) == T
53+
b = fr(rng)
54+
@test typeof(b) == T
55+
end
56+
end
57+
58+
# out-of-place, with type specified
59+
@testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES
60+
rng = BNNS.bnns_rng()
61+
@testset "$args" for args in ((T, 0),
62+
(T, 1),
63+
(T, 3),
64+
(T, 3, 3),
65+
(T, (3, 3)),
66+
(T, 16),
67+
(T, 16, 16),
68+
(T, (16, 16)),
69+
(T, 1000),
70+
(T, 1000, 1000),)
71+
# default_rng
72+
A = fm(args...)
73+
@test eltype(A) == T
74+
75+
# specified RNG rng
76+
B = fr(rng, args...)
77+
@test eltype(B) == T
78+
end
79+
80+
@testset "scalar" begin
81+
a = fm(T)
82+
@test typeof(a) == T
83+
b = fr(rng, T)
84+
@test typeof(b) == T
85+
end
86+
end
87+
end
88+
89+
## seeding
90+
@testset "Seeding" begin
91+
@testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), (3,3,3,3), (3,3,3,3,3), (3,3,3,3,3,3))
92+
rng = BNNS.bnns_rng(1)
93+
a = rand(rng, Float32, d)
94+
Random.seed!(rng, 1)
95+
b = rand(rng, Float32, d)
96+
@test a == b
97+
end
98+
end
99+
end # testset

0 commit comments

Comments
 (0)