@@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
28
28
true
29
29
```
30
30
31
+ A chain may be called with multiple arguments, which is equivalent to calling it
32
+ with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref)
33
+ to mean the same as several arguments:
34
+
35
+ ```jldoctest
36
+ julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple
37
+ (1, 2, 3)
38
+ nothing
39
+
40
+ julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2
41
+ x = (4, 5)
42
+ 25.25
43
+ ```
44
+
31
45
For large models, there is a special type-unstable path which can reduce compilation
32
46
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
33
47
This feature is somewhat experimental, beware!
46
60
@forward Chain. layers Base. getindex, Base. length, Base. first, Base. last,
47
61
Base. iterate, Base. lastindex, Base. keys, Base. firstindex
48
62
49
- @layer :expand Chain # the + opts-in to container-style pretty-printing
63
+ @layer :expand Chain # the option :expand opts-in to container-style pretty-printing
50
64
51
65
(c:: Chain )(x) = _applychain (c. layers, x)
66
+ (c:: Chain )(x, ys... ) = _applychain (c. layers, (x, ys... ))
52
67
53
68
@generated function _applychain (layers:: Tuple{Vararg{Any,N}} , x) where {N}
54
69
symbols = vcat (:x , [gensym () for _ in 1 : N])
68
83
Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i])
69
84
Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
70
85
Chain (NamedTuple {keys(c)[i]} (Tuple (c. layers)[i]))
86
+
71
87
function Base. show (io:: IO , c:: Chain )
72
88
print (io, " Chain(" )
73
89
_show_layers (io, c. layers)
475
491
Create a layer which passes an input array to each path in
476
492
`layers`, before reducing the output with `connection`.
477
493
478
- Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
479
- If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
494
+ Obeys the similar rules to broadcasting:
495
+ * Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
496
+ * With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
497
+ * With multiple inputs and multiple layers, one input is passed to each layer,
498
+ thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
480
499
481
500
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
482
501
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
@@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`.
486
505
487
506
# Examples
488
507
508
+ ```jldoctest
509
+ julia> p = Parallel(+, abs2, sqrt);
510
+
511
+ julia> p(3, 4) # == 3^2 + √4, two functions two inputs
512
+ 11.0
513
+
514
+ julia> p((3, 4)) # tuple is always splatted
515
+ 11.0
516
+
517
+ julia> p(4) # == 4^2 + √4, one input used twice
518
+ 18.0
519
+
520
+ julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs
521
+ 1×3 Matrix{Float64}:
522
+ 1.0 0.5 0.25
523
+ ```
524
+
525
+ With Flux layers:
526
+
489
527
```jldoctest
490
528
julia> model = Chain(Dense(3 => 5),
491
529
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
@@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
516
554
layers:: T
517
555
end
518
556
557
+ _ParallelONE{T} = Parallel{T, <: Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}} }
558
+
519
559
Parallel (connection, layers... ) = Parallel (connection, layers)
520
560
function Parallel (connection; kw... )
521
561
layers = NamedTuple (kw)
522
562
if :layers in keys (layers) || :connection in keys (layers)
523
563
throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
524
564
end
525
- isempty (layers) && return Parallel (connection, ())
526
565
Parallel (connection, layers)
527
566
end
567
+ Parallel (connection, layers:: Union{Tuple{}, @NamedTuple{}} ) =
568
+ throw (ArgumentError (" cannot construct a Parallel layer with no sub-layers" ))
528
569
529
570
@layer :expand Parallel
530
571
531
- (m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
532
- (m:: Parallel )(xs:: Tuple ) = m (xs... )
572
+ (m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... ) # one argument
533
573
534
574
function _parallel_check (layers, xs)
535
575
nl = length (layers)
536
- nx = length (xs)
576
+ @assert nl > 1 # dispatch handles nl==1 cases
577
+ nx = length (xs)
537
578
if (nl != nx)
538
- throw (ArgumentError (lazy " Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
579
+ throw (ArgumentError (lazy " Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs" ))
539
580
end
540
581
end
541
582
ChainRulesCore. @non_differentiable _parallel_check (nl, nx)
542
583
543
- function (m:: Parallel )(xs... )
584
+ function (m:: Parallel )(x, ys... )
585
+ xs = (x, ys... )
544
586
_parallel_check (m. layers, xs)
545
- m. connection (map (|> , xs, Tuple (m. layers))... )
587
+ m. connection (map (|> , xs, Tuple (m. layers))... ) # multiple arguments & multiple layers
546
588
end
547
589
590
+ (m:: _ParallelONE )(x, ys... ) =
591
+ m. connection (map (z -> only (m. layers)(z), (x, ys... ))... ) # multiple arguments, one layer
592
+
593
+ (m:: Parallel )(xs:: Tuple ) = m (xs... ) # tuple is always splatted
594
+ (m:: _ParallelONE )(xs:: Tuple ) = m (xs... ) # solves an ambiguity
595
+
596
+ (m:: Parallel )() = throw (ArgumentError (" Parallel layer cannot take 0 inputs" ))
597
+
548
598
Base. getindex (m:: Parallel , i) = m. layers[i]
549
599
Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
550
600
Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
0 commit comments