Skip to content

Commit 4aae8c2

Browse files
authored
Merge pull request #208 from FluxML/image-builder
Add preliminary Metalhead.jl integration
2 parents 14b39b1 + 4c8de77 commit 4aae8c2

16 files changed

+344
-138
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
99
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
12+
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
1213
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -18,15 +19,15 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1819
CategoricalArrays = "0.10"
1920
ColorTypes = "0.10.3, 0.11"
2021
ComputationalResources = "0.3.2"
21-
Flux = "0.10.4, 0.11, 0.12, 0.13"
22+
Flux = "0.13"
23+
Metalhead = "0.7"
2224
MLJModelInterface = "1.1.1"
2325
ProgressMeter = "1.7.1"
2426
Tables = "1.0"
2527
julia = "1.6"
2628

2729
[extras]
2830
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
29-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
3031
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3132
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3233
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -35,4 +36,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3536
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3637

3738
[targets]
38-
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
39+
test = ["LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]

src/MLJFlux.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ using Statistics
1313
using ColorTypes
1414
using ComputationalResources
1515
using Random
16+
import Metalhead
1617

18+
include("utilities.jl")
1719
const MMI=MLJModelInterface
1820

1921
include("penalizers.jl")
2022
include("core.jl")
2123
include("builders.jl")
24+
include("metalhead.jl")
2225
include("types.jl")
2326
include("regressor.jl")
2427
include("classifier.jl")
@@ -27,6 +30,7 @@ include("mlj_model_interface.jl")
2730

2831
export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
2932
export NeuralNetworkClassifier, ImageClassifier
33+
export CUDALibs, CPU1
3034

3135

3236

src/builders.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE
1+
# # BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE
22

33
# We introduce chain builders as a way of exposing neural network
44
# hyperparameters (describing, architecture, dropout rates, etc) to
@@ -9,7 +9,7 @@
99
# input/output dimensions/shape.
1010

1111
# Below n or (n1, n2) etc refers to network inputs, while m or (m1,
12-
# m2) etc refers to outputs.
12+
# m2) etc refers to outputs.
1313

1414
abstract type Builder <: MLJModelInterface.MLJType end
1515

@@ -38,7 +38,7 @@ using `n_hidden` nodes in the hidden layer and the specified `dropout`
3838
(defaulting to 0.5). An activation function `σ` is applied between the
3939
hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
4040
is the geometric mean of the number of input and output nodes. The
41-
number of input and output nodes is determined from the data.
41+
number of input and output nodes is determined from the data.
4242
4343
The each layer is initialized using `Flux.glorot_uniform(rng)`. If
4444
`rng` is an integer, it is instead used as the seed for a
@@ -96,6 +96,8 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out)
9696
end
9797

9898

99+
# # BUILER MACRO
100+
99101
struct GenericBuilder{F} <: Builder
100102
apply::F
101103
end

src/core.jl

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,8 @@
11
## EXPOSE OPTIMISERS TO MLJ (for eg, tuning)
22

3-
# Here we make the optimiser structs "transparent" so that their
4-
# field values are exposed by calls to MLJ.params
5-
6-
for opt in (:Descent,
7-
:Momentum,
8-
:Nesterov,
9-
:RMSProp,
10-
:ADAM,
11-
:RADAM,
12-
:AdaMax,
13-
:OADAM,
14-
:ADAGrad,
15-
:ADADelta,
16-
:AMSGrad,
17-
:NADAM,
18-
:AdaBelief,
19-
:Optimiser,
20-
:InvDecay, :ExpDecay, :WeightDecay,
21-
:ClipValue,
22-
:ClipNorm) # last updated: Flux.jl 0.12.3
23-
24-
@eval begin
25-
MLJModelInterface.istransparent(m::Flux.$opt) = true
26-
end
27-
end
3+
# make the optimiser structs "transparent" so that their field values
4+
# are exposed by calls to MLJ.params:
5+
MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true
286

297

308
## GENERAL METHOD TO OPTIMIZE A CHAIN

