@@ -12,109 +12,77 @@ store those as fields, just for convenience, and to allow for non-breaking chang
1212we decide we _do_ want to specialize on those values. We always want to specialize on
1313things like stride, padding, dilation, and kernel flipping though.
1414"""
15- abstract type ConvDims{N, S, P, D, F} end
1615
17- # Hack to get rid of type parameters
18- function basetype (:: Type{C} ) where {C <: ConvDims }
19- if C <: DenseConvDims
20- return DenseConvDims
21- elseif C <: PoolDims
22- return PoolDims
23- else
24- return nothing
25- end
16+ struct ConvDims{N,K,C_in,C_out,S,P,D,F,G} <: AbstractDims{N,S,P,D,F}
17+ I:: NTuple{N,Int}
2618end
2719
28- # Obvious getter definitions for the type system-level definitions
29- spatial_dims (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = N
30- stride (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = S
31- padding (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = P
32- dilation (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = D
33- flipkernel (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = F
34-
35- """
36- im2col_dims(c::ConvDims)
37-
38- im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
39- number of output channels, by doing a matrix multiply. The dimensions of that matrix
40- are given by this function.
41- """
42- im2col_dims (c:: ConvDims ) = (prod (output_size (c)), prod (kernel_size (c))* channels_in (c))
4320
44- # Protect your skin, kids. Also do common validation of stride, padding, etc...
45- function check_spdf (x_size:: NTuple{N} , w_size:: NTuple{N} , stride, padding, dilation) where {N}
46- # Number of spatial dimensions in `x` and `w`.
47- nd = N - 2
48-
49- # Given a number, duplicate it out to have `nd` length. If it's already a collection,
50- # just splat it out into a tuple so it's always a tuple. We'll lint length later.
51- expand_size (p:: Number ) = ntuple (_ -> Int (p), nd)
52- expand_size (p) = tuple (p... )
53-
54- # Convert stride, padding, dilation, etc.. to fully-specified tuples
55- pstride = expand_size (stride)
56- pdilation = expand_size (dilation)
57- ppadding = expand_size (padding)
58-
59- if length (pstride) != nd
60- throw (DimensionMismatch (" Stride $(length (stride)) d, should be $(nd) d!" ))
61- end
62- if length (pdilation) != nd
63- throw (DimensionMismatch (" Dilation $(length (pdilation)) d, should be $(nd) d!" ))
21+ # Getters for the fields
22+ input_size (c:: ConvDims ) = c. I
23+ kernel_size (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = K
24+ channels_in (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = C_in
25+ channels_out (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = C_out
26+ group_count (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = G
27+
28+ # Convenience wrapper to create ConvDims objects
29+ function ConvDims (x_size:: NTuple{M} , w_size:: NTuple{M} ;
30+ stride= 1 , padding= 0 , dilation= 1 , flipkernel:: Bool = false , groupcount= 1 ) where M
31+ # Do common parameter validation
32+ stride, padding, dilation = check_spdf (x_size, w_size, stride, padding, dilation)
33+
34+ # Ensure channels are equal
35+ if x_size[M- 1 ] != w_size[M- 1 ]* groupcount
36+ xs = x_size[M- 1 ]
37+ ws = w_size[M- 1 ]* groupcount
38+ throw (DimensionMismatch (" Input channels must match! ($xs vs. $ws )" ))
6439 end
6540
66- # padding is kind of a special case; we allow it to be either 2-length or 4-length,
67- # since we support asymmetrical padding
68- if length (ppadding) != 2 * nd
69- if length (ppadding) == nd
70- # Do this repeat dance so that we get lo/hi symmetrical padding
71- ppadding = tuple (repeat (collect (ppadding), inner= 2 )... )
72- else
73- throw (DimensionMismatch (" Padding $(length (ppadding)) d, should be either $(nd) d or $(2 * nd) d!" ))
74- end
75- end
41+ # The type parameters are what
42+ return ConvDims{
43+ M - 2 ,
44+ w_size[1 : M- 2 ],
45+ x_size[M- 1 ],
46+ w_size[M],
47+ stride,
48+ padding,
49+ dilation,
50+ flipkernel,
51+ groupcount
52+ }(
53+ # Input spatial size
54+ x_size[1 : M- 2 ],
55+ )
56+ end
7657
77- # Assert that kernel size * dilation is <= padded input size
78- for idx in 1 : nd
79- Is = x_size[idx]
80- Pl = ppadding[(idx - 1 )* 2 + 1 ]
81- Ph = ppadding[(idx - 1 )* 2 + 2 ]
82- Ks = w_size[idx]
83- Ds = pdilation[idx]
84- if Is + Pl + Ph < (Ks - 1 )* Ds + 1
85- throw (DimensionMismatch (" Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph )!" ))
86- end
58+ # Auto-extract sizes and sub out to big brother above
59+ function ConvDims (x:: AbstractArray , w:: AbstractArray ; kwargs... )
60+ if ndims (x) != ndims (w)
61+ throw (DimensionMismatch (" Rank of x and w must match! ($(ndims (x)) vs. $(ndims (w)) )" ))
8762 end
88-
89- return pstride, ppadding, pdilation
63+ return ConvDims (size (x), size (w); kwargs... )
9064end
9165
92- """
93- output_size(c::ConvDims)
66+ # Useful for constructing a new ConvDims that has only a few elements different
67+ # from the original progenitor object that it inherits shapes from.
68+ function ConvDims (c:: AbstractDims ; N= spatial_dims (c), I= input_size (c), K= kernel_size (c),
69+ C_in= channels_in (c), C_out= channels_out (c), S= stride (c),
70+ P= padding (c), D= dilation (c), F= flipkernel (c), G= group_count (c))
71+ return ConvDims {N, K, C_in, C_out, S, P, D, F, G} (I)
72+ end
9473
95- Calculate the output (spatial) dimensions of the convolution. Get channel count via
96- `channels_out(c)`, and batch count is unknowable.
97- """
98- function output_size (c:: ConvDims )
99- I = input_size (c)
100- K = kernel_size (c)
101- S = stride (c)
102- P = padding (c)
103- D = dilation (c)
74+ function check_dims (x:: NTuple{M} , w:: NTuple{M} , y:: NTuple{M} , cdims:: ConvDims ) where {M}
75+ # First, check that channel counts are all correct:
76+ @assert x[M- 1 ] == channels_in (cdims) DimensionMismatch (" Data input channel count ($(x[M- 1 ]) vs. $(channels_in (cdims)) )" )
77+ @assert y[M- 1 ] == channels_out (cdims) DimensionMismatch (" Data output channel count ($(y[M- 1 ]) vs. $(channels_out (cdims)) )" )
78+ @assert w[M- 1 ] == channels_in (cdims)/ group_count (cdims) DimensionMismatch (" Kernel input channel count ($(w[M- 1 ]) vs. $(channels_in (cdims)/ group_count (cdims)) )" )
79+ @assert w[M] == channels_out (cdims) DimensionMismatch (" Kernel output channel count ($(w[M]) vs. $(channels_out (cdims)) )" )
10480
105- return ntuple ( spatial_dims (c)) do i
106- return div (I[i] + P[(i - 1 ) * 2 + 1 ] + P[(i - 1 ) * 2 + 2 ] - (K[i] - 1 ) * D[i] - 1 , S[i ]) + 1
107- end
108- end
81+ # Next, check that the spatial dimensions match up
82+ @assert x[ 1 : M - 2 ] == input_size (cdims) DimensionMismatch ( " Data input spatial size ( $(x[ 1 : M - 2 ]) vs. $( input_size (cdims)) ) " )
83+ @assert y[ 1 : M - 2 ] == output_size (cdims) DimensionMismatch ( " Data output spatial size ( $(y[ 1 : M - 2 ]) vs. $( output_size (cdims)) ) " )
84+ @assert w[ 1 : M - 2 ] == kernel_size (cdims) DimensionMismatch ( " Kernel spatial size ( $(w[ 1 : M - 2 ]) vs. $( kernel_size (cdims)) ) " )
10985
110- # Override show() for these beauties
111- function Base. show (io:: IO , cdims:: C ) where {C <: ConvDims }
112- I = (input_size (cdims)... , channels_in (cdims))
113- O = (output_size (cdims)... , channels_out (cdims))
114- K = kernel_size (cdims)
115- S = stride (cdims)
116- P = padding (cdims)
117- D = dilation (cdims)
118- F = flipkernel (cdims)
119- print (io, " $(basetype (C)) : $I * $K -> $O , stride: $S pad: $P , dil: $D , flip: $F " )
86+ # Finally, check that the batch size matches
87+ @assert x[M] == y[M] DimensionMismatch (" Batch size ($(x[M]) vs. $(y[M]) )" )
12088end
0 commit comments