diff --git a/Project.toml b/Project.toml index a2f70565..6ce654f4 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -18,7 +19,8 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" CategoricalArrays = "0.10" ColorTypes = "0.10.3, 0.11" ComputationalResources = "0.3.2" -Flux = "0.10.4, 0.11, 0.12, 0.13" +Flux = "0.13" +Metalhead = "0.7" MLJModelInterface = "1.1.1" ProgressMeter = "1.7.1" Tables = "1.0" @@ -26,7 +28,6 @@ julia = "1.6" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -35,4 +36,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "MLDatasets", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"] +test = ["LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"] diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 0c1b84f1..f46f88ef 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -13,12 +13,15 @@ using Statistics using ColorTypes using ComputationalResources using Random +import Metalhead +include("utilities.jl") const MMI=MLJModelInterface include("penalizers.jl") include("core.jl") include("builders.jl") +include("metalhead.jl") include("types.jl") include("regressor.jl") include("classifier.jl") @@ -27,6 +30,7 @@ include("mlj_model_interface.jl") export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor export NeuralNetworkClassifier, ImageClassifier +export CUDALibs, CPU1 diff --git a/src/builders.jl b/src/builders.jl index 2c417c20..b106058a 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -1,4 +1,4 @@ -## BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE +# # BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE # We introduce chain builders as a way of exposing neural network # hyperparameters (describing, architecture, dropout rates, etc) to @@ -9,7 +9,7 @@ # input/output dimensions/shape. # Below n or (n1, n2) etc refers to network inputs, while m or (m1, -# m2) etc refers to outputs. +# m2) etc refers to outputs. abstract type Builder <: MLJModelInterface.MLJType end @@ -38,7 +38,7 @@ using `n_hidden` nodes in the hidden layer and the specified `dropout` (defaulting to 0.5). An activation function `σ` is applied between the hidden and final layers. If `n_hidden=0` (the default) then `n_hidden` is the geometric mean of the number of input and output nodes. The -number of input and output nodes is determined from the data. +number of input and output nodes is determined from the data. The each layer is initialized using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is instead used as the seed for a @@ -96,6 +96,8 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out) end +# # BUILER MACRO + struct GenericBuilder{F} <: Builder apply::F end diff --git a/src/core.jl b/src/core.jl index de2a982d..d94d9f22 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,30 +1,8 @@ ## EXPOSE OPTIMISERS TO MLJ (for eg, tuning) -# Here we make the optimiser structs "transparent" so that their -# field values are exposed by calls to MLJ.params - -for opt in (:Descent, - :Momentum, - :Nesterov, - :RMSProp, - :ADAM, - :RADAM, - :AdaMax, - :OADAM, - :ADAGrad, - :ADADelta, - :AMSGrad, - :NADAM, - :AdaBelief, - :Optimiser, - :InvDecay, :ExpDecay, :WeightDecay, - :ClipValue, - :ClipNorm) # last updated: Flux.jl 0.12.3 - - @eval begin - MLJModelInterface.istransparent(m::Flux.$opt) = true - end -end +# make the optimiser structs "transparent" so that their field values +# are exposed by calls to MLJ.params: +MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true ## GENERAL METHOD TO OPTIMIZE A CHAIN diff --git a/src/metalhead.jl b/src/metalhead.jl new file mode 100644 index 00000000..48d2eaa0 --- /dev/null +++ b/src/metalhead.jl @@ -0,0 +1,146 @@ +#= + +TODO: After https://github.com/FluxML/Metalhead.jl/issues/176: + +- Export and externally document `image_builder` method + +- Delete definition of `ResNetHack` below + +- Change default builder in ImageClassifier (see /src/types.jl) from + `image_builder(ResNetHack)` to `image_builder(Metalhead.ResNet)`. + +=# + +const DISALLOWED_KWARGS = [:imsize, :inchannels, :nclasses] +const human_disallowed_kwargs = join(map(s->"`$s`", DISALLOWED_KWARGS), ", ", " and ") +const ERR_METALHEAD_DISALLOWED_KWARGS = ArgumentError( + "Keyword arguments $human_disallowed_kwargs are disallowed "* + "as their values are inferred from data. " +) + +# # WRAPPING + +struct MetalheadBuilder{F} <: MLJFlux.Builder + metalhead_constructor::F + args + kwargs +end + +function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder) + println(io, "builder wrapping $(w.metalhead_constructor)") + if !isempty(w.args) + println(io, " args:") + for (i, arg) in enumerate(w.args) + println(io, " 1: $arg") + end + end + if !isempty(w.kwargs) + println(io, " kwargs:") + for kwarg in w.kwargs + println(io, " $(first(kwarg)) = $(last(kwarg))") + end + end +end + +Base.show(io::IO, w::MetalheadBuilder) = + print(io, "image_builder($(repr(w.metalhead_constructor)), …)") + + +""" + image_builder(metalhead_constructor, args...; kwargs...) + +Return an MLJFlux builder object based on the Metalhead.jl constructor/type +`metalhead_constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are +passed to the `MetalheadType` constructor at "build time", along with +the extra keyword specifiers `imsize=...`, `inchannels=...` and +`nclasses=...`, with values inferred from the data. + +# Example + +If in Metalhead.jl you would do + +```julia +using Metalhead +model = ResNet(50, pretrain=true, inchannels=1, nclasses=10) +``` + +then in MLJFlux, it suffices to do + +```julia +using MLJFlux, Metalhead +builder = image_builder(ResNet, 50, pretrain=true) +``` + +which can be used in `ImageClassifier` as in + +```julia +clf = ImageClassifier( + builder=builder, + epochs=500, + optimiser=Flux.Adam(0.001), + loss=Flux.crossentropy, + batch_size=5, +) +``` + +The keyord arguments `imsize`, `inchannels` and `nclasses` are +dissallowed in `kwargs` (see above). + +""" +function image_builder( + metalhead_constructor, + args...; + kwargs... +) + kw_names = keys(kwargs) + isempty(intersect(kw_names, DISALLOWED_KWARGS)) || + throw(ERR_METALHEAD_DISALLOWED_KWARGS) + return MetalheadBuilder(metalhead_constructor, args, kwargs) +end + +MLJFlux.build( + b::MetalheadBuilder, + rng, + n_in, + n_out, + n_channels +) = b.metalhead_constructor( + b.args...; + b.kwargs..., + imsize=n_in, + inchannels=n_channels, + nclasses=n_out +) + +# See above "TODO" list. +function VGGHack( + depth::Integer=16; + imsize=(242,242), + inchannels=3, + nclasses=1000, + batchnorm=false, + pretrain=false, +) + + # Adapted from + # https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165 + # But we do not ignore `imsize`. + + @assert( + depth in keys(Metalhead.vgg_config), + "depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))" + ) + model = Metalhead.VGG(imsize; + config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], + inchannels, + batchnorm, + nclasses, + fcsize = 4096, + dropout = 0.5) + if pretrain && !batchnorm + Metalhead.loadpretrain!(model, string("VGG", depth)) + elseif pretrain + Metalhead.loadpretrain!(model, "VGG$(depth)-BN)") + end + return model +end diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index bfac2987..1f2e09d4 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -40,6 +40,9 @@ end # # FIT AND UPDATE +const ERR_BUILDER = + "Builder does not appear to build an architecture compatible with supplied data. " + true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng function MLJModelInterface.fit(model::MLJFluxModel, @@ -51,10 +54,24 @@ function MLJModelInterface.fit(model::MLJFluxModel, rng = true_rng(model) shape = MLJFlux.shape(model, X, y) - chain = build(model, rng, shape) |> move + + chain = try + build(model, rng, shape) |> move + catch ex + @error ERR_BUILDER + end + penalty = Penalty(model) data = move.(collate(model, X, y)) + x = data |> first |> first + try + chain(x) + catch ex + @error ERR_BUILDER + throw(ex) + end + optimiser = deepcopy(model.optimiser) chain, history = fit!(model.loss, diff --git a/src/types.jl b/src/types.jl index 3df93bc7..2e7958ef 100644 --- a/src/types.jl +++ b/src/types.jl @@ -5,6 +5,9 @@ const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} for Model in [:NeuralNetworkClassifier, :ImageClassifier] + default_builder_ex = + Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() + ex = quote mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic builder::B @@ -20,7 +23,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B = Short() + function $Model(; builder::B = $default_builder_ex , finaliser::F = Flux.softmax , optimiser::O = Flux.Optimise.Adam() , loss::L = Flux.crossentropy @@ -108,12 +111,9 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] end - - const Regressor = Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor} - MMI.metadata_pkg.( ( NeuralNetworkRegressor, diff --git a/src/utilities.jl b/src/utilities.jl new file mode 100644 index 00000000..815cef31 --- /dev/null +++ b/src/utilities.jl @@ -0,0 +1,43 @@ +# # IMAGE COERCION + +# Taken from ScientificTypes.jl to avoid as dependency. + +_4Dcollection = AbstractArray{<:Real, 4} + +function coerce(y::_4Dcollection, T2::Type{GrayImage}) + size(y, 3) == 1 || error("Multiple color channels encountered. "* + "Perhaps you want to use `coerce(image_collection, ColorImage)`.") + y = dropdims(y, dims=3) + return [ColorTypes.Gray.(y[:,:,idx]) for idx=1:size(y,3)] +end + +function coerce(y::_4Dcollection, T2::Type{ColorImage}) + return [broadcast(ColorTypes.RGB, y[:,:,1, idx], y[:,:,2,idx], y[:,:,3, idx]) for idx=1:size(y,4)] +end + + +# # SYNTHETIC IMAGES + +""" + make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05) + +Return synthetic data of the form `(images, labels)` suitable for use +with MLJ's `ImageClassifier` model. All `images` are distortions of +`n_classes` fixed images. Two images with the same label correspond to +the same undistorted image. + +""" +function make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05) + n_channels = color ? 3 : 1 + image_bag = map(1:n_classes) do _ + rand(rng, Float32, image_size..., n_channels) + end + labels = rand(rng, 1:3, n_images) + images = map(labels) do j + image_bag[j] + noise*rand(rng, Float32, image_size..., n_channels) + end + T = color ? ColorImage : GrayImage + X = coerce(cat(images...; dims=4), T) + y = categorical(labels) + return X, y +end diff --git a/test/builders.jl b/test/builders.jl index 030cbfa0..8aafa862 100644 --- a/test/builders.jl +++ b/test/builders.jl @@ -52,9 +52,11 @@ end end @testset_accelerated "@builder" accel begin - builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(n_in, 4, - init = (out, in) -> randn(rng, out, in)), - Flux.Dense(4, n_out))) + builder = MLJFlux.@builder(Flux.Chain(Flux.Dense( + n_in, + 4, + init = (out, in) -> randn(rng, out, in) + ), Flux.Dense(4, n_out))) rng = StableRNGs.StableRNG(123) chain = MLJFlux.build(builder, rng, 5, 3) ps = Flux.params(chain) diff --git a/test/classifier.jl b/test/classifier.jl index 135c3020..55bade43 100644 --- a/test/classifier.jl +++ b/test/classifier.jl @@ -19,7 +19,7 @@ end |> categorical; # TODO: replace Short2 -> Short when # https://github.com/FluxML/Flux.jl/issues/1372 is resolved: builder = Short2() -optimiser = Flux.Optimise.ADAM(0.03) +optimiser = Flux.Optimise.Adam(0.03) losses = [] diff --git a/test/core.jl b/test/core.jl index 75e03636..823ca16d 100644 --- a/test/core.jl +++ b/test/core.jl @@ -4,7 +4,7 @@ stable_rng = StableRNGs.StableRNG(123) rowvec(y) = y rowvec(y::Vector) = reshape(y, 1, length(y)) -@test MLJFlux.MLJModelInterface.istransparent(Flux.ADAM(0.1)) +@test MLJFlux.MLJModelInterface.istransparent(Flux.Adam(0.1)) @testset "nrows" begin Xmatrix = rand(stable_rng, 10, 3) @@ -112,7 +112,7 @@ epochs = 10 _chain_yes_drop, history = MLJFlux.fit!(model.loss, penalty, chain_yes_drop, - Flux.Optimise.ADAM(0.001), + Flux.Optimise.Adam(0.001), epochs, 0, data[1], @@ -124,7 +124,7 @@ epochs = 10 _chain_no_drop, history = MLJFlux.fit!(model.loss, penalty, chain_no_drop, - Flux.Optimise.ADAM(0.001), + Flux.Optimise.Adam(0.001), epochs, 0, data[1], diff --git a/test/image.jl b/test/image.jl index 1866b1ed..fd038472 100644 --- a/test/image.jl +++ b/test/image.jl @@ -1,4 +1,4 @@ -## BASIC IMAGE TESTS GREY +# # BASIC IMAGE TESTS GREY Random.seed!(123) stable_rng = StableRNGs.StableRNG(123) @@ -18,16 +18,9 @@ function MLJFlux.build(model::MyNeuralNetwork, rng, ip, op, n_channels) end builder = MyNeuralNetwork((2,2), (2,2)) - -# collection of gray images as a 4D array in WHCN format: -raw_images = rand(stable_rng, Float32, 6, 6, 1, 50); - -# as a vector of Matrix{<:AbstractRGB} -images = coerce(raw_images, GrayImage); -@test scitype(images) == AbstractVector{GrayImage{6,6}} -labels = categorical(rand(stable_rng, 1:5, 50)); - +images, labels = MLJFlux.make_images(stable_rng) losses = [] + @testset_accelerated "ImageClassifier basic tests" accel begin Random.seed!(123) @@ -74,76 +67,12 @@ reference = losses[1] @test all(x->abs(x - reference)/reference < 5e-4, losses[2:end]) -## MNIST IMAGES TEST - -mutable struct MyConvBuilder <: MLJFlux.Builder end - -using MLDatasets - -ENV["DATADEPS_ALWAYS_ACCEPT"] = true -images, labels = MNIST.traindata() -images = coerce(images, GrayImage); -labels = categorical(labels); - -function flatten(x::AbstractArray) - return reshape(x, :, size(x)[end]) -end - -function MLJFlux.build(builder::MyConvBuilder, rng, n_in, n_out, n_channels) - cnn_output_size = [3,3,32] - init = Flux.glorot_uniform(rng) - return Chain( - Conv((3, 3), n_channels=>16, pad=(1,1), relu, init=init), - MaxPool((2,2)), - Conv((3, 3), 16=>32, pad=(1,1), relu, init=init), - MaxPool((2,2)), - Conv((3, 3), 32=>32, pad=(1,1), relu, init=init), - MaxPool((2,2)), - flatten, - Dense(prod(cnn_output_size), n_out, init=init)) -end - -losses = [] - -@testset_accelerated "Image MNIST" accel begin - - Random.seed!(123) - stable_rng = StableRNGs.StableRNG(123) - - model = MLJFlux.ImageClassifier(builder=MyConvBuilder(), - acceleration=accel, - batch_size=50, - rng=stable_rng) - - @time fitresult, cache, _report = - MLJBase.fit(model, 0, images[1:500], labels[1:500]); - first_last_training_loss = _report[1][[1, end]] - push!(losses, first_last_training_loss[2]) -# @show first_last_training_loss +# # BASIC IMAGE TESTS COLOR - pred = mode.(MLJBase.predict(model, fitresult, images[501:600])); - error = misclassification_rate(pred, labels[501:600]) - @test error < 0.2 - -end - -# check different resources (CPU1, CUDALibs, etc)) give about the same loss: -reference = losses[1] -@info "Losses for each computational resource: $losses" -@test all(x->abs(x - reference)/reference < 0.05, losses[2:end]) - - -## BASIC IMAGE TESTS COLOR +# In this case we use the default ResNet builder builder = MyNeuralNetwork((2,2), (2,2)) - -# collection of color images as a 4D array in WHCN format: -raw_images = rand(stable_rng, Float32, 6, 6, 3, 50); - -images = coerce(raw_images, ColorImage); -@test scitype(images) == AbstractVector{ColorImage{6,6}} -labels = categorical(rand(1:5, 50)); - +images, labels = MLJFlux.make_images(stable_rng, color=true) losses = [] @testset_accelerated "ColorImages" accel begin @@ -155,20 +84,18 @@ losses = [] epochs=10, acceleration=accel, rng=stable_rng) - # tests update logic, etc (see test_utililites.jl): @test basictest(MLJFlux.ImageClassifier, images, labels, model.builder, model.optimiser, 0.95, accel) - @time fitresult, cache, _report = MLJBase.fit(model, 0, images, labels) + @time fitresult, cache, _report = MLJBase.fit(model, 0, images, labels); pred = MLJBase.predict(model, fitresult, images[1:6]) first_last_training_loss = _report[1][[1, end]] push!(losses, first_last_training_loss[2]) -# @show first_last_training_loss # try with batch_size > 1: - model = MLJFlux.ImageClassifier(builder=builder, - epochs=10, + model = MLJFlux.ImageClassifier(epochs=10, + builder=builder, batch_size=2, acceleration=accel, rng=stable_rng) @@ -184,4 +111,18 @@ reference = losses[1] @info "Losses for each computational resource: $losses" @test all(x->abs(x - reference)/reference < 1e-5, losses[2:end]) + +# # SMOKE TEST FOR DEFAULT BUILDER + +images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12, noise=0.2, color=true); + +@testset_accelerated "ImageClassifier basic tests" accel begin + model = MLJFlux.ImageClassifier(epochs=10, + batch_size=4, + acceleration=accel, + rng=stable_rng) + fitresult, _, _ = MLJBase.fit(model, 0, images, labels); + predict(model, fitresult, images) +end + true diff --git a/test/metalhead.jl b/test/metalhead.jl new file mode 100644 index 00000000..4260ff78 --- /dev/null +++ b/test/metalhead.jl @@ -0,0 +1,61 @@ +using StableRNGs +using MLJFlux +const Metalhead = MLJFlux.Metalhead + +@testset "display" begin + io = IOBuffer() + builder = MLJFlux.image_builder(MLJFlux.Metalhead.ResNet, 50, pretrain=false) + show(io, MIME("text/plain"), builder) + @test String(take!(io)) == + "builder wrapping Metalhead.ResNet\n args:\n"* + " 1: 50\n kwargs:\n pretrain = false\n" + show(io, builder) + @test String(take!(io)) == "image_builder(Metalhead.ResNet, …)" + close(io) +end + +@testset "disallowed kwargs" begin + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.image_builder(MLJFlux.Metalhead.VGG, imsize=(241, 241)), + ) + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.image_builder(MLJFlux.Metalhead.VGG, inchannels=2), + ) + @test_throws( + MLJFlux.ERR_METALHEAD_DISALLOWED_KWARGS, + MLJFlux.image_builder(MLJFlux.Metalhead.VGG, nclasses=10), + ) +end + +@testset "constructors" begin + depth = 16 + imsize = (128, 128) + nclasses = 10 + inchannels = 1 + builder = MLJFlux.image_builder( + Metalhead.VGG, + depth, + batchnorm=true + ) + @test builder.metalhead_constructor == Metalhead.VGG + @test builder.args == (depth, ) + @test (; builder.kwargs...) == (; batchnorm=true) + ref_chain = Metalhead.VGG( + imsize; + config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]], + inchannels, + batchnorm=true, + nclasses, + fcsize = 4096, + dropout = 0.5 + ) + # needs https://github.com/FluxML/Metalhead.jl/issues/176 + # chain = + # MLJFlux.build(builder, StableRNGs.StableRNG(123), imsize, nclasses, inchannels) + # @test length.(MLJFlux.Flux.params(ref_chain)) == + # length.(MLJFlux.Flux.params(chain)) +end + +true diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 6b15aca4..24b9a59e 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -6,10 +6,6 @@ ModelType = MLJFlux.NeuralNetworkRegressor @test model == clone clone.optimiser.eta *= 10 @test model != clone - - clone = deepcopy(model) - clone.builder.dropout *= 0.5 - @test clone != model end @testset "clean!" begin diff --git a/test/regressor.jl b/test/regressor.jl index 0b6c7c7f..0f05ee72 100644 --- a/test/regressor.jl +++ b/test/regressor.jl @@ -6,7 +6,7 @@ X = MLJBase.table(randn(Float32, N, 5)); # TODO: replace Short2 -> Short when # https://github.com/FluxML/Flux.jl/pull/1618 is resolved: builder = Short2(σ=identity) -optimiser = Flux.Optimise.ADAM() +optimiser = Flux.Optimise.Adam() losses = [] diff --git a/test/runtests.jl b/test/runtests.jl index ab44a92f..fc235899 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,34 +45,49 @@ seed!(123) include("test_utils.jl") -@testset "penalizers" begin +# enable conditional testing of modules by providing test_args +# e.g. `Pkg.test("MLJBase", test_args=["misc"])` +RUN_ALL_TESTS = isempty(ARGS) +macro conditional_testset(name, expr) + name = string(name) + esc(quote + if RUN_ALL_TESTS || $name in ARGS + @testset $name $expr + end + end) +end +@conditional_testset "penalizers" begin include("penalizers.jl") end -@testset "core" begin +@conditional_testset "core" begin include("core.jl") end -@testset "builders" begin +@conditional_testset "builders" begin include("builders.jl") end -@testset "mlj_model_interface" begin +@conditional_testset "metalhead" begin + include("metalhead.jl") +end + +@conditional_testset "mlj_model_interface" begin include("mlj_model_interface.jl") end -@testset "regressor" begin +@conditional_testset "regressor" begin include("regressor.jl") end -@testset "classifier" begin +@conditional_testset "classifier" begin include("classifier.jl") end -@testset "image" begin +@conditional_testset "image" begin include("image.jl") end -@testset "integration" begin +@conditional_testset "integration" begin include("integration.jl") end