Skip to content

summing many differentials (varadic +) #275

Open
@willtebbutt

Description

@willtebbutt

Zygote regularly emits things like

accum(a, b, c, d)

which is equivalent to

+(a, b, c, d)

in ChainRules language. This means that we a) need to support this and b) could optimise it 🥳 .

A simple default implementation already exists provided that +(a, b) is defined, and for simple things like Arrays this is optimised already. For example:

using BenchmarkTools
a = randn(10);
b = randn(10);
c = randn(10);
@benchmark $a + $b + $c

yields

BenchmarkTools.Trial:
  memory estimate:  160 bytes
  allocs estimate:  1
  --------------
  minimum time:     51.717 ns (0.00% GC)
  median time:      52.820 ns (0.00% GC)
  mean time:        55.692 ns (1.67% GC)
  maximum time:     443.181 ns (82.46% GC)
  --------------
  samples:          10000
  evals/sample:     980

Only a single temporary was allocated, which is great. However, we don't have this optimisation for e.g. Composites. Consider

julia> using ChainRulesCore

julia> a = Composite{Any}(a);

julia> b = Composite{Any}(b);

julia> c = Composite{Any}(c);

julia> @benchmark $a + $b + $c
BenchmarkTools.Trial:
  memory estimate:  320 bytes
  allocs estimate:  2
  --------------
  minimum time:     84.682 ns (0.00% GC)
  median time:      88.059 ns (0.00% GC)
  mean time:        93.438 ns (2.04% GC)
  maximum time:     644.918 ns (85.62% GC)
  --------------
  samples:          10000
  evals/sample:     947

It would be nice if we could obtain the same benefits here as in the Array case.

Loosely related to #226 and #113

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestinplace accumulationfor things relating to inplace accumulation of gradients

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions