Skip to content

Commit 908701f

Browse files
committed
Directly operate on leaf tuples
1 parent 3df9fef commit 908701f

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/classification/main.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ n_labels` matrix of probabilities, each row summing up to 1.
180180
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
181181
of the output matrix. """
182182
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
183-
collect(leaf.values ./ leaf.total)
183+
leaf.values ./ leaf.total
184184

185185
function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T}
186186
if tree.featval === nothing
@@ -192,8 +192,13 @@ function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels)
192192
end
193193
end
194194

195-
apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
196-
stack_function_results(row->apply_tree_proba(tree, row, labels), features)
195+
function apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T}
196+
predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1))
197+
for i in 1:size(features, 1)
198+
predictions[i] = apply_tree_proba(tree, view(features, i, :), labels)
199+
end
200+
reinterpret(reshape, Float64, predictions) |> transpose |> Matrix
201+
end
197202

198203
function build_forest(
199204
labels :: AbstractVector{T},

0 commit comments

Comments
 (0)