|
203 | 203 | @inline transform_broadcasted(fv::FieldVector, symb, axes) =
|
204 | 204 | parent(getfield(_values(fv), symb))
|
205 | 205 | @inline transform_broadcasted(x, symb, axes) = x
|
| 206 | + |
| 207 | +@inline function first_fieldvector_in_bc(args::Tuple, rargs...) |
| 208 | + x1 = first_fieldvector_in_bc(args[1], rargs...) |
| 209 | + x1 isa FieldVector && return x1 |
| 210 | + return first_fieldvector_in_bc(Base.tail(args), rargs...) |
| 211 | +end |
| 212 | + |
| 213 | +@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) = |
| 214 | + first_fieldvector_in_bc(args[1], rargs...) |
| 215 | +@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing |
| 216 | +@inline first_fieldvector_in_bc(x) = nothing |
| 217 | +@inline first_fieldvector_in_bc(x::FieldVector) = x |
| 218 | + |
| 219 | +@inline first_fieldvector_in_bc( |
| 220 | + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, |
| 221 | +) = first_fieldvector_in_bc(bc.args) |
| 222 | + |
| 223 | +@inline _is_diagonal_bc_args( |
| 224 | + truesofar, |
| 225 | + ::Type{TStart}, |
| 226 | + args::Tuple, |
| 227 | + rargs..., |
| 228 | +) where {TStart} = |
| 229 | + truesofar && |
| 230 | + _is_diagonal_bc(truesofar, TStart, args[1], rargs...) && |
| 231 | + _is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...) |
| 232 | + |
| 233 | +@inline _is_diagonal_bc_args( |
| 234 | + truesofar, |
| 235 | + ::Type{TStart}, |
| 236 | + args::Tuple{Any}, |
| 237 | + rargs..., |
| 238 | +) where {TStart} = |
| 239 | + truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...) |
| 240 | +@inline _is_diagonal_bc_args( |
| 241 | + truesofar, |
| 242 | + ::Type{TStart}, |
| 243 | + args::Tuple{}, |
| 244 | + rargs..., |
| 245 | +) where {TStart} = truesofar |
| 246 | + |
| 247 | +@inline function _is_diagonal_bc( |
| 248 | + truesofar, |
| 249 | + ::Type{TStart}, |
| 250 | + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, |
| 251 | +) where {TStart} |
| 252 | + return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args) |
| 253 | +end |
| 254 | + |
| 255 | +@inline _is_diagonal_bc( |
| 256 | + truesofar, |
| 257 | + ::Type{TStart}, |
| 258 | + ::TStart, |
| 259 | +) where {TStart <: FieldVector} = true |
| 260 | +@inline _is_diagonal_bc( |
| 261 | + truesofar, |
| 262 | + ::Type{TStart}, |
| 263 | + x::FieldVector, |
| 264 | +) where {TStart} = false |
| 265 | +@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar |
| 266 | + |
| 267 | +# Find the first fieldvector in the broadcast expression (BCE), |
| 268 | +# and compare against every other fieldvector in the BCE |
| 269 | +@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) = |
| 270 | + _is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args) |
| 271 | + |
| 272 | +# Specialize on FieldVectorStyle to avoid inference failure |
| 273 | +# in fieldvector broadcast expressions: |
| 274 | +# https://github.com/JuliaArrays/BlockArrays.jl/issues/310 |
| 275 | +function Base.Broadcast.instantiate( |
| 276 | + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, |
| 277 | +) |
| 278 | + if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) |
| 279 | + axes = Base.Broadcast.combine_axes(bc.args...) |
| 280 | + else |
| 281 | + axes = bc.axes |
| 282 | + # Base.Broadcast.check_broadcast_axes is type-unstable |
| 283 | + # for broadcast expressions with multiple fieldvectors. |
| 284 | + # So, let's statically elide this when we have "diagonal" |
| 285 | + # broadcast expressions: |
| 286 | + if !is_diagonal_bc(bc) |
| 287 | + Base.Broadcast.check_broadcast_axes(axes, bc.args...) |
| 288 | + end |
| 289 | + end |
| 290 | + return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes) |
| 291 | +end |
| 292 | + |
206 | 293 | @inline function Base.copyto!(
|
207 | 294 | dest::FieldVector,
|
208 | 295 | bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
|
|
0 commit comments