Skip to content

Commit 7be1ca7

Browse files
authored
Allow Parallel(+, f)(x, y, z) to work like broadcasting, and enable Chain(identity, Parallel(+, f))(x, y, z) (#2393)
* let Parallel(+, f)(x, y, z) work like broadcasting * add (::Chain)(xs...) method * more examples * correction * change implementation to dispatch * nicer errors when called on zero inputs * disallow zero layers, let's try this out
1 parent 7525499 commit 7be1ca7

File tree

2 files changed

+77
-16
lines changed

2 files changed

+77
-16
lines changed

src/layers/basic.jl

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
2828
true
2929
```
3030
31+
A chain may be called with multiple arguments, which is equivalent to calling it
32+
with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref)
33+
to mean the same as several arguments:
34+
35+
```jldoctest
36+
julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple
37+
(1, 2, 3)
38+
nothing
39+
40+
julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2
41+
x = (4, 5)
42+
25.25
43+
```
44+
3145
For large models, there is a special type-unstable path which can reduce compilation
3246
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
3347
This feature is somewhat experimental, beware!
@@ -46,9 +60,10 @@ end
4660
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
4761
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
4862

49-
@layer :expand Chain # the + opts-in to container-style pretty-printing
63+
@layer :expand Chain # the option :expand opts-in to container-style pretty-printing
5064

5165
(c::Chain)(x) = _applychain(c.layers, x)
66+
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))
5267

5368
@generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N}
5469
symbols = vcat(:x, [gensym() for _ in 1:N])
@@ -68,6 +83,7 @@ end
6883
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
6984
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
7085
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
86+
7187
function Base.show(io::IO, c::Chain)
7288
print(io, "Chain(")
7389
_show_layers(io, c.layers)
@@ -475,8 +491,11 @@ end
475491
Create a layer which passes an input array to each path in
476492
`layers`, before reducing the output with `connection`.
477493
478-
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
479-
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
494+
Obeys the similar rules to broadcasting:
495+
* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
496+
* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
497+
* With multiple inputs and multiple layers, one input is passed to each layer,
498+
thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
480499
481500
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
482501
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
@@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`.
486505
487506
# Examples
488507
508+
```jldoctest
509+
julia> p = Parallel(+, abs2, sqrt);
510+
511+
julia> p(3, 4) # == 3^2 + √4, two functions two inputs
512+
11.0
513+
514+
julia> p((3, 4)) # tuple is always splatted
515+
11.0
516+
517+
julia> p(4) # == 4^2 + √4, one input used twice
518+
18.0
519+
520+
julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs
521+
1×3 Matrix{Float64}:
522+
1.0 0.5 0.25
523+
```
524+
525+
With Flux layers:
526+
489527
```jldoctest
490528
julia> model = Chain(Dense(3 => 5),
491529
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
@@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
516554
layers::T
517555
end
518556

557+
_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}}
558+
519559
Parallel(connection, layers...) = Parallel(connection, layers)
520560
function Parallel(connection; kw...)
521561
layers = NamedTuple(kw)
522562
if :layers in keys(layers) || :connection in keys(layers)
523563
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
524564
end
525-
isempty(layers) && return Parallel(connection, ())
526565
Parallel(connection, layers)
527566
end
567+
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
568+
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))
528569

529570
@layer :expand Parallel
530571

531-
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
532-
(m::Parallel)(xs::Tuple) = m(xs...)
572+
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument
533573

534574
function _parallel_check(layers, xs)
535575
nl = length(layers)
536-
nx = length(xs)
576+
@assert nl > 1 # dispatch handles nl==1 cases
577+
nx = length(xs)
537578
if (nl != nx)
538-
throw(ArgumentError(lazy"Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
579+
throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs"))
539580
end
540581
end
541582
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)
542583

543-
function (m::Parallel)(xs...)
584+
function (m::Parallel)(x, ys...)
585+
xs = (x, ys...)
544586
_parallel_check(m.layers, xs)
545-
m.connection(map(|>, xs, Tuple(m.layers))...)
587+
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
546588
end
547589

590+
(m::_ParallelONE)(x, ys...) =
591+
m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer
592+
593+
(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted
594+
(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity
595+
596+
(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs"))
597+
548598
Base.getindex(m::Parallel, i) = m.layers[i]
549599
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
550600
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =

test/layers/basic.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ using Flux: activations
3535
c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu))
3636
@test c[1] == c[begin]
3737
@test c[3] == c[end]
38+
39+
@test Chain(identity)(1,2,3) == (1,2,3) # multiple args become a tuple
3840
end
3941

4042
@testset "Activations" begin
@@ -228,17 +230,20 @@ using Flux: activations
228230
end
229231

230232
@testset "concat size" begin
231-
input = randn(10, 2)
233+
input = randn32(10, 2)
232234
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
233235
@test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4)
234236
end
235237

236238
@testset "vararg input" begin
237-
inputs = randn(10), randn(5), randn(4)
239+
inputs = randn32(10), randn32(5), randn32(4)
238240
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
239241
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
240242
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
241-
@test Parallel(+, sin, cos)(pi/2) 1
243+
@test Parallel(+, sin, cos)(pi/2) 1 # one input, several layers
244+
@test Parallel(/, abs)(3, -4) 3/4 # one layer, several inputs
245+
@test Parallel(/, abs)((3, -4)) 3/4
246+
@test Parallel(/; f=abs)(3, -4) 3/4
242247
end
243248

244249
@testset "named access" begin
@@ -256,9 +261,13 @@ using Flux: activations
256261
end
257262

258263
@testset "trivial cases" begin
259-
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
260-
@test Parallel(hcat)(1) == hcat()
261-
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
264+
# zero inputs, always an error
265+
@test_throws ArgumentError Parallel(hcat)()
266+
@test_throws ArgumentError Parallel(hcat, inv)()
267+
@test_throws ArgumentError Parallel(hcat, inv, sqrt)()
268+
269+
# zero layers -- not useful... now made an error
270+
@test_throws ArgumentError Parallel(hcat)
262271
end
263272

264273
@testset "connection is called once" begin
@@ -270,6 +279,8 @@ using Flux: activations
270279
@test CNT[] == 2
271280
Parallel(f_cnt, sin)(1)
272281
@test CNT[] == 3
282+
Parallel(f_cnt, sin)(1,2,3)
283+
@test CNT[] == 4
273284
end
274285

275286
# Ref https://github.com/FluxML/Flux.jl/issues/1673

0 commit comments

Comments
 (0)