Skip to content

Commit 2fe393e

Browse files
committed
Restore type stability of conv_transpose_dims
1 parent df468ba commit 2fe393e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/layers/conv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,9 @@ end
313313

314314
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
315315
# Calculate size of "input", from ∇conv_data()'s perspective...
316-
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
317-
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
316+
calc_dim(xsz, wsz, stride, dilation, pad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad
317+
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
318+
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad)
318319
C_in = size(c.weight)[end-1] * c.groups
319320
batch_size = size(x)[end]
320321
# Create DenseConvDims() that looks like the corresponding conv()

0 commit comments

Comments
 (0)