src/metalhead.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#=
2+
3+
TODO: After https://github.com/FluxML/Metalhead.jl/issues/176:
4+
5+
- Export and externally document `image_builder` method
6+
7+
- Delete definition of `ResNetHack` below
8+
9+
- Change default builder in ImageClassifier (see /src/types.jl) from
10+
`image_builder(ResNetHack)` to `image_builder(Metalhead.ResNet)`.
11+
12+
=#
13+
14+
const DISALLOWED_KWARGS = [:imsize, :inchannels, :nclasses]
15+
const human_disallowed_kwargs = join(map(s->"`$s`", DISALLOWED_KWARGS), ", ", " and ")
16+
const ERR_METALHEAD_DISALLOWED_KWARGS = ArgumentError(
17+
"Keyword arguments $human_disallowed_kwargs are disallowed "*
18+
"as their values are inferred from data. "
19+
)
20+
21+
# # WRAPPING
22+
23+
struct MetalheadBuilder{F} <: MLJFlux.Builder
24+
metalhead_constructor::F
25+
args
26+
kwargs
27+
end
28+
29+
function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder)
30+
println(io, "builder wrapping $(w.metalhead_constructor)")
31+
if !isempty(w.args)
32+
println(io, " args:")
33+
for (i, arg) in enumerate(w.args)
34+
println(io, " 1: $arg")
35+
end
36+
end
37+
if !isempty(w.kwargs)
38+
println(io, " kwargs:")
39+
for kwarg in w.kwargs
40+
println(io, " $(first(kwarg)) = $(last(kwarg))")
41+
end
42+
end
43+
end
44+
45+
Base.show(io::IO, w::MetalheadBuilder) =
46+
print(io, "image_builder($(repr(w.metalhead_constructor)), …)")
47+
48+
49+
"""
50+
image_builder(metalhead_constructor, args...; kwargs...)
51+
52+
Return an MLJFlux builder object based on the Metalhead.jl constructor/type
53+
`metalhead_constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are
54+
passed to the `MetalheadType` constructor at "build time", along with
55+
the extra keyword specifiers `imsize=...`, `inchannels=...` and
56+
`nclasses=...`, with values inferred from the data.
57+
58+
# Example
59+
60+
If in Metalhead.jl you would do
61+
62+
```julia
63+
using Metalhead
64+
model = ResNet(50, pretrain=true, inchannels=1, nclasses=10)
65+
```
66+
67+
then in MLJFlux, it suffices to do
68+
69+
```julia
70+
using MLJFlux, Metalhead
71+
builder = image_builder(ResNet, 50, pretrain=true)
72+
```
73+
74+
which can be used in `ImageClassifier` as in
75+
76+
```julia
77+
clf = ImageClassifier(
78+
builder=builder,
79+
epochs=500,
80+
optimiser=Flux.Adam(0.001),
81+
loss=Flux.crossentropy,
82+
batch_size=5,
83+
)
84+
```
85+
86+
The keyord arguments `imsize`, `inchannels` and `nclasses` are
87+
dissallowed in `kwargs` (see above).
88+
89+
"""
90+
function image_builder(
91+
metalhead_constructor,
92+
args...;
93+
kwargs...
94+
)
95+
kw_names = keys(kwargs)
96+
isempty(intersect(kw_names, DISALLOWED_KWARGS)) ||
97+
throw(ERR_METALHEAD_DISALLOWED_KWARGS)
98+
return MetalheadBuilder(metalhead_constructor, args, kwargs)
99+
end
100+
101+
MLJFlux.build(
102+
b::MetalheadBuilder,
103+
rng,
104+
n_in,
105+
n_out,
106+
n_channels
107+
) = b.metalhead_constructor(
108+
b.args...;
109+
b.kwargs...,
110+
imsize=n_in,
111+
inchannels=n_channels,
112+
nclasses=n_out
113+
)
114+
115+
# See above "TODO" list.
116+
function VGGHack(
117+
depth::Integer=16;
118+
imsize=(242,242),
119+
inchannels=3,
120+
nclasses=1000,
121+
batchnorm=false,
122+
pretrain=false,
123+
)
124+
125+
# Adapted from
126+
# https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165
127+
# But we do not ignore `imsize`.
128+
129+
@assert(
130+
depth in keys(Metalhead.vgg_config),
131+
"depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))"
132+
)
133+
model = Metalhead.VGG(imsize;
134+
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
135+
inchannels,
136+
batchnorm,
137+
nclasses,
138+
fcsize = 4096,
139+
dropout = 0.5)
140+
if pretrain && !batchnorm
141+
Metalhead.loadpretrain!(model, string("VGG", depth))
142+
elseif pretrain
143+
Metalhead.loadpretrain!(model, "VGG$(depth)-BN)")
144+
end
145+
return model
146+
end

