@@ -28,16 +28,18 @@ export InfoNode, InfoLeaf, wrap
28
28
# ##########################
29
29
# ######### Types ##########
30
30
31
- struct Leaf{T}
32
- majority :: T
33
- values :: Vector{T}
31
+ struct Leaf{T, N}
32
+ features :: NTuple{N, T}
33
+ majority :: Int
34
+ values :: NTuple{N, Int}
35
+ total :: Int
34
36
end
35
37
36
- struct Node{S, T}
38
+ struct Node{S, T, N }
37
39
featid :: Int
38
40
featval :: S
39
- left :: Union{Leaf{T}, Node{S, T}}
40
- right :: Union{Leaf{T}, Node{S, T}}
41
+ left :: Union{Leaf{T, N }, Node{S, T, N }}
42
+ right :: Union{Leaf{T, N }, Node{S, T, N }}
41
43
end
42
44
43
45
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -46,11 +48,15 @@ struct Ensemble{S, T}
46
48
trees :: Vector{LeafOrNode{S, T}}
47
49
end
48
50
51
+ Leaf (features:: NTuple{T, N} ) where {T, N} =
52
+ Leaf (features, 0 , Tuple (zeros (T, N)), 0 )
53
+
49
54
is_leaf (l:: Leaf ) = true
50
55
is_leaf (n:: Node ) = false
51
56
52
57
zero (String) = " "
53
- convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , zero (S), lf, Leaf (zero (T), [zero (T)]))
58
+ convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} =
59
+ Node (0 , zero (S), lf, Leaf (lf. features))
54
60
promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
55
61
promote_rule (:: Type{Leaf{T}} , :: Type{Node{S, T}} ) where {S, T} = Node{S, T}
56
62
@@ -81,9 +87,8 @@ depth(leaf::Leaf) = 0
81
87
depth (tree:: Node ) = 1 + max (depth (tree. left), depth (tree. right))
82
88
83
89
function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
84
- n_matches = count (leaf. values .== leaf. majority)
85
- ratio = string (n_matches, " /" , length (leaf. values))
86
- println (io, " $(leaf. majority) : $(ratio) " )
90
+ println (io, " $(leaf. features[leaf. majority]) : " ,
91
+ leaf. values[leaf. majority], ' /' , leaf. total)
87
92
end
88
93
function print_tree (leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
89
94
return print_tree (stdout , leaf, depth, indent; feature_names= feature_names)
139
144
140
145
function show (io:: IO , leaf:: Leaf )
141
146
println (io, " Decision Leaf" )
142
- println (io, " Majority: $( leaf. majority) " )
143
- print (io, " Samples: $( length ( leaf. values)) " )
147
+ println (io, " Majority: " , leaf. features[leaf . majority] )
148
+ print (io, " Samples: " , leaf. total )
144
149
end
145
150
146
151
function show (io:: IO , tree:: Node )
0 commit comments