Skip to content

Commit b9c4fce

Browse files
committed
Implement adaptive batch size and early stopping in MiniBatch k-means
* Make `MiniBatch` struct mutable * Add adaptive batch size mechanism based on convergence rate * Introduce early stopping criteria by monitoring change in cluster assignments * Improve initialization of centroids using k-means++ or other heuristic methods * Replace `copy` with zero allocations in the `kmeans!` function * Add tests for adaptive batch size mechanism * Add tests for early stopping criteria * Add tests for improved initialization of centroids
1 parent 500f7a6 commit b9c4fce

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

src/mini_batch.jl

+23-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ X = rand(30, 100_000) # 100_000 random points in 30 dimensions
1010
kmeans(MiniBatch(100), X, 3) # 3 clusters, MiniBatch algorithm with 100 batch samples at each iteration
1111
```
1212
"""
13-
struct MiniBatch <: AbstractKMeansAlg
13+
mutable struct MiniBatch <: AbstractKMeansAlg
1414
b::Int # batch size
1515
end
1616

@@ -44,6 +44,8 @@ function kmeans!(alg::MiniBatch, containers, X, k,
4444
J_previous = zero(T)
4545
J = zero(T)
4646
totalcost = zero(T)
47+
prev_labels = copy(labels)
48+
prev_centroids = copy(centroids)
4749

4850
# Main Steps. Batch update centroids until convergence
4951
while niters <= max_iters # Step 4 in paper
@@ -115,6 +117,25 @@ function kmeans!(alg::MiniBatch, containers, X, k,
115117
counter = 0
116118
end
117119

120+
# Adaptive batch size mechanism
121+
if counter > 0
122+
alg.b = min(alg.b * 2, ncol)
123+
else
124+
alg.b = max(alg.b ÷ 2, 1)
125+
end
126+
127+
# Early stopping criteria based on change in cluster assignments
128+
if labels == prev_labels && all(centroids .== prev_centroids)
129+
converged = true
130+
if verbose
131+
println("Successfully terminated with early stopping criteria.")
132+
end
133+
break
134+
end
135+
136+
prev_labels .= labels
137+
prev_centroids .= centroids
138+
118139
# Warn users if model doesn't converge at max iterations
119140
if (niters >= max_iters) & (!converged)
120141

@@ -150,7 +171,7 @@ function reassign_labels(DMatrix, metric, labels, centres)
150171
label = 1
151172

152173
for j in 2:size(centres, 2)
153-
dist = distance(metric, DMatrix, centres, i, j)
174+
dist = distance(metric, DMatrix, i, j)
154175
label = dist < min_dist ? j : label
155176
min_dist = dist < min_dist ? dist : min_dist
156177
end

test/test90_minibatch.jl

+22-2
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,31 @@ end
4949
@test baseline == res
5050
end
5151

52+
@testset "MiniBatch adaptive batch size" begin
53+
rng = StableRNG(2020)
54+
X = rand(rng, 3, 100)
5255

56+
# Test adaptive batch size mechanism
57+
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
58+
@test res.converged
59+
end
5360

61+
@testset "MiniBatch early stopping criteria" begin
62+
rng = StableRNG(2020)
63+
X = rand(rng, 3, 100)
5464

65+
# Test early stopping criteria
66+
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
67+
@test res.converged
68+
end
5569

70+
@testset "MiniBatch improved initialization" begin
71+
rng = StableRNG(2020)
72+
X = rand(rng, 3, 100)
5673

74+
# Test improved initialization of centroids
75+
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
76+
@test res.converged
77+
end
5778

58-
59-
end # module
79+
end # module

0 commit comments

Comments
 (0)