Skip to content

Commit 17e9b0a

Browse files
WIP: use RFFT.jl
1 parent 20df36d commit 17e9b0a

File tree

5 files changed

+140
-25
lines changed

5 files changed

+140
-25
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ author = ["Tim Holy <[email protected]>", "Jan Weidner <[email protected]>"]
44
version = "0.7.8"
55

66
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
78
CatIndices = "aafaddc9-749c-510e-ac4f-586e18779b91"
89
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -13,8 +14,8 @@ ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
1314
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
16-
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
18+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1819
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1920
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2021
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -29,8 +30,8 @@ FFTW = "0.3, 1"
2930
ImageBase = "0.1.5"
3031
ImageCore = "0.10"
3132
OffsetArrays = "1.9"
32-
Reexport = "1.1"
3333
PrecompileTools = "1"
34+
Reexport = "1.1"
3435
StaticArrays = "0.10, 0.11, 0.12, 1.0"
3536
TiledIteration = "0.2, 0.3, 0.4, 0.5"
3637
julia = "1.6"

demo.jl

+28-9
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,34 @@ using ImageFiltering, FFTW, LinearAlgebra, Profile, Random
22
# using ProfileView
33
using ComputationalResources
44

5-
FFTW.set_num_threads(parse(Int, ENV["FFTW_NUM_THREADS"]))
6-
BLAS.set_num_threads(parse(Int, ENV["BLAS_NUM_THREADS"]))
5+
FFTW.set_num_threads(parse(Int, get(ENV, "FFTW_NUM_THREADS", "1")))
6+
BLAS.set_num_threads(parse(Int, get(ENV, "BLAS_NUM_THREADS", string(Threads.nthreads() ÷ 2))))
77

88
function benchmark(mats)
99
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
1010
Threads.@threads for mat in mats
11-
frame_filtered = similar(mat[:, :, 1])
12-
r = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
11+
frame_filtered = deepcopy(mat[:, :, 1])
12+
r_cached = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
1313
for i in axes(mat, 3)
1414
frame = @view mat[:, :, i]
15-
imfilter!(r, frame_filtered, frame, kernel)
15+
imfilter!(r_cached, frame_filtered, frame, kernel)
16+
end
17+
return
18+
end
19+
end
20+
21+
function test(mats)
22+
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
23+
for mat in mats
24+
f1 = deepcopy(mat[:, :, 1])
25+
r_cached = CPU1(ImageFiltering.planned_fft(f1, kernel))
26+
f2 = deepcopy(mat[:, :, 1])
27+
r_noncached = CPU1(Algorithm.FFT())
28+
for i in axes(mat, 3)
29+
frame = @view mat[:, :, i]
30+
imfilter!(r_cached, f1, frame, kernel)
31+
imfilter!(r_noncached, f2, frame, kernel)
32+
all(f1 .≈ f2) || error("f1 !≈ f2")
1633
end
1734
return
1835
end
@@ -24,11 +41,13 @@ function profile()
2441
mats = [rand(Float32, rand(80:100), rand(80:100), rand(2000:3000)) for _ in 1:nmats]
2542
GC.gc(true)
2643

27-
benchmark(mats)
44+
# benchmark(mats)
2845

29-
for _ in 1:3
30-
@time "warm run of benchmark(mats)" benchmark(mats)
31-
end
46+
# for _ in 1:3
47+
# @time "warm run of benchmark(mats)" benchmark(mats)
48+
# end
49+
50+
test(mats)
3251

3352
# Profile.clear()
3453
# @profile benchmark(mats)

src/ImageFiltering.jl

+16-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module ImageFiltering
22

33
using FFTW
4+
include("RFFT.jl") # TODO: Register RFFT.jl on General and add as a dependency
45
using ImageCore, FFTViews, OffsetArrays, StaticArrays, ComputationalResources, TiledIteration
56
# Where possible we avoid a direct dependency to reduce the number of [compat] bounds
67
# using FixedPointNumbers: Normed, N0f8 # reexported by ImageCore
78
using ImageCore.MappedArrays
89
using Statistics, LinearAlgebra
910
using Base: Indices, tail, fill_to_length, @pure, depwarn, @propagate_inbounds
11+
import Base: copy!
1012
using OffsetArrays: IdentityUnitRange # using the one in OffsetArrays makes this work with multiple Julia versions
1113
using SparseArrays # only needed to fix an ambiguity in borderarray
1214
using Reexport
@@ -51,13 +53,23 @@ end
5153

