Skip to content

Commit 6d6ff0f

Browse files
Add support for training classifiers with one class present (#29)
* skip fitting the classifier if there is only one class * test for ambiguities, non row_tables * bump to `v0.3.3`
1 parent e7f56b7 commit 6d6ff0f

File tree

7 files changed

+82
-83
lines changed

7 files changed

+82
-83
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ jobs:
4040
- os: macos-latest
4141
version: '1'
4242
arch: x64
43-
env:
44-
PYTHON: ''
4543
steps:
46-
- uses: actions/checkout@v3
44+
- uses: actions/checkout@v4
4745
with:
4846
fetch-depth: 0
4947
- uses: julia-actions/setup-julia@v1
@@ -54,6 +52,6 @@ jobs:
5452
- uses: julia-actions/julia-buildpkg@v1
5553
- uses: julia-actions/julia-runtest@v1
5654
- uses: julia-actions/julia-processcoverage@v1
57-
- uses: codecov/codecov-action@v2
55+
- uses: codecov/codecov-action@v3
5856
with:
5957
file: lcov.info

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CatBoost"
22
uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12"
33
authors = ["Beacon Biosignals, Inc."]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
@@ -10,11 +10,15 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1010
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1111

1212
[compat]
13-
Aqua = "0.5"
13+
Aqua = "0.8"
14+
DataFrames = "1.6"
15+
MLJBase = "1"
1416
MLJModelInterface = "1"
17+
MLJTestInterface = "0.2"
1518
OrderedCollections = "1.4"
1619
PythonCall = "0.9"
1720
Tables = "1.4"
21+
Test = "1.6"
1822
julia = "1.6"
1923

2024
[extras]

format/Manifest.toml

Lines changed: 42 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,82 +11,61 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[CSTParser]]
1313
deps = ["Tokenize"]
14-
git-tree-sha1 = "60e9121d9ea044c30a04397e59b00c5d9eb826ee"
14+
git-tree-sha1 = "3ddd48d200eb8ddf9cb3e0189fc059fd49b97c1f"
1515
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
16-
version = "2.5.0"
16+
version = "3.3.6"
1717

1818
[[CommonMark]]
19-
deps = ["Crayons", "JSON", "URIs"]
20-
git-tree-sha1 = "7632afc57f92720a01d9aedf23f413f4e5e21015"
19+
deps = ["Crayons", "JSON", "PrecompileTools", "URIs"]
20+
git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071"
2121
uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6"
22-
version = "0.8.1"
22+
version = "0.8.12"
2323

2424
[[Compat]]
25-
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
26-
git-tree-sha1 = "0a817fbe51c976de090aa8c997b7b719b786118d"
25+
deps = ["Dates", "LinearAlgebra", "UUIDs"]
26+
git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d"
2727
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
28-
version = "3.28.0"
28+
version = "4.10.1"
2929

3030
[[Crayons]]
31-
git-tree-sha1 = "3f71217b538d7aaee0b69ab47d9b7724ca8afa0d"
31+
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
3232
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
33-
version = "4.0.4"
33+
version = "4.1.1"
3434

3535
[[DataStructures]]
3636
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
37-
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
37+
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
3838
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
39-
version = "0.18.9"
39+
version = "0.18.16"
4040

4141
[[Dates]]
4242
deps = ["Printf"]
4343
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
4444

45-
[[DelimitedFiles]]
46-
deps = ["Mmap"]
47-
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
48-
49-
[[Distributed]]
50-
deps = ["Random", "Serialization", "Sockets"]
51-
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
52-
53-
[[DocStringExtensions]]
54-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
55-
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
56-
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
57-
version = "0.8.4"
58-
59-
[[Documenter]]
60-
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
61-
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
62-
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
63-
version = "0.26.3"
64-
6545
[[Downloads]]
6646
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
6747
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
6848

69-
[[IOCapture]]
70-
deps = ["Logging"]
71-
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
72-
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
73-
version = "0.1.1"
49+
[[Glob]]
50+
git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
51+
uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
52+
version = "1.3.1"
7453

7554
[[InteractiveUtils]]
7655
deps = ["Markdown"]
7756
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7857

7958
[[JSON]]
8059
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
81-
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
60+
git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
8261
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
83-
version = "0.21.1"
62+
version = "0.21.4"
8463

8564
[[JuliaFormatter]]
86-
deps = ["CSTParser", "CommonMark", "DataStructures", "Documenter", "Pkg", "Tokenize"]
87-
git-tree-sha1 = "b947b46a3477e4c1ea32a7db66905d6f63dd7076"
65+
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"]
66+
git-tree-sha1 = "8f5295e46f594ad2d8652f1098488a77460080cd"
8867
uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
89-
version = "0.13.10"
68+
version = "1.0.45"
9069

9170
[[LibCURL]]
9271
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
@@ -132,20 +111,32 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
132111
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
133112

134113
[[OrderedCollections]]
135-
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
114+
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
136115
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
137-
version = "1.4.0"
116+
version = "1.6.3"
138117

139118
[[Parsers]]
140-
deps = ["Dates"]
141-
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
119+
deps = ["Dates", "PrecompileTools", "UUIDs"]
120+
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
142121
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
143-
version = "1.1.0"
122+
version = "2.8.1"
144123

145124
[[Pkg]]
146125
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
147126
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
148127