src/mlj_model_interface.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ end
4040

4141
# # FIT AND UPDATE
4242

43+
const ERR_BUILDER =
44+
"Builder does not appear to build an architecture compatible with supplied data. "
45+
4346
true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng
4447

4548
function MLJModelInterface.fit(model::MLJFluxModel,
@@ -51,10 +54,24 @@ function MLJModelInterface.fit(model::MLJFluxModel,
5154

5255
rng = true_rng(model)
5356
shape = MLJFlux.shape(model, X, y)
54-
chain = build(model, rng, shape) |> move
57+
58+
chain = try
59+
build(model, rng, shape) |> move
60+
catch ex
61+
@error ERR_BUILDER
62+
end
63+
5564
penalty = Penalty(model)
5665
data = move.(collate(model, X, y))
5766

67+
x = data |> first |> first
68+
try
69+
chain(x)
70+
catch ex
71+
@error ERR_BUILDER
72+
throw(ex)
73+
end
74+
5875
optimiser = deepcopy(model.optimiser)
5976

6077
chain, history = fit!(model.loss,

src/types.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic}
55

66
for Model in [:NeuralNetworkClassifier, :ImageClassifier]
77

8+
default_builder_ex =
9+
Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short()
10+
811
ex = quote
912
mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic
1013
builder::B
@@ -20,7 +23,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier]
2023
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
2124
end
2225

23-
function $Model(; builder::B = Short()
26+
function $Model(; builder::B = $default_builder_ex
2427
, finaliser::F = Flux.softmax
2528
, optimiser::O = Flux.Optimise.Adam()
2629
, loss::L = Flux.crossentropy
@@ -108,12 +111,9 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor]
108111

109112
end
110113

111-
112-
113114
const Regressor =
114115
Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor}
115116

116-
117117
MMI.metadata_pkg.(
118118
(
119119
NeuralNetworkRegressor,

src/utilities.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# # IMAGE COERCION
2+
3+
# Taken from ScientificTypes.jl to avoid as dependency.
4+
5+
_4Dcollection = AbstractArray{<:Real, 4}
6+
7+
function coerce(y::_4Dcollection, T2::Type{GrayImage})
8+
size(y, 3) == 1 || error("Multiple color channels encountered. "*
9+
"Perhaps you want to use `coerce(image_collection, ColorImage)`.")
10+
y = dropdims(y, dims=3)
11+
return [ColorTypes.Gray.(y[:,:,idx]) for idx=1:size(y,3)]
12+
end
13+
14+
function coerce(y::_4Dcollection, T2::Type{ColorImage})
15+
return [broadcast(ColorTypes.RGB, y[:,:,1, idx], y[:,:,2,idx], y[:,:,3, idx]) for idx=1:size(y,4)]
16+
end
17+
18+
19+
# # SYNTHETIC IMAGES
20+
21+
"""
22+
make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05)
23+
24+
Return synthetic data of the form `(images, labels)` suitable for use
25+
with MLJ's `ImageClassifier` model. All `images` are distortions of
26+
`n_classes` fixed images. Two images with the same label correspond to
27+
the same undistorted image.
28+
29+
"""
30+
function make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05)
31+
n_channels = color ? 3 : 1
32+
image_bag = map(1:n_classes) do _
33+
rand(rng, Float32, image_size..., n_channels)
34+
end
35+
labels = rand(rng, 1:3, n_images)
36+
images = map(labels) do j
37+
image_bag[j] + noise*rand(rng, Float32, image_size..., n_channels)
38+
end
39+
T = color ? ColorImage : GrayImage
40+
X = coerce(cat(images...; dims=4), T)
41+
y = categorical(labels)
42+
return X, y
43+
end

test/builders.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ end
5252
end
5353

5454
@testset_accelerated "@builder" accel begin
55-
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(n_in, 4,
56-
init = (out, in) -> randn(rng, out, in)),
57-
Flux.Dense(4, n_out)))
55+
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(
56+
n_in,
57+
4,
58+
init = (out, in) -> randn(rng, out, in)
59+
), Flux.Dense(4, n_out)))
5860
rng = StableRNGs.StableRNG(123)
5961
chain = MLJFlux.build(builder, rng, 5, 3)
6062
ps = Flux.params(chain)

0 commit comments

Comments
 (0)