5254
module Algorithm
5355
import FFTW
56+
import ..RFFT
57+
struct BufferedFFTPlan{T<:AbstractFloat}
58+
plan::Function
59+
buf::RFFT.RCpair{T}
60+
end
61+
function BufferedFFTPlan(a::AbstractArray{T}) where {T<:AbstractFloat}
62+
buf = RFFT.RCpair{T}(undef, size(a))
63+
plan = RFFT.plan_rfft!(buf)
64+
BufferedFFTPlan(plan, buf)
65+
end
5466
# deliberately don't export these, but it's expected that they
5567
# will be used as Algorithm.FFT(), etc.
5668
abstract type Alg end
5769
"Filter using the Fast Fourier Transform" struct FFT <: Alg
58-
plan1::Union{FFTW.rFFTWPlan,Nothing}
59-
plan2::Union{FFTW.rFFTWPlan,Nothing}
60-
plan3::Union{FFTW.AbstractFFTs.ScaledPlan,Nothing}
70+
plan1::Union{BufferedFFTPlan,Nothing}
71+
plan2::Union{BufferedFFTPlan,Nothing}
72+
plan3::Union{BufferedFFTPlan,Nothing}
6173
end
6274
FFT() = FFT(nothing, nothing, nothing)
6375
"Filter using a direct algorithm" struct FIR <: Alg end
@@ -69,7 +81,7 @@ module Algorithm
6981

7082
FIRTiled() = FIRTiled(())
7183
end
72-
using .Algorithm: Alg, FFT, FIR, FIRTiled, IIR, Mixed
84+
using .Algorithm: Alg, FFT, FIR, FIRTiled, IIR, Mixed, BufferedFFTPlan
7385

7486
Alg(r::AbstractResource{A}) where {A<:Alg} = r.settings
7587

src/RFFT.jl

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
module RFFT
2+
3+
using FFTW, LinearAlgebra
4+
5+
export RCpair, plan_rfft!, plan_irfft!, rfft!, irfft!, normalization
6+
7+
import Base: real, complex, copy, copy!
8+
9+
mutable struct RCpair{T<:AbstractFloat,N,RType<:AbstractArray{T,N},CType<:AbstractArray{Complex{T},N}}
10+
R::RType
11+
C::CType
12+
region::Vector{Int}
13+
end
14+
15+
function RCpair{T}(::UndefInitializer, realsize::Dims{N}, region=1:length(realsize)) where {T<:AbstractFloat,N}
16+
sz = [realsize...]
17+
firstdim = region[1]
18+
sz[firstdim] = realsize[firstdim]>>1 + 1
19+
sz2 = copy(sz)
20+
sz2[firstdim] *= 2
21+
R = Array{T,N}(undef, (sz2...,)::Dims{N})
22+
C = unsafe_wrap(Array, convert(Ptr{Complex{T}}, pointer(R)), (sz...,)::Dims{N}) # work around performance problems of reinterpretarray
23+
RCpair(view(R, map(n->1:n, realsize)...), C, [region...])
24+
end
25+
26+
RCpair(A::Array{T}, region=1:ndims(A)) where {T<:AbstractFloat} = copy!(RCpair{T}(undef, size(A), region), A)
27+
28+
real(RC::RCpair) = RC.R
29+
complex(RC::RCpair) = RC.C
30+
31+
copy!(RC::RCpair, A::AbstractArray{T}) where {T<:Real} = (copy!(RC.R, A); RC)
32+
function copy(RC::RCpair{T,N}) where {T,N}
33+
C = copy(RC.C)
34+
R = reshape(reinterpret(T, C), size(parent(RC.R)))
35+
RCpair(view(R, RC.R.indices...), C, copy(RC.region))
36+
end
37+
38+
# New API
39+
rplan_fwd(R, C, region, flags, tlim) =
40+
FFTW.rFFTWPlan{eltype(R),FFTW.FORWARD,true,ndims(R)}(R, C, region, flags, tlim)
41+
rplan_inv(R, C, region, flags, tlim) =
42+
FFTW.rFFTWPlan{eltype(R),FFTW.BACKWARD,true,ndims(R)}(R, C, region, flags, tlim)
43+
function plan_rfft!(RC::RCpair{T}; flags::Integer = FFTW.ESTIMATE, timelimit::Real = FFTW.NO_TIMELIMIT) where T
44+
p = rplan_fwd(RC.R, RC.C, RC.region, flags, timelimit)
45+
return Z::RCpair -> begin
46+
FFTW.assert_applicable(p, Z.R, Z.C)
47+
FFTW.unsafe_execute!(p, Z.R, Z.C)
48+
return Z
49+
end
50+
end
51+
function plan_irfft!(RC::RCpair{T}; flags::Integer = FFTW.ESTIMATE, timelimit::Real = FFTW.NO_TIMELIMIT) where T
52+
p = rplan_inv(RC.C, RC.R, RC.region, flags, timelimit)
53+
return Z::RCpair -> begin
54+
FFTW.assert_applicable(p, Z.C, Z.R)
55+
FFTW.unsafe_execute!(p, Z.C, Z.R)
56+
rmul!(Z.R, 1 / prod(size(Z.R)[Z.region]))
57+
return Z
58+
end
59+
end
60+
function rfft!(RC::RCpair{T}) where T
61+
p = rplan_fwd(RC.R, RC.C, RC.region, FFTW.ESTIMATE, FFTW.NO_TIMELIMIT)
62+
FFTW.unsafe_execute!(p, RC.R, RC.C)
63+
return RC
64+
end
65+
function irfft!(RC::RCpair{T}) where T
66+
p = rplan_inv(RC.C, RC.R, RC.region, FFTW.ESTIMATE, FFTW.NO_TIMELIMIT)
67+
FFTW.unsafe_execute!(p, RC.C, RC.R)
68+
rmul!(RC.R, 1 / prod(size(RC.R)[RC.region]))
69+
return RC
70+
end
71+
72+
@deprecate RCpair(realtype::Type{T}, realsize, region=1:length(realsize)) where T<:AbstractFloat RCpair{T}(undef, realsize, region)
73+
74+
end

