Skip to content

Commit 3349879

Browse files
authored
Merge pull request #17 from SymbolicML/cleanup-iddict
Clean up IdDict use with macro
2 parents 46dc417 + ea3e91c commit 3349879

13 files changed

+685
-419
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.6.1"
4+
version = "0.7.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
9+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -16,6 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1617

1718
[compat]
1819
LoopVectorization = "0.12"
20+
MacroTools = "0.4, 0.5"
1921
Reexport = "1"
2022
PrecompileTools = "1"
2123
SymbolicUtils = "0.19, ^1.0.5"

benchmark/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3-
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
43
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

benchmark/benchmark_utils.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import DynamicExpressions:
2+
Node, copy_node, set_node!, count_nodes, has_constants, has_operators
3+
4+
# This code is copied from SymbolicRegression.jl and modified
5+
6+
# Return a random node from the tree
7+
function random_node(tree::Node{T})::Node{T} where {T}
8+
if tree.degree == 0
9+
return tree
10+
end
11+
b = 0
12+
c = 0
13+
if tree.degree >= 1
14+
b = count_nodes(tree.l)
15+
end
16+
if tree.degree == 2
17+
c = count_nodes(tree.r)
18+
end
19+
20+
i = rand(1:(1 + b + c))
21+
if i <= b
22+
return random_node(tree.l)
23+
elseif i == b + 1
24+
return tree
25+
end
26+
27+
return random_node(tree.r)
28+
end
29+
30+
function make_random_leaf(nfeatures::Int, ::Type{T})::Node{T} where {T}
31+
if rand() > 0.5
32+
return Node(; val=randn(T))
33+
else
34+
return Node(T; feature=rand(1:nfeatures))
35+
end
36+
end
37+
38+
# Add a random unary/binary operation to the end of a tree
39+
function append_random_op(
40+
tree::Node{T}, operators, nfeatures::Int; makeNewBinOp::Union{Bool,Nothing}=nothing
41+
)::Node{T} where {T}
42+
nuna = length(operators.unaops)
43+
nbin = length(operators.binops)
44+
45+
node = random_node(tree)
46+
while node.degree != 0
47+
node = random_node(tree)
48+
end
49+
50+
if makeNewBinOp === nothing
51+
choice = rand()
52+
makeNewBinOp = choice < nbin / (nuna + nbin)
53+
end
54+
55+
if makeNewBinOp
56+
newnode = Node(
57+
rand(1:nbin), make_random_leaf(nfeatures, T), make_random_leaf(nfeatures, T)
58+
)
59+
else
60+
newnode = Node(rand(1:nuna), make_random_leaf(nfeatures, T))
61+
end
62+
63+
set_node!(node, newnode)
64+
65+
return tree
66+
end
67+
68+
function gen_random_tree_fixed_size(
69+
node_count::Int, operators, nfeatures::Int, ::Type{T}
70+
)::Node{T} where {T}
71+
tree = make_random_leaf(nfeatures, T)
72+
cur_size = count_nodes(tree)
73+
while cur_size < node_count
74+
if cur_size == node_count - 1 # only unary operator allowed.
75+
length(operators.unaops) == 0 && break # We will go over the requested amount, so we must break.
76+
tree = append_random_op(tree, operators, nfeatures; makeNewBinOp=false)
77+
else
78+
tree = append_random_op(tree, operators, nfeatures)
79+
end
80+
cur_size = count_nodes(tree)
81+
end
82+
return tree
83+
end

benchmark/benchmarks.jl

Lines changed: 107 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,122 @@
11
using DynamicExpressions, BenchmarkTools, Random
2+
using DynamicExpressions: copy_node
23

3-
const v_PACKAGE_VERSION = try
4-
VersionNumber(PACKAGE_VERSION)
5-
catch
6-
VersionNumber("v0.0.0")
7-
end
4+
include("benchmark_utils.jl")
85

96
const SUITE = BenchmarkGroup()
107

