Description
This is a follow-up to this discussion in JuliaDiff/ChainRules.jl#336.
JuliaDiff/ChainRules.jl#336 improves the array rules for sum
by changing the code e.g. (in the case of sum(abs2, x)
) from 2 .* real.(ȳ) .* x
to
InplaceableThunk(
@thunk(2 .* real.(ȳ) .* x), # val
dx -> dx .+= 2 .* real.(ȳ) .* x # add!(dx)
)
This makes two improvements:
- (1) the
val
computation2 .* real.(ȳ) .* x
is now thunked@thunk(2 .* real.(ȳ) .* x)
- (2) the
add!
accumulation function is nowdx -> dx .+= 2 .* real.(ȳ) .* x
It took me a while to work out why (2) was in improvement. The docs on InplaceableThunks
say
add!
should be defined such that:ithunk.add!(Δ) = Δ .+= ithunk.val
but it should do this more efficently than simply doing this directly.
Looking at the code above, where val = 2 .* real.(ȳ) .* x
, why is add!(dx) = dx .+= 2 .* real.(ȳ) .* x
"more efficient" that add!(dx) = dx .+= val
? By copying the code for val
into the add!
function we get a single expression, allowing the broadcast to be "fused", and thereby avoid allocating an intermediate val = 2 .* real.(ȳ) .* x
array.
So that's cool! (Aside: there are some good blog posts about Julia's loop fusion and broadcast magic)
But it did mean we had to copy code. This issue is to ask "can we do this without having to copy code?" i.e. it's about API / user-friendliness / reducing code / syntactic stuff (which might in turn make this performance improvement more widely used in our array rules).
I see two options, but perhaps there are others:
(A) create a macro like @inplaceable_thunk
If we did this, code such as
x_thunk = InplaceableThunk(
@thunk(2 .* real.(ȳ) .* x),
dx -> dx .+= 2 .* real.(ȳ) .* x
)
could instead be written more succinctly as
x_thunk = @inplaceable_thunk(2 .* real.(ȳ) .* x)
(B) have @thunk
always return an InplaceableThunk
with the add!
function defined like above (i.e. copying in the code for val
)
I'm not sure if (B) is a valid option. But perhaps it is, if users are expected to go via the add!!
function (which checks is_inplaceable_destination
).