128+
[[PrecompileTools]]
129+
deps = ["Preferences"]
130+
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
131+
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
132+
version = "1.2.0"
133+
134+
[[Preferences]]
135+
deps = ["TOML"]
136+
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
137+
uuid = "21216c6a-2e73-6563-6e65-726566657250"
138+
version = "1.4.1"
139+
149140
[[Printf]]
150141
deps = ["Unicode"]
151142
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -164,21 +155,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
164155
[[Serialization]]
165156
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
166157

167-
[[SharedArrays]]
168-
deps = ["Distributed", "Mmap", "Random", "Serialization"]
169-
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
170-
171158
[[Sockets]]
172159
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
173160

174-
[[SparseArrays]]
175-
deps = ["LinearAlgebra", "Random"]
176-
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
177-
178-
[[Statistics]]
179-
deps = ["LinearAlgebra", "SparseArrays"]
180-
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
181-
182161
[[TOML]]
183162
deps = ["Dates"]
184163
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
@@ -187,19 +166,15 @@ uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
187166
deps = ["ArgTools", "SHA"]
188167
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
189168

190-
[[Test]]
191-
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
192-
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
193-
194169
[[Tokenize]]
195-
git-tree-sha1 = "45b1932b0ec576159181bf75df71d6d86aa9c850"
170+
git-tree-sha1 = "3ac1ac11b09e8033ec93a7993acdb9b68252be6d"
196171
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
197-
version = "0.5.13"
172+
version = "0.5.27"
198173

199174
[[URIs]]
200-
git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355"
175+
git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b"
201176
uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
202-
version = "1.3.0"
177+
version = "1.5.1"
203178

204179
[[UUIDs]]
205180
deps = ["Random", "SHA"]

src/mlj_catboostclassifier.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,22 @@ function model_init(mlj_model::CatBoostClassifier; kw...)
7373
end
7474

7575
function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_first)
76-
verbose = verbosity <= 1 ? false : true
76+
# Check if y_first has only one unique value
77+
unique_classes = pyconvert(Vector, numpy.unique(data_pool.get_label()))
78+
if length(unique_classes) == 1
79+
# Skip training and store the single class
80+
fitresult = (model=nothing, single_class=unique_classes[1], y_first=y_first)
81+
cache = (; mlj_model=deepcopy(mlj_model))
82+
report = (feature_importances=[],) # No feature importances in this case
83+
return (fitresult, cache, report)
84+
end
7785

86+
verbose = verbosity <= 1 ? false : true
7887
model = model_init(mlj_model; verbose)
7988
model.fit(data_pool)
8089

8190
cache = (; mlj_model=deepcopy(mlj_model))
8291
report = (feature_importances=feature_importance(model),)
83-
8492
fitresult = (model, y_first)
8593

8694
return (fitresult, cache, report)
@@ -90,6 +98,14 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,)
9098
MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true
9199

92100
function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
101+
if fitresult[1] === nothing
102+
# Always predict the single class
103+
n = nrow(X_pool)
104+
classes = [fitresult.single_class]
105+
probs = ones(n, 1)
106+
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first)
107+
end
108+
93109
model, y_first = fitresult
94110
classes = pyconvert(Array, model.classes_.tolist())
95111
py_preds = predict_proba(model, X_pool)
@@ -98,6 +114,12 @@ function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
98114
end
99115

100116
function MMI.predict_mode(mlj_model::CatBoostClassifier, fitresult, X_pool)
117+
if fitresult[1] === nothing
118+
# Return probability 1 for the single class
119+
n = nrow(X_pool)
120+
return hcat(ones(n), zeros(n))
121+
end
122+
101123
model, y_first = fitresult
102124
py_preds = predict(model, X_pool)
103125
preds = pyconvert(Array, py_preds)

src/wrapper.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ function feature_importance(py_model::Py)
9393
py_df_importance["importance"] = py_model.feature_importances_
9494
tbl_importance = pandas_to_tbl(py_df_importance)
9595
n_features = size(tbl_importance.name, 1)
96-
feat_importance = [Symbol(tbl_importance.name[i]) => tbl_importance.importance[i] for i in
97-
1:n_features]
96+
feat_importance = [Symbol(tbl_importance.name[i]) => tbl_importance.importance[i]
97+
for i in
98+
1:n_features]
9899
return feat_importance
99100
end
100101

test/mlj_interface.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@
4949
@test isempty(failures)
5050
end
5151
@testset "CatBoostClassifier" begin
52-
for data in [MLJTestInterface.make_binary(), MLJTestInterface.make_multiclass()]
52+
for data in [MLJTestInterface.make_binary(),
53+
MLJTestInterface.make_multiclass(),
54+
MLJTestInterface.make_binary(; row_table=true),
55+
MLJTestInterface.make_multiclass(; row_table=false)]
5356
X = data[1]
5457
y = data[2]
55-
# catboost fails if only 1 class is present when training
56-
# MLJTestInterface splits the data down the middle, so the binary
57-
# data only has one class during training
58-
y[1] = y[end]
5958
failures, summary = MLJTestInterface.test([CatBoostClassifier], X, y;
6059
mod=@__MODULE__, verbosity=0, # bump to debug
6160
throw=false)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ const MLJ_EXAMPLES_DIR = joinpath(@__DIR__, "..", "examples/mlj")
1111
include("wrapper.jl")
1212
include("mlj_interface.jl")
1313

14-
Aqua.test_all(CatBoost; ambiguities=false)
14+
Aqua.test_all(CatBoost)

0 commit comments

Comments
 (0)