src/imfilter.jl

+19-10
Original file line numberDiff line numberDiff line change
@@ -840,25 +840,34 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
840840
out
841841
end
842842

843+
copy!(p::BufferedFFTPlan, a::AbstractArray{T}) where {T} = copy!(p.buf, a)
844+
function updaterun!(p::BufferedFFTPlan, a::AbstractArray{T}) where {T}
845+
copy!(p.buf, OffsetArrays.no_offset_view(a))
846+
p.plan(p.buf)
847+
end
848+
843849
function planned_fft(A::AbstractArray{T,N},
844850
kernel::Tuple{AbstractArray,Vararg{AbstractArray}},
845851
border::BorderSpecAny=Pad(:replicate)) where {T,N}
846852
bord = border(kernel, A, Algorithm.FFT())
847853
_A = padarray(T, A, bord)
848-
p1 = plan_rfft(_A)
854+
bfp1 = BufferedFFTPlan(_A)
855+
B = real(updaterun!(bfp1, _A)) * FFTW.AbstractFFTs.to1(_A)
849856
kern = samedims(_A, kernelconv(kernel...))
850857
krn = FFTView(zeros(eltype(kern), map(length, axes(_A))))
851-
p2 = plan_rfft(krn)
852-
B = p1 * _A
853-
B .*= conj!(p2 * krn)
854-
p3 = plan_irfft(B, length(axes(_A, 1)))
855-
return Algorithm.FFT(p1, p2, p3)
858+
for I in CartesianIndices(axes(kern))
859+
krn[I] = kern[I]
860+
end
861+
bfp2 = BufferedFFTPlan(krn)
862+
B .*= conj!(real(updaterun!(bfp2, krn)) * FFTW.AbstractFFTs.to1(krn))
863+
bfp3 = BufferedFFTPlan(B)
864+
return Algorithm.FFT(bfp1, bfp2, bfp3)
856865
end
857866

858-
function filtfft(A, krn, plan_A::FFTW.rFFTWPlan, plan_krn::FFTW.rFFTWPlan, plan_B::FFTW.AbstractFFTs.ScaledPlan)
859-
B = plan_A * A
860-
B .*= conj!(plan_krn * krn)
861-
plan_B * B
867+
function filtfft(A, krn, bfp1::BufferedFFTPlan, bfp2::BufferedFFTPlan, bfp3::BufferedFFTPlan)
868+
B = real(updaterun!(bfp1, A)) * FFTW.AbstractFFTs.to1(A)
869+
B .*= conj!(real(updaterun!(bfp2, krn)) * FFTW.AbstractFFTs.to1(krn))
870+
return real(updaterun!(bfp3, B)) * B
862871
end
863872
filtfft(A, krn, ::Nothing, ::Nothing, ::Nothing) = filtfft(A, krn)
864873
function filtfft(A, krn)

0 commit comments

Comments
 (0)