Skip to content

Commit 8109f9c

Browse files
authored
Merge pull request #56 from SymbolicML/MilesCranmer/issue14
Graph-like expressions
2 parents f23ed22 + f8f9678 commit 8109f9c

28 files changed

+1620
-820
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.13.1"
4+
version = "0.14.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -24,6 +24,7 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2424
DynamicExpressionsZygoteExt = "Zygote"
2525

2626
[compat]
27+
Aqua = "0.7"
2728
Compat = "3.37, 4"
2829
LoopVectorization = "0.12"
2930
MacroTools = "0.4, 0.5"

benchmark/benchmarks.jl

+37-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using DynamicExpressions, BenchmarkTools, Random
22
using DynamicExpressions.EquationUtilsModule: is_constant
33
using Zygote
4+
if PACKAGE_VERSION < v"0.14.0"
5+
@eval using DynamicExpressions: Node as GraphNode
6+
else
7+
@eval using DynamicExpressions: GraphNode
8+
end
49

510
include("benchmark_utils.jl")
611

@@ -66,13 +71,15 @@ end
6671

6772
# These macros make the benchmarks work on older versions:
6873
#! format: off
69-
@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node}
74+
@generated function _convert(::Type{N}, t; preserve_sharing) where {N}
7075
PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t))
71-
return :(convert(N, t; preserve_sharing=preserve_sharing))
76+
PACKAGE_VERSION < v"0.14.0" && return :(convert(N, t; preserve_sharing=preserve_sharing))
77+
return :(convert(N, t)) # Assume type used to infer sharing
7278
end
7379
@generated function _copy_node(t; preserve_sharing)
7480
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
75-
return :(copy_node(t; preserve_sharing=preserve_sharing))
81+
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
82+
return :(copy_node(t)) # Assume type used to infer sharing
7683
end
7784
@generated function get_set_constants!(tree)
7885
!(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree)))
@@ -101,13 +108,36 @@ function benchmark_utilities()
101108
:index_constants,
102109
:string_tree,
103110
)
111+
has_both_modes = [:copy, :convert]
112+
if PACKAGE_VERSION >= v"0.14.0"
113+
append!(
114+
has_both_modes,
115+
[
116+
:simplify_tree,
117+
:count_nodes,
118+
:count_constants,
119+
:get_set_constants!,
120+
:index_constants,
121+
:string_tree,
122+
],
123+
)
124+
end
104125

105126
operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
106127
for func_k in all_funcs
107128
suite[func_k] = let s = BenchmarkGroup()
108-
for k in (:break_sharing, :preserve_sharing)
109-
has_both_modes = func_k in (:copy, :convert)
110-
k == :preserve_sharing && !has_both_modes && continue
129+
for k in (
130+
if func_k in has_both_modes
131+
[:break_sharing, :preserve_sharing]
132+
else
133+
[:break_sharing]
134+
end
135+
)
136+
preprocess = if k == :preserve_sharing && PACKAGE_VERSION >= v"0.14.0"
137+
tree -> GraphNode(tree)
138+
else
139+
identity
140+
end
111141

112142
f = if func_k == :copy
113143
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
@@ -132,12 +162,9 @@ function benchmark_utilities()
132162
setup=(
133163
ntrees=100;
134164
n=20;
135-
trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees]
165+
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
136166
)
137167
)
138-
if !has_both_modes
139-
s = s[k]
140-
end
141168
#! format: on
142169
end
143170
s

docs/src/types.md

+70-15
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,7 @@ Equations are specified as binary trees with the `Node` type, defined
4848
as follows:
4949

5050
```@docs
51-
Node{T}
52-
```
53-
54-
There are a variety of constructors for `Node` objects, including:
55-
56-
```@docs
57-
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
58-
Node(op::Integer, l::Node)
59-
Node(op::Integer, l::Node, r::Node)
60-
Node(var_string::String)
51+
Node
6152
```
6253

6354
When you create an `Options` object, the operators
@@ -69,23 +60,87 @@ When using these node constructors, types will automatically be promoted.
6960
You can convert the type of a node using `convert`:
7061

7162
```@docs
72-
convert(::Type{Node{T1}}, tree::Node{T2}) where {T1, T2}
63+
convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2}
7364
```
7465

7566
You can set a `tree` (in-place) with `set_node!`:
7667

7768
```@docs
78-
set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
69+
set_node!
7970
```
8071

8172
You can create a copy of a node with `copy_node`:
8273

8374
```@docs
84-
copy_node(tree::Node)
75+
copy_node
76+
```
77+
78+
## Graph-Like Equations
79+
80+
You can describe an equation as a *graph* rather than a tree
81+
by using the `GraphNode` type:
82+
83+
```@docs
84+
GraphNode{T}
85+
```
86+
87+
This makes it so you can have multiple parents for a given node,
88+
and share parts of an expression. For example:
89+
90+
```julia
91+
julia> operators = OperatorEnum(;
92+
binary_operators=[+, -, *], unary_operators=[cos, sin, exp]
93+
);
94+
95+
julia> x1, x2 = GraphNode(feature=1), GraphNode(feature=2)
96+
(x1, x2)
97+
98+
julia> y = sin(x1) + 1.5
99+
sin(x1) + 1.5
100+
101+
julia> z = exp(y) + y
102+
exp(sin(x1) + 1.5) + {(sin(x1) + 1.5)}
103+
```
104+
105+
Here, the curly braces `{}` indicate that the node
106+
is shared by another (or more) parent node.
107+
108+
This means that we only need to change it once
109+
to have changes propagate across the expression:
110+
111+
```julia
112+
julia> y.r.val *= 0.9
113+
1.35
114+
115+
julia> z
116+
exp(sin(x1) + 1.35) + {(sin(x1) + 1.35)}
117+
```
118+
119+
This also means there are fewer nodes to describe an expression:
120+
121+
```julia
122+
julia> length(z)
123+
6
124+
125+
julia> length(convert(Node, z))
126+
10
127+
```
128+
129+
where we have converted the `GraphNode` to a `Node` type,
130+
which breaks shared connections into separate nodes.
131+
132+
## Abstract Types
133+
134+
Both the `Node` and `GraphNode` types are subtypes of the abstract type:
135+
136+
```@docs
137+
AbstractExpressionNode{T}
85138
```
86139

87-
There is also an abstract type `AbstractNode` which is a supertype of `Node`:
140+
which can be used to create additional expression-like types.
141+
The supertype of this abstract type is the `AbstractNode` type,
142+
which is more generic but does not have all of the same methods:
88143

89144
```@docs
90-
AbstractNode
145+
AbstractNode{T}
91146
```

ext/DynamicExpressionsSymbolicUtilsExt.jl

+33-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module DynamicExpressionsSymbolicUtilsExt
22

33
using SymbolicUtils
4-
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
4+
import DynamicExpressions.EquationModule:
5+
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
56
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
67
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
78
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
@@ -19,14 +20,17 @@ end
1920
subs_bad(x) = isgood(x) ? x : Inf
2021

2122
function parse_tree_to_eqs(
22-
tree::Node{T}, operators::AbstractOperatorEnum, index_functions::Bool=false
23+
tree::AbstractExpressionNode{T},
24+
operators::AbstractOperatorEnum,
25+
index_functions::Bool=false,
2326
) where {T}
2427
if tree.degree == 0
2528
# Return constant if needed
2629
tree.constant && return subs_bad(tree.val::T)
2730
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
2831
end
2932
# Collect the next children
33+
# TODO: Type instability!
3034
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
3135
# Get the operation
3236
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
@@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x
6670
function split_eq(
6771
op,
6872
args,
69-
operators::AbstractOperatorEnum;
73+
operators::AbstractOperatorEnum,
74+
::Type{N}=Node;
7075
variable_names::Union{Array{String,1},Nothing}=nothing,
7176
# Deprecated:
7277
varMap=nothing,
73-
)
78+
) where {N<:AbstractExpressionNode}
7479
variable_names = deprecate_varmap(variable_names, varMap, :split_eq)
7580
!(op (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!"))
7681
if Symbol(op) == Symbol(sum)
@@ -80,10 +85,10 @@ function split_eq(
8085
else
8186
ind = findoperation(op, operators.binops)
8287
end
83-
return Node(
88+
return constructorof(N)(
8489
ind,
85-
convert(Node, args[1], operators; variable_names=variable_names),
86-
convert(Node, op(args[2:end]...), operators; variable_names=variable_names),
90+
convert(N, args[1], operators; variable_names=variable_names),
91+
convert(N, op(args[2:end]...), operators; variable_names=variable_names),
8792
)
8893
end
8994

@@ -96,7 +101,7 @@ end
96101

97102
function Base.convert(
98103
::typeof(SymbolicUtils.Symbolic),
99-
tree::Node,
104+
tree::AbstractExpressionNode,
100105
operators::AbstractOperatorEnum;
101106
variable_names::Union{Array{String,1},Nothing}=nothing,
102107
index_functions::Bool=false,
@@ -109,20 +114,22 @@ function Base.convert(
109114
)
110115
end
111116

112-
function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
113-
return Node(; val=DEFAULT_NODE_TYPE(x))
117+
function Base.convert(
118+
::Type{N}, x::Number, operators::AbstractOperatorEnum; kws...
119+
) where {N<:AbstractExpressionNode}
120+
return constructorof(N)(; val=DEFAULT_NODE_TYPE(x))
114121
end
115122

116123
function Base.convert(
117-
::typeof(Node),
124+
::Type{N},
118125
expr::SymbolicUtils.Symbolic,
119126
operators::AbstractOperatorEnum;
120127
variable_names::Union{Array{String,1},Nothing}=nothing,
121-
)
128+
) where {N<:AbstractExpressionNode}
122129
variable_names = deprecate_varmap(variable_names, nothing, :convert)
123130
if !SymbolicUtils.istree(expr)
124-
variable_names === nothing && return Node(String(expr.name))
125-
return Node(String(expr.name), variable_names)
131+
variable_names === nothing && return constructorof(N)(String(expr.name))
132+
return constructorof(N)(String(expr.name), variable_names)
126133
end
127134

128135
# First, we remove integer powers:
@@ -134,20 +141,21 @@ function Base.convert(
134141
op = convert_to_function(SymbolicUtils.operation(expr), operators)
135142
args = SymbolicUtils.arguments(expr)
136143

137-
length(args) > 2 && return split_eq(op, args, operators; variable_names=variable_names)
144+
length(args) > 2 &&
145+
return split_eq(op, args, operators, N; variable_names=variable_names)
138146
ind = if length(args) == 2
139147
findoperation(op, operators.binops)
140148
else
141149
findoperation(op, operators.unaops)
142150
end
143151

144-
return Node(
145-
ind, map(x -> convert(Node, x, operators; variable_names=variable_names), args)...
152+
return constructorof(N)(
153+
ind, map(x -> convert(N, x, operators; variable_names=variable_names), args)...
146154
)
147155
end
148156

149157
"""
150-
node_to_symbolic(tree::Node, operators::AbstractOperatorEnum;
158+
node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum;
151159
variable_names::Union{Array{String, 1}, Nothing}=nothing,
152160
index_functions::Bool=false)
153161
@@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format.
156164
157165
## Arguments
158166
159-
- `tree::Node`: The equation to convert.
167+
- `tree::AbstractExpressionNode`: The equation to convert.
160168
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
161169
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
162170
each feature. Default is [x1, x2, x3, ...].
163171
- `index_functions::Bool=false`: Whether to generate special names for the
164-
operators, which then allows one to convert back to a `Node` format
172+
operators, which then allows one to convert back to a `AbstractExpressionNode` format
165173
using `symbolic_to_node`.
166174
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
167175
"""
168176
function node_to_symbolic(
169-
tree::Node,
177+
tree::AbstractExpressionNode,
170178
operators::AbstractOperatorEnum;
171179
variable_names::Union{Array{String,1},Nothing}=nothing,
172180
index_functions::Bool=false,
@@ -192,13 +200,14 @@ end
192200

193201
function symbolic_to_node(
194202
eqn::SymbolicUtils.Symbolic,
195-
operators::AbstractOperatorEnum;
203+
operators::AbstractOperatorEnum,
204+
::Type{N}=Node;
196205
variable_names::Union{Array{String,1},Nothing}=nothing,
197206
# Deprecated:
198207
varMap=nothing,
199-
)::Node
208+
) where {N<:AbstractExpressionNode}
200209
variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node)
201-
return convert(Node, eqn, operators; variable_names=variable_names)
210+
return convert(N, eqn, operators; variable_names=variable_names)
202211
end
203212

204213
function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}

src/DynamicExpressions.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@ import PackageExtensionCompat: @require_extensions
1515
import Reexport: @reexport
1616
@reexport import .EquationModule:
1717
AbstractNode,
18+
AbstractExpressionNode,
19+
GraphNode,
1820
Node,
1921
string_tree,
2022
print_tree,
2123
copy_node,
2224
set_node!,
2325
tree_mapreduce,
2426
filter_map
27+
import .EquationModule: constructorof, preserve_sharing
2528
@reexport import .EquationUtilsModule:
2629
count_nodes,
2730
count_constants,
@@ -38,7 +41,7 @@ import Reexport: @reexport
3841
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
3942
@reexport import .EvaluateEquationDerivativeModule:
4043
eval_diff_tree_array, eval_grad_tree_array
41-
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
44+
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
4245
@reexport import .EvaluationHelpersModule
4346
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
4447

0 commit comments

Comments
 (0)