1
1
module DynamicExpressionsSymbolicUtilsExt
2
2
3
3
using SymbolicUtils
4
- import DynamicExpressions. EquationModule: Node, DEFAULT_NODE_TYPE
4
+ import DynamicExpressions. EquationModule:
5
+ AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
5
6
import DynamicExpressions. OperatorEnumModule: AbstractOperatorEnum
6
7
import DynamicExpressions. UtilsModule: isgood, isbad, @return_on_false , deprecate_varmap
7
8
import DynamicExpressions. ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
19
20
subs_bad (x) = isgood (x) ? x : Inf
20
21
21
22
function parse_tree_to_eqs (
22
- tree:: Node{T} , operators:: AbstractOperatorEnum , index_functions:: Bool = false
23
+ tree:: AbstractExpressionNode{T} ,
24
+ operators:: AbstractOperatorEnum ,
25
+ index_functions:: Bool = false ,
23
26
) where {T}
24
27
if tree. degree == 0
25
28
# Return constant if needed
26
29
tree. constant && return subs_bad (tree. val:: T )
27
30
return SymbolicUtils. Sym {LiteralReal} (Symbol (" x$(tree. feature) " ))
28
31
end
29
32
# Collect the next children
33
+ # TODO : Type instability!
30
34
children = tree. degree == 2 ? (tree. l, tree. r) : (tree. l,)
31
35
# Get the operation
32
36
op = tree. degree == 2 ? operators. binops[tree. op] : operators. unaops[tree. op]
@@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x
66
70
function split_eq (
67
71
op,
68
72
args,
69
- operators:: AbstractOperatorEnum ;
73
+ operators:: AbstractOperatorEnum ,
74
+ :: Type{N} = Node;
70
75
variable_names:: Union{Array{String,1},Nothing} = nothing ,
71
76
# Deprecated:
72
77
varMap= nothing ,
73
- )
78
+ ) where {N <: AbstractExpressionNode }
74
79
variable_names = deprecate_varmap (variable_names, varMap, :split_eq )
75
80
! (op ∈ (sum, prod, + , * )) && throw (error (" Unsupported operation $op in expression!" ))
76
81
if Symbol (op) == Symbol (sum)
@@ -80,10 +85,10 @@ function split_eq(
80
85
else
81
86
ind = findoperation (op, operators. binops)
82
87
end
83
- return Node (
88
+ return constructorof (N) (
84
89
ind,
85
- convert (Node , args[1 ], operators; variable_names= variable_names),
86
- convert (Node , op (args[2 : end ]. .. ), operators; variable_names= variable_names),
90
+ convert (N , args[1 ], operators; variable_names= variable_names),
91
+ convert (N , op (args[2 : end ]. .. ), operators; variable_names= variable_names),
87
92
)
88
93
end
89
94
96
101
97
102
function Base. convert (
98
103
:: typeof (SymbolicUtils. Symbolic),
99
- tree:: Node ,
104
+ tree:: AbstractExpressionNode ,
100
105
operators:: AbstractOperatorEnum ;
101
106
variable_names:: Union{Array{String,1},Nothing} = nothing ,
102
107
index_functions:: Bool = false ,
@@ -109,20 +114,22 @@ function Base.convert(
109
114
)
110
115
end
111
116
112
- function Base. convert (:: typeof (Node), x:: Number , operators:: AbstractOperatorEnum ; kws... )
113
- return Node (; val= DEFAULT_NODE_TYPE (x))
117
+ function Base. convert (
118
+ :: Type{N} , x:: Number , operators:: AbstractOperatorEnum ; kws...
119
+ ) where {N<: AbstractExpressionNode }
120
+ return constructorof (N)(; val= DEFAULT_NODE_TYPE (x))
114
121
end
115
122
116
123
function Base. convert (
117
- :: typeof (Node) ,
124
+ :: Type{N} ,
118
125
expr:: SymbolicUtils.Symbolic ,
119
126
operators:: AbstractOperatorEnum ;
120
127
variable_names:: Union{Array{String,1},Nothing} = nothing ,
121
- )
128
+ ) where {N <: AbstractExpressionNode }
122
129
variable_names = deprecate_varmap (variable_names, nothing , :convert )
123
130
if ! SymbolicUtils. istree (expr)
124
- variable_names === nothing && return Node (String (expr. name))
125
- return Node (String (expr. name), variable_names)
131
+ variable_names === nothing && return constructorof (N) (String (expr. name))
132
+ return constructorof (N) (String (expr. name), variable_names)
126
133
end
127
134
128
135
# First, we remove integer powers:
@@ -134,20 +141,21 @@ function Base.convert(
134
141
op = convert_to_function (SymbolicUtils. operation (expr), operators)
135
142
args = SymbolicUtils. arguments (expr)
136
143
137
- length (args) > 2 && return split_eq (op, args, operators; variable_names= variable_names)
144
+ length (args) > 2 &&
145
+ return split_eq (op, args, operators, N; variable_names= variable_names)
138
146
ind = if length (args) == 2
139
147
findoperation (op, operators. binops)
140
148
else
141
149
findoperation (op, operators. unaops)
142
150
end
143
151
144
- return Node (
145
- ind, map (x -> convert (Node , x, operators; variable_names= variable_names), args)...
152
+ return constructorof (N) (
153
+ ind, map (x -> convert (N , x, operators; variable_names= variable_names), args)...
146
154
)
147
155
end
148
156
149
157
"""
150
- node_to_symbolic(tree::Node , operators::AbstractOperatorEnum;
158
+ node_to_symbolic(tree::AbstractExpressionNode , operators::AbstractOperatorEnum;
151
159
variable_names::Union{Array{String, 1}, Nothing}=nothing,
152
160
index_functions::Bool=false)
153
161
@@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format.
156
164
157
165
## Arguments
158
166
159
- - `tree::Node `: The equation to convert.
167
+ - `tree::AbstractExpressionNode `: The equation to convert.
160
168
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
161
169
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
162
170
each feature. Default is [x1, x2, x3, ...].
163
171
- `index_functions::Bool=false`: Whether to generate special names for the
164
- operators, which then allows one to convert back to a `Node ` format
172
+ operators, which then allows one to convert back to a `AbstractExpressionNode ` format
165
173
using `symbolic_to_node`.
166
174
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
167
175
"""
168
176
function node_to_symbolic (
169
- tree:: Node ,
177
+ tree:: AbstractExpressionNode ,
170
178
operators:: AbstractOperatorEnum ;
171
179
variable_names:: Union{Array{String,1},Nothing} = nothing ,
172
180
index_functions:: Bool = false ,
@@ -192,13 +200,14 @@ end
192
200
193
201
function symbolic_to_node (
194
202
eqn:: SymbolicUtils.Symbolic ,
195
- operators:: AbstractOperatorEnum ;
203
+ operators:: AbstractOperatorEnum ,
204
+ :: Type{N} = Node;
196
205
variable_names:: Union{Array{String,1},Nothing} = nothing ,
197
206
# Deprecated:
198
207
varMap= nothing ,
199
- ):: Node
208
+ ) where {N <: AbstractExpressionNode }
200
209
variable_names = deprecate_varmap (variable_names, varMap, :symbolic_to_node )
201
- return convert (Node , eqn, operators; variable_names= variable_names)
210
+ return convert (N , eqn, operators; variable_names= variable_names)
202
211
end
203
212
204
213
function multiply_powers (eqn:: Number ):: Tuple{SYMBOLIC_UTILS_TYPES,Bool}
0 commit comments