Skip to content

Commit da32a43

Browse files
Merge pull request #1615 from CliMA/ck/fv_specialize
Specialize on diagonal fieldvector broadcasts
2 parents 202ed63 + 7d31f3b commit da32a43

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

src/Fields/fieldvector.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,93 @@ end
203203
@inline transform_broadcasted(fv::FieldVector, symb, axes) =
204204
parent(getfield(_values(fv), symb))
205205
@inline transform_broadcasted(x, symb, axes) = x
206+
207+
@inline function first_fieldvector_in_bc(args::Tuple, rargs...)
208+
x1 = first_fieldvector_in_bc(args[1], rargs...)
209+
x1 isa FieldVector && return x1
210+
return first_fieldvector_in_bc(Base.tail(args), rargs...)
211+
end
212+
213+
@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) =
214+
first_fieldvector_in_bc(args[1], rargs...)
215+
@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing
216+
@inline first_fieldvector_in_bc(x) = nothing
217+
@inline first_fieldvector_in_bc(x::FieldVector) = x
218+
219+
@inline first_fieldvector_in_bc(
220+
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
221+
) = first_fieldvector_in_bc(bc.args)
222+
223+
@inline _is_diagonal_bc_args(
224+
truesofar,
225+
::Type{TStart},
226+
args::Tuple,
227+
rargs...,
228+
) where {TStart} =
229+
truesofar &&
230+
_is_diagonal_bc(truesofar, TStart, args[1], rargs...) &&
231+
_is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...)
232+
233+
@inline _is_diagonal_bc_args(
234+
truesofar,
235+
::Type{TStart},
236+
args::Tuple{Any},
237+
rargs...,
238+
) where {TStart} =
239+
truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...)
240+
@inline _is_diagonal_bc_args(
241+
truesofar,
242+
::Type{TStart},
243+
args::Tuple{},
244+
rargs...,
245+
) where {TStart} = truesofar
246+
247+
@inline function _is_diagonal_bc(
248+
truesofar,
249+
::Type{TStart},
250+
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
251+
) where {TStart}
252+
return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args)
253+
end
254+
255+
@inline _is_diagonal_bc(
256+
truesofar,
257+
::Type{TStart},
258+
::TStart,
259+
) where {TStart <: FieldVector} = true
260+
@inline _is_diagonal_bc(
261+
truesofar,
262+
::Type{TStart},
263+
x::FieldVector,
264+
) where {TStart} = false
265+
@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar
266+
267+
# Find the first fieldvector in the broadcast expression (BCE),
268+
# and compare against every other fieldvector in the BCE
269+
@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) =
270+
_is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args)
271+
272+
# Specialize on FieldVectorStyle to avoid inference failure
273+
# in fieldvector broadcast expressions:
274+
# https://github.com/JuliaArrays/BlockArrays.jl/issues/310
275+
function Base.Broadcast.instantiate(
276+
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
277+
)
278+
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
279+
axes = Base.Broadcast.combine_axes(bc.args...)
280+
else
281+
axes = bc.axes
282+
# Base.Broadcast.check_broadcast_axes is type-unstable
283+
# for broadcast expressions with multiple fieldvectors.
284+
# So, let's statically elide this when we have "diagonal"
285+
# broadcast expressions:
286+
if !is_diagonal_bc(bc)
287+
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
288+
end
289+
end
290+
return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes)
291+
end
292+
206293
@inline function Base.copyto!(
207294
dest::FieldVector,
208295
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},

test/Fields/field.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,48 @@ end
278278
@test Y.k.z === 3.0
279279
end
280280

281+
# https://github.com/CliMA/ClimaCore.jl/issues/1465
282+
@testset "Diagonal FieldVector broadcast expressions" begin
283+
FT = Float64
284+
device = ClimaComms.device()
285+
comms_ctx = ClimaComms.context(device)
286+
cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
287+
fspace = TU.FaceExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
288+
cx = Fields.fill((; a = FT(1), b = FT(2)), cspace)
289+
cy = Fields.fill((; a = FT(1), b = FT(2)), cspace)
290+
fx = Fields.fill((; a = FT(1), b = FT(2)), fspace)
291+
fy = Fields.fill((; a = FT(1), b = FT(2)), fspace)
292+
Y1 = Fields.FieldVector(; x = cx, y = cy)
293+
Y2 = Fields.FieldVector(; x = cx, y = cy)
294+
Y3 = Fields.FieldVector(; x = cx, y = cy)
295+
Y4 = Fields.FieldVector(; x = cx, y = cy)
296+
Z = Fields.FieldVector(; x = fx, y = fy)
297+
function test_fv_allocations!(X1, X2, X3, X4)
298+
@. X1 += X2 * X3 + X4
299+
return nothing
300+
end
301+
test_fv_allocations!(Y1, Y2, Y3, Y4)
302+
p_allocated = @allocated test_fv_allocations!(Y1, Y2, Y3, Y4)
303+
if device isa ClimaComms.AbstractCPUDevice
304+
@test p_allocated == 0
305+
elseif device isa ClimaComms.CUDADevice
306+
@test_broken p_allocated == 0
307+
end
308+
309+
bc1 = Base.broadcasted(
310+
:-,
311+
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y2)),
312+
Base.broadcasted(:*, 3, Y3),
313+
)
314+
bc2 = Base.broadcasted(
315+
:-,
316+
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y1)),
317+
Base.broadcasted(:*, 3, Z),
318+
)
319+
@test Fields.is_diagonal_bc(bc1)
320+
@test !Fields.is_diagonal_bc(bc2)
321+
end
322+
281323
function call_getcolumn(fv, colidx)
282324
@allowscalar fvcol = fv[colidx]
283325
nothing

0 commit comments

Comments
 (0)