Skip to content

Commit 3df9fef

Browse files
committed
Turn Leaf struct into a frequency map
1 parent be7a715 commit 3df9fef

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

src/DecisionTree.jl

+17-12
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@ export InfoNode, InfoLeaf, wrap
2828
###########################
2929
########## Types ##########
3030

31-
struct Leaf{T}
32-
majority :: T
33-
values :: Vector{T}
31+
struct Leaf{T, N}
32+
features :: NTuple{N, T}
33+
majority :: Int
34+
values :: NTuple{N, Int}
35+
total :: Int
3436
end
3537

36-
struct Node{S, T}
38+
struct Node{S, T, N}
3739
featid :: Int
3840
featval :: S
39-
left :: Union{Leaf{T}, Node{S, T}}
40-
right :: Union{Leaf{T}, Node{S, T}}
41+
left :: Union{Leaf{T, N}, Node{S, T, N}}
42+
right :: Union{Leaf{T, N}, Node{S, T, N}}
4143
end
4244

4345
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -46,11 +48,15 @@ struct Ensemble{S, T}
4648
trees :: Vector{LeafOrNode{S, T}}
4749
end
4850

51+
Leaf(features::NTuple{T, N}) where {T, N} =
52+
Leaf(features, 0, Tuple(zeros(T, N)), 0)
53+
4954
is_leaf(l::Leaf) = true
5055
is_leaf(n::Node) = false
5156

5257
zero(String) = ""
53-
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)]))
58+
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} =
59+
Node(0, zero(S), lf, Leaf(lf.features))
5460
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
5561
promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T}
5662

@@ -81,9 +87,8 @@ depth(leaf::Leaf) = 0
8187
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
8288

8389
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84-
n_matches = count(leaf.values .== leaf.majority)
85-
ratio = string(n_matches, "/", length(leaf.values))
86-
println(io, "$(leaf.majority) : $(ratio)")
90+
println(io, "$(leaf.features[leaf.majority]) : ",
91+
leaf.values[leaf.majority], '/', leaf.total)
8792
end
8893
function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
8994
return print_tree(stdout, leaf, depth, indent; feature_names=feature_names)
@@ -139,8 +144,8 @@ end
139144

140145
function show(io::IO, leaf::Leaf)
141146
println(io, "Decision Leaf")
142-
println(io, "Majority: $(leaf.majority)")
143-
print(io, "Samples: $(length(leaf.values))")
147+
println(io, "Majority: ", leaf.features[leaf.majority])
148+
print(io, "Samples: ", leaf.total)
144149
end
145150

146151
function show(io::IO, tree::Node)

src/classification/main.jl

+11-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ function _convert(
4141
) where {S, T}
4242

4343
if node.is_leaf
44-
return Leaf{T}(list[node.label], labels[node.region])
44+
features = Tuple(unique(labels))
45+
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
46+
return Leaf{T, length(features)}(
47+
features, argmax(featfreq), featfreq, length(node.region))
4548
else
4649
left = _convert(node.l, list, labels)
4750
right = _convert(node.r, list, labels)
@@ -120,7 +123,10 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
120123
matches = findall(all_labels .== majority)
121124
purity = length(matches) / length(all_labels)
122125
if purity >= purity_thresh
123-
return Leaf{T}(majority, all_labels)
126+
features = Tuple(unique(all_labels))
127+
featfreq = Tuple(sum(all_labels .== f) for f in features)
128+
return Leaf{T}(features, argmax(featfreq),
129+
featfreq, length(all_labels))
124130
else
125131
return tree
126132
end
@@ -139,7 +145,8 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
139145
end
140146

141147

142-
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
148+
apply_tree(leaf::Leaf{T}, feature::AbstractVector{S}) where {S, T} =
149+
leaf.features[leaf.majority]
143150

144151
function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
145152
if tree.featid == 0
@@ -173,7 +180,7 @@ n_labels` matrix of probabilities, each row summing up to 1.
173180
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
174181
of the output matrix. """
175182
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
176-
compute_probabilities(labels, leaf.values)
183+
collect(leaf.values ./ leaf.total)
177184

178185
function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T}
179186
if tree.featval === nothing

src/regression/main.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ include("tree.jl")
22

33
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64}
44
if node.is_leaf
5-
return Leaf{T}(node.label, labels[node.region])
5+
features = Tuple(unique(labels))
6+
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
7+
return Leaf{T, length(features)}(
8+
features, argmax(featfreq), featfreq, length(node.region))
69
else
710
left = _convert(node.l, labels)
811
right = _convert(node.r, labels)

0 commit comments

Comments
 (0)