Skip to content

Commit 3bfb345

Browse files
Fix vgg16
1 parent 1ce5893 commit 3bfb345

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

src/Torch/Compose/Models.hs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ import Data.HList
2323

2424
vgg16Spec numClass =
2525
let maxPool2dSpec = MaxPool2dSpec
26-
{ kernelSize = (3,3)
26+
{ kernelSize = (2,2)
2727
, stride = (2,2)
28-
, padding = (1,1)
29-
, dilation = (0,0)
30-
, ceilMode = Ceil
28+
, padding = (0,0)
29+
, dilation = (1,1)
30+
, ceilMode = Floor
3131
}
3232
vggClassifierSpec =
3333
LinearSpec (512 * 7 * 7) 4096 .*.
@@ -39,7 +39,7 @@ vgg16Spec numClass =
3939
LinearSpec 4096 numClass .*.
4040
HNil
4141
conv2dSpec inChannel outChannel kernelHeight kernelWidth =
42-
Conv2dSpec' inChannel outChannel kernelHeight kernelWidth (1,1) (0,0)
42+
Conv2dSpec' inChannel outChannel kernelHeight kernelWidth (1,1) (1,1)
4343
in
4444
conv2dSpec 3 64 3 3 .*.
4545
conv2dSpec 64 64 3 3 .*.
@@ -55,7 +55,7 @@ vgg16Spec numClass =
5555
conv2dSpec 512 512 3 3 .*.
5656
conv2dSpec 512 512 3 3 .*.
5757
AdaptiveAvgPool2dSpec (7,7) .*.
58-
ReshapeSpec [1,512*7*7] .*.
58+
ReshapeSpec [-1,512*7*7] .*.
5959
vggClassifierSpec
6060

6161
-- resnetSpec numClass =

test/Spec.hs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,10 @@ main = hspec $ do
106106
model <- sample (vgg16Spec 10)
107107
let input = ones' [2,3,128,128]
108108
outputShapes = toMaybeOutputShapes model input
109-
-- output = forward model input
110109
exp =
111-
Nothing .*. Nothing .*. Nothing .*. Nothing .*. Nothing .*.
112-
Nothing .*. Nothing .*. Nothing .*. Nothing .*. Nothing .*.
113-
Nothing .*. Nothing .*. Nothing .*. Nothing .*. Nothing .*.
114-
Nothing .*. Nothing .*. Nothing .*. Nothing .*. Nothing .*.
115-
Nothing .*. Nothing .*. HNil
116-
110+
Just [2,64,128,128] .*. Just [2,64,128,128] .*. Just [2,64,64,64] .*. Just [2,128,64,64] .*. Just [2,128,64,64] .*.
111+
Just [2,128,32,32] .*. Just [2,256,32,32] .*. Just [2,256,32,32] .*. Just [2,256,32,32] .*. Just [2,256,16,16] .*.
112+
Just [2,512,16,16] .*. Just [2,512,16,16] .*. Just [2,512,16,16] .*. Just [2,512,7,7] .*. Just [2,25088] .*.
113+
Just [2,4096] .*. Just [2,4096] .*. Just [2,4096] .*. Just [2,4096] .*. Just [2,4096] .*.
114+
Just [2,4096] .*. Just [2,10] .*. HNil
117115
outputShapes `shouldBe` exp

0 commit comments

Comments
 (0)