Skip to content

Helper for creating InplaceableThunks when it is just broadcast + #274

Open
@nickrobinson251

Description

@nickrobinson251

This is a follow-up to this discussion in JuliaDiff/ChainRules.jl#336.

JuliaDiff/ChainRules.jl#336 improves the array rules for sum by changing the code e.g. (in the case of sum(abs2, x)) from 2 .* real.(ȳ) .* x to

InplaceableThunk(
    @thunk(2 .* real.(ȳ) .* x),        # val
    dx -> dx .+= 2 .* real.(ȳ) .* x    # add!(dx)
)

This makes two improvements:

  • (1) the val computation 2 .* real.(ȳ) .* x is now thunked @thunk(2 .* real.(ȳ) .* x)
  • (2) the add! accumulation function is now dx -> dx .+= 2 .* real.(ȳ) .* x

It took me a while to work out why (2) was in improvement. The docs on InplaceableThunks say

add! should be defined such that: ithunk.add!(Δ) = Δ .+= ithunk.val but it should do this more efficently than simply doing this directly.

Looking at the code above, where val = 2 .* real.(ȳ) .* x, why is add!(dx) = dx .+= 2 .* real.(ȳ) .* x "more efficient" that add!(dx) = dx .+= val? By copying the code for val into the add! function we get a single expression, allowing the broadcast to be "fused", and thereby avoid allocating an intermediate val = 2 .* real.(ȳ) .* x array.

So that's cool! (Aside: there are some good blog posts about Julia's loop fusion and broadcast magic)

But it did mean we had to copy code. This issue is to ask "can we do this without having to copy code?" i.e. it's about API / user-friendliness / reducing code / syntactic stuff (which might in turn make this performance improvement more widely used in our array rules).

I see two options, but perhaps there are others:

(A) create a macro like @inplaceable_thunk

If we did this, code such as

x_thunk = InplaceableThunk(
    @thunk(2 .* real.(ȳ) .* x),
    dx -> dx .+= 2 .* real.(ȳ) .* x
)

could instead be written more succinctly as

x_thunk = @inplaceable_thunk(2 .* real.(ȳ) .* x)

(B) have @thunk always return an InplaceableThunk with the add! function defined like above (i.e. copying in the code for val)

I'm not sure if (B) is a valid option. But perhaps it is, if users are expected to go via the add!! function (which checks is_inplaceable_destination).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestinplace accumulationfor things relating to inplace accumulation of gradientsrule definition helperrelating to helpers for declaring rules

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions