Skip to content

Commit 4fa28d4

Browse files
committed
Make pretrain condition explicit
1 parent de079bc commit 4fa28d4

File tree

6 files changed

+33
-10
lines changed

6 files changed

+33
-10
lines changed

src/convnets/alexnet.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ end
4949

5050
function AlexNet(; pretrain = false, nclasses = 1000)
5151
layers = alexnet(; nclasses = nclasses)
52-
pretrain && loadpretrain!(layers, "AlexNet")
52+
if pretrain
53+
loadpretrain!(layers, "AlexNet")
54+
end
5355
return AlexNet(layers)
5456
end
5557

src/convnets/densenet.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ See also [`Metalhead.densenet`](#).
162162
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
163163
@assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))."
164164
model = DenseNet(densenet_config[config]; nclasses = nclasses)
165-
pretrain && loadpretrain!(model, string("DenseNet", config))
165+
if pretrain
166+
loadpretrain!(model, string("DenseNet", config))
167+
end
166168
return model
167169
end

src/convnets/googlenet.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ end
8686

8787
function GoogLeNet(; pretrain = false, nclasses = 1000)
8888
layers = googlenet(; nclasses = nclasses)
89-
pretrain && loadpretrain!(layers, "GoogLeNet")
89+
if pretrain
90+
loadpretrain!(layers, "GoogLeNet")
91+
end
9092
return GoogLeNet(layers)
9193
end
9294

src/convnets/inception.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ end
182182

183183
function Inceptionv3(; pretrain = false, nclasses = 1000)
184184
layers = inceptionv3(; nclasses = nclasses)
185-
pretrain && loadpretrain!(layers, "Inceptionv3")
185+
if pretrain
186+
loadpretrain!(layers, "Inceptionv3")
187+
end
186188
return Inceptionv3(layers)
187189
end
188190

@@ -341,7 +343,9 @@ end
341343

342344
function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
343345
layers = inceptionv4(; inchannels, drop_rate, nclasses)
344-
pretrain && loadpretrain!(layers, "Inceptionv4")
346+
if pretrain
347+
loadpretrain!(layers, "Inceptionv4")
348+
end
345349
return Inceptionv4(layers)
346350
end
347351

@@ -476,7 +480,9 @@ end
476480
function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0,
477481
nclasses = 1000)
478482
layers = inceptionresnetv2(; inchannels, drop_rate, nclasses)
479-
pretrain && loadpretrain!(layers, "InceptionResNetv2")
483+
if pretrain
484+
loadpretrain!(layers, "InceptionResNetv2")
485+
end
480486
return InceptionResNetv2(layers)
481487
end
482488

@@ -584,7 +590,9 @@ Creates an Xception model.
584590
"""
585591
function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
586592
layers = xception(; inchannels, drop_rate, nclasses)
587-
pretrain && loadpretrain!(layers, "xception")
593+
if pretrain
594+
loadpretrain!(layers, "xception")
595+
end
588596
return Xception(layers)
589597
end
590598

src/convnets/mobilenet.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ end
9090
function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false,
9191
nclasses = 1000)
9292
layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses)
93-
pretrain && loadpretrain!(layers, string("MobileNetv1"))
93+
if pretrain
94+
loadpretrain!(layers, string("MobileNetv1"))
95+
end
9496
return MobileNetv1(layers)
9597
end
9698

@@ -189,6 +191,9 @@ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false,
189191
nclasses = 1000)
190192
layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses)
191193
pretrain && loadpretrain!(layers, string("MobileNetv2"))
194+
if pretrain
195+
loadpretrain!(layers, string("MobileNetv2"))
196+
end
192197
return MobileNetv2(layers)
193198
end
194199

@@ -319,7 +324,9 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels =
319324
max_width = (mode == :large) ? 1280 : 1024
320325
layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width,
321326
nclasses)
322-
pretrain && loadpretrain!(layers, string("MobileNetv3", mode))
327+
if pretrain
328+
loadpretrain!(layers, string("MobileNetv3", mode))
329+
end
323330
return MobileNetv3(layers)
324331
end
325332

src/convnets/squeezenet.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ end
6868

6969
function SqueezeNet(; pretrain = false)
7070
layers = squeezenet()
71-
pretrain && loadpretrain!(layers, "SqueezeNet")
71+
if pretrain
72+
loadpretrain!(layers, "SqueezeNet")
73+
end
7274
return SqueezeNet(layers)
7375
end
7476

0 commit comments

Comments
 (0)