1- export DepthwiseConvDims
2-
3- """
4- DepthwiseConvDims
5-
6- Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7- characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8- DenseConvDims primarily for channel calculation differences.
9- """
10- struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11- I:: NTuple{N, Int}
12- K:: NTuple{N, Int}
13- C_in:: Int
14- C_mult:: Int
15- end
16-
17- # Getters for the fields
18- input_size (c:: DepthwiseConvDims ) = c. I
19- kernel_size (c:: DepthwiseConvDims ) = c. K
20- channels_in (c:: DepthwiseConvDims ) = c. C_in
21- channels_out (c:: DepthwiseConvDims ) = c. C_in * channel_multiplier (c)
22- channel_multiplier (c:: DepthwiseConvDims ) = c. C_mult
23-
24-
25- # Convenience wrapper to create DepthwiseConvDims objects
26- function DepthwiseConvDims (x_size:: NTuple{M} , w_size:: NTuple{M} ;
27- stride= 1 , padding= 0 , dilation= 1 , flipkernel:: Bool = false ) where M
28- # Do common parameter validation
29- stride, padding, dilation = check_spdf (x_size, w_size, stride, padding, dilation)
30-
31- # Ensure channels are equal
32- if x_size[end - 1 ] != w_size[end ]
33- xs = x_size[end - 1 ]
34- ws = w_size[end ]
35- throw (DimensionMismatch (" Input channels must match! ($xs vs. $ws )" ))
36- end
37-
38- return DepthwiseConvDims{
39- M - 2 ,
40- stride,
41- padding,
42- dilation,
43- flipkernel
44- }(
45- # Image spatial size
46- x_size[1 : end - 2 ],
47-
48- # Kernel spatial size
49- w_size[1 : end - 2 ],
50-
51- # Input channels
52- x_size[end - 1 ],
53-
54- # Channel multiplier
55- w_size[end - 1 ],
56- )
57- end
58-
59- # Auto-extract sizes and just pass those directly in
60- function DepthwiseConvDims (x:: AbstractArray , w:: AbstractArray ; kwargs... )
61- if ndims (x) != ndims (w)
62- throw (DimensionMismatch (" Rank of x and w must match! ($(ndims (x)) vs. $(ndims (w)) )" ))
63- end
64- return DepthwiseConvDims (size (x), size (w); kwargs... )
65- end
66-
67- # Useful for constructing a new DepthwiseConvDims that has only a few elements different
68- # from the original progenitor object.
69- function DepthwiseConvDims (c:: DepthwiseConvDims ; N= spatial_dims (c), I= input_size (c), K= kernel_size (c),
70- C_in= channels_in (c), C_m= channel_multiplier (c), S= stride (c),
71- P= padding (c), D= dilation (c), F= flipkernel (c))
72- return DepthwiseConvDims {N, S, P, D, F} (I, K, C_in, C_m)
73- end
74-
75- # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76- function check_dims (x:: NTuple{M} , w:: NTuple{M} , y:: NTuple{M} , cdims:: DepthwiseConvDims ) where {M}
77- # First, check that channel counts are all correct:
78- @assert x[end - 1 ] == channels_in (cdims) DimensionMismatch (" Data input channel count ($(x[end - 1 ]) vs. $(channels_in (cdims)) )" )
79- @assert y[end - 1 ] == channels_out (cdims) DimensionMismatch (" Data output channel count ($(y[end - 1 ]) vs. $(channels_out (cdims)) )" )
80- @assert w[end - 1 ] == channel_multiplier (cdims) DimensionMismatch (" Kernel multiplier channel count ($(w[end - 1 ]) vs. $(channel_multiplier (cdims)) " )
81- @assert w[end ] == channels_in (cdims) DimensionMismatch (" Kernel input channel count ($(w[end ]) vs. $(channels_in (cdims)) )" )
82-
83- # Next, check that the spatial dimensions match up
84- @assert x[1 : end - 2 ] == input_size (cdims) DimensionMismatch (" Data input spatial size ($(x[1 : end - 2 ]) vs. $(input_size (cdims)) )" )
85- @assert y[1 : end - 2 ] == output_size (cdims) DimensionMismatch (" Data output spatial size ($(y[1 : end - 2 ]) vs. $(output_size (cdims)) )" )
86- @assert w[1 : end - 2 ] == kernel_size (cdims) DimensionMismatch (" Kernel spatial size ($(w[1 : end - 2 ]) vs. $(kernel_size (cdims)) )" )
87-
88- # Finally, check that the batch size matches
89- @assert x[end ] == y[end ] DimensionMismatch (" Batch size ($(x[end ]) vs. $(y[end ]) )" )
90- end
1+ # export DepthwiseConvDims
2+ #
3+ # """
4+ # DepthwiseConvDims
5+ #
6+ # Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7+ # characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8+ # DenseConvDims primarily for channel calculation differences.
9+ # """
10+ # struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11+ # I::NTuple{N, Int}
12+ # K::NTuple{N, Int}
13+ # C_in::Int
14+ # C_mult::Int
15+ # end
16+ #
17+ # # Getters for the fields
18+ # input_size(c::DepthwiseConvDims) = c.I
19+ # kernel_size(c::DepthwiseConvDims) = c.K
20+ # channels_in(c::DepthwiseConvDims) = c.C_in
21+ # channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22+ # channel_multiplier(c::DepthwiseConvDims) = c.C_mult
23+ #
24+ #
25+ # # Convenience wrapper to create DepthwiseConvDims objects
26+ # function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
27+ # stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
28+ # # Do common parameter validation
29+ # stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
30+ #
31+ # # Ensure channels are equal
32+ # if x_size[end-1] != w_size[end]
33+ # xs = x_size[end-1]
34+ # ws = w_size[end]
35+ # throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
36+ # end
37+ #
38+ # return DepthwiseConvDims{
39+ # M - 2,
40+ # stride,
41+ # padding,
42+ # dilation,
43+ # flipkernel
44+ # }(
45+ # # Image spatial size
46+ # x_size[1:end-2],
47+ #
48+ # # Kernel spatial size
49+ # w_size[1:end-2],
50+ #
51+ # # Input channels
52+ # x_size[end-1],
53+ #
54+ # # Channel multiplier
55+ # w_size[end-1],
56+ # )
57+ # end
58+ #
59+ # # Auto-extract sizes and just pass those directly in
60+ # function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
61+ # if ndims(x) != ndims(w)
62+ # throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
63+ # end
64+ # return DepthwiseConvDims(size(x), size(w); kwargs...)
65+ # end
66+ #
67+ # # Useful for constructing a new DepthwiseConvDims that has only a few elements different
68+ # # from the original progenitor object.
69+ # function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
70+ # C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
71+ # P=padding(c), D=dilation(c), F=flipkernel(c))
72+ # return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
73+ # end
74+ #
75+ # # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76+ # function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
77+ # # First, check that channel counts are all correct:
78+ # @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79+ # @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80+ # @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81+ # @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
82+ #
83+ # # Next, check that the spatial dimensions match up
84+ # @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85+ # @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86+ # @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
87+ #
88+ # # Finally, check that the batch size matches
89+ # @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
90+ # end
0 commit comments