11-
SUITE["OperatorEnum"] = BenchmarkGroup()
12-
13-
operators = OperatorEnum(;
14-
binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true
15-
)
16-
simple_tree = Node(
17-
2,
18-
Node(
19-
1,
20-
Node(3, Node(1, Node(; val=1.0f0), Node(; feature=2)), Node(2, Node(; val=-1.0f0))),
21-
Node(1, Node(; feature=3), Node(; feature=4)),
22-
),
23-
Node(
24-
4,
25-
Node(3, Node(1, Node(; val=1.0f0), Node(; feature=2)), Node(2, Node(; val=-1.0f0))),
26-
Node(1, Node(; feature=3), Node(; feature=4)),
27-
),
28-
)
29-
for T in (ComplexF32, ComplexF64, Float32, Float64)
30-
if !(T <: Real) && v_PACKAGE_VERSION < v"0.5.0" && v_PACKAGE_VERSION != v"0.0.0"
31-
continue
32-
end
33-
evals = 10
34-
samples = 1_000
35-
n = 1_000
36-
37-
#! format: off
38-
if !haskey(SUITE["OperatorEnum"], T)
39-
SUITE["OperatorEnum"][T] = BenchmarkGroup()
40-
end
41-
42-
for turbo in (false, true)
43-
if turbo && !(T in (Float32, Float64))
8+
function benchmark_evaluation()
9+
suite = BenchmarkGroup()
10+
operators = OperatorEnum(;
11+
binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true
12+
)
13+
for T in (ComplexF32, ComplexF64, Float32, Float64)
14+
if !(T <: Real) && PACKAGE_VERSION < v"0.5.0" && PACKAGE_VERSION != v"0.0.0"
4415
continue
4516
end
46-
extra_key = turbo ? "_turbo" : ""
47-
SUITE["OperatorEnum"][T]["evaluation$(extra_key)"] = @benchmarkable(
48-
eval_tree_array(tree, X, $operators; turbo=$turbo),
49-
evals=evals,
50-
samples=samples,
51-
seconds=5.0,
52-
setup=(
53-
X=randn(MersenneTwister(0), $T, 5, $n);
54-
tree=convert(Node{$T}, copy_node($simple_tree))
17+
suite[T] = BenchmarkGroup()
18+
19+
n = 1_000
20+
21+
#! format: off
22+
for turbo in (false, true)
23+
if turbo && !(T in (Float32, Float64))
24+
continue
25+
end
26+
extra_key = turbo ? "_turbo" : ""
27+
eval_tree_array(
28+
gen_random_tree_fixed_size(20, operators, 5, T),
29+
randn(MersenneTwister(0), T, 5, n),
30+
operators;
31+
turbo=turbo
5532
)
56-
)
57-
if T <: Real
58-
SUITE["OperatorEnum"][T]["derivative$(extra_key)"] = @benchmarkable(
59-
eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo),
60-
evals=evals,
61-
samples=samples,
62-
seconds=5.0,
33+
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
34+
[eval_tree_array(tree, X, $operators; turbo=$turbo) for tree in trees],
6335
setup=(
6436
X=randn(MersenneTwister(0), $T, 5, $n);
65-
tree=convert(Node{$T}, copy_node($simple_tree))
37+
treesize=20;
38+
ntrees=100;
39+
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
6640
)
6741
)
42+
if T <: Real
43+
eval_grad_tree_array(
44+
gen_random_tree_fixed_size(20, operators, 5, T),
45+
randn(MersenneTwister(0), T, 5, n),
46+
operators;
47+
variable=true,
48+
turbo=turbo
49+
)
50+
suite[T]["derivative$(extra_key)"] = @benchmarkable(
51+
[eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo) for tree in trees],
52+
setup=(
53+
X=randn(MersenneTwister(0), $T, 5, $n);
54+
treesize=20;
55+
ntrees=100;
56+
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
57+
)
58+
)
59+
end
60+
end
61+
#! format: on
62+
end
63+
return suite
64+
end
65+
66+
# These macros make the benchmarks work on older versions:
67+
#! format: off
68+
@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node}
69+
PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t))
70+
return :(convert(N, t; preserve_sharing=preserve_sharing))
71+
end
72+
@generated function _copy_node(t; preserve_sharing)
73+
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
74+
return :(copy_node(t; preserve_sharing=preserve_sharing))
75+
end
76+
#! format: on
77+
78+
function benchmark_utilities()
79+
suite = BenchmarkGroup()
80+
operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
81+
for func_k in ("copy", "convert", "simplify_tree", "combine_operators")
82+
suite[func_k] = let s = BenchmarkGroup()
83+
for k in ("break_sharing", "preserve_sharing")
84+
k == "preserve_sharing" &&
85+
func_k in ("simplify_tree", "combine_operators") &&
86+
continue
87+
88+
f = if func_k == "copy"
89+
tree -> _copy_node(tree; preserve_sharing=(k == "preserve_sharing"))
90+
elseif func_k == "convert"
91+
tree -> _convert(
92+
Node{Float64},
93+
tree;
94+
preserve_sharing=(k == "preserve_sharing"),
95+
)
96+
elseif func_k == "simplify_tree"
97+
tree -> simplify_tree(tree, operators)
98+
elseif func_k == "combine_operators"
99+
tree -> combine_operators(tree, operators)
100+
end
101+
102+
#! format: off
103+
s[k] = @benchmarkable(
104+
[$(f)(tree) for tree in trees],
105+
seconds=10.0,
106+
setup=(
107+
ntrees=100;
108+
n=20;
109+
trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees]
110+
)
111+
)
112+
#! format: on
113+
end
114+
s
68115
end
69116
end
70-
#! format: on
117+
118+
return suite
71119
end
120+
121+
SUITE["eval"] = benchmark_evaluation()
122+
SUITE["utils"] = benchmark_utilities()

0 commit comments

Comments
 (0)