|
1 | 1 | using DynamicExpressions, BenchmarkTools, Random
|
| 2 | +using DynamicExpressions: copy_node |
2 | 3 |
|
3 |
| -const v_PACKAGE_VERSION = try |
4 |
| - VersionNumber(PACKAGE_VERSION) |
5 |
| -catch |
6 |
| - VersionNumber("v0.0.0") |
7 |
| -end |
| 4 | +include("benchmark_utils.jl") |
8 | 5 |
|
9 | 6 | const SUITE = BenchmarkGroup()
|
10 | 7 |
|
11 |
| -SUITE["OperatorEnum"] = BenchmarkGroup() |
12 |
| - |
13 |
| -operators = OperatorEnum(; |
14 |
| - binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true |
15 |
| -) |
16 |
| -simple_tree = Node( |
17 |
| - 2, |
18 |
| - Node( |
19 |
| - 1, |
20 |
| - Node(3, Node(1, Node(; val=1.0f0), Node(; feature=2)), Node(2, Node(; val=-1.0f0))), |
21 |
| - Node(1, Node(; feature=3), Node(; feature=4)), |
22 |
| - ), |
23 |
| - Node( |
24 |
| - 4, |
25 |
| - Node(3, Node(1, Node(; val=1.0f0), Node(; feature=2)), Node(2, Node(; val=-1.0f0))), |
26 |
| - Node(1, Node(; feature=3), Node(; feature=4)), |
27 |
| - ), |
28 |
| -) |
29 |
| -for T in (ComplexF32, ComplexF64, Float32, Float64) |
30 |
| - if !(T <: Real) && v_PACKAGE_VERSION < v"0.5.0" && v_PACKAGE_VERSION != v"0.0.0" |
31 |
| - continue |
32 |
| - end |
33 |
| - evals = 10 |
34 |
| - samples = 1_000 |
35 |
| - n = 1_000 |
36 |
| - |
37 |
| - #! format: off |
38 |
| - if !haskey(SUITE["OperatorEnum"], T) |
39 |
| - SUITE["OperatorEnum"][T] = BenchmarkGroup() |
40 |
| - end |
41 |
| - |
42 |
| - for turbo in (false, true) |
43 |
| - if turbo && !(T in (Float32, Float64)) |
| 8 | +function benchmark_evaluation() |
| 9 | + suite = BenchmarkGroup() |
| 10 | + operators = OperatorEnum(; |
| 11 | + binary_operators=[+, -, /, *], unary_operators=[cos, exp], enable_autodiff=true |
| 12 | + ) |
| 13 | + for T in (ComplexF32, ComplexF64, Float32, Float64) |
| 14 | + if !(T <: Real) && PACKAGE_VERSION < v"0.5.0" && PACKAGE_VERSION != v"0.0.0" |
44 | 15 | continue
|
45 | 16 | end
|
46 |
| - extra_key = turbo ? "_turbo" : "" |
47 |
| - SUITE["OperatorEnum"][T]["evaluation$(extra_key)"] = @benchmarkable( |
48 |
| - eval_tree_array(tree, X, $operators; turbo=$turbo), |
49 |
| - evals=evals, |
50 |
| - samples=samples, |
51 |
| - seconds=5.0, |
52 |
| - setup=( |
53 |
| - X=randn(MersenneTwister(0), $T, 5, $n); |
54 |
| - tree=convert(Node{$T}, copy_node($simple_tree)) |
| 17 | + suite[T] = BenchmarkGroup() |
| 18 | + |
| 19 | + n = 1_000 |
| 20 | + |
| 21 | + #! format: off |
| 22 | + for turbo in (false, true) |
| 23 | + if turbo && !(T in (Float32, Float64)) |
| 24 | + continue |
| 25 | + end |
| 26 | + extra_key = turbo ? "_turbo" : "" |
| 27 | + eval_tree_array( |
| 28 | + gen_random_tree_fixed_size(20, operators, 5, T), |
| 29 | + randn(MersenneTwister(0), T, 5, n), |
| 30 | + operators; |
| 31 | + turbo=turbo |
55 | 32 | )
|
56 |
| - ) |
57 |
| - if T <: Real |
58 |
| - SUITE["OperatorEnum"][T]["derivative$(extra_key)"] = @benchmarkable( |
59 |
| - eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo), |
60 |
| - evals=evals, |
61 |
| - samples=samples, |
62 |
| - seconds=5.0, |
| 33 | + suite[T]["evaluation$(extra_key)"] = @benchmarkable( |
| 34 | + [eval_tree_array(tree, X, $operators; turbo=$turbo) for tree in trees], |
63 | 35 | setup=(
|
64 | 36 | X=randn(MersenneTwister(0), $T, 5, $n);
|
65 |
| - tree=convert(Node{$T}, copy_node($simple_tree)) |
| 37 | + treesize=20; |
| 38 | + ntrees=100; |
| 39 | + trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees] |
66 | 40 | )
|
67 | 41 | )
|
| 42 | + if T <: Real |
| 43 | + eval_grad_tree_array( |
| 44 | + gen_random_tree_fixed_size(20, operators, 5, T), |
| 45 | + randn(MersenneTwister(0), T, 5, n), |
| 46 | + operators; |
| 47 | + variable=true, |
| 48 | + turbo=turbo |
| 49 | + ) |
| 50 | + suite[T]["derivative$(extra_key)"] = @benchmarkable( |
| 51 | + [eval_grad_tree_array(tree, X, $operators; variable=true, turbo=$turbo) for tree in trees], |
| 52 | + setup=( |
| 53 | + X=randn(MersenneTwister(0), $T, 5, $n); |
| 54 | + treesize=20; |
| 55 | + ntrees=100; |
| 56 | + trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees] |
| 57 | + ) |
| 58 | + ) |
| 59 | + end |
| 60 | + end |
| 61 | + #! format: on |
| 62 | + end |
| 63 | + return suite |
| 64 | +end |
| 65 | + |
| 66 | +# These macros make the benchmarks work on older versions: |
| 67 | +#! format: off |
| 68 | +@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node} |
| 69 | + PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t)) |
| 70 | + return :(convert(N, t; preserve_sharing=preserve_sharing)) |
| 71 | +end |
| 72 | +@generated function _copy_node(t; preserve_sharing) |
| 73 | + PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing)) |
| 74 | + return :(copy_node(t; preserve_sharing=preserve_sharing)) |
| 75 | +end |
| 76 | +#! format: on |
| 77 | + |
| 78 | +function benchmark_utilities() |
| 79 | + suite = BenchmarkGroup() |
| 80 | + operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp]) |
| 81 | + for func_k in ("copy", "convert", "simplify_tree", "combine_operators") |
| 82 | + suite[func_k] = let s = BenchmarkGroup() |
| 83 | + for k in ("break_sharing", "preserve_sharing") |
| 84 | + k == "preserve_sharing" && |
| 85 | + func_k in ("simplify_tree", "combine_operators") && |
| 86 | + continue |
| 87 | + |
| 88 | + f = if func_k == "copy" |
| 89 | + tree -> _copy_node(tree; preserve_sharing=(k == "preserve_sharing")) |
| 90 | + elseif func_k == "convert" |
| 91 | + tree -> _convert( |
| 92 | + Node{Float64}, |
| 93 | + tree; |
| 94 | + preserve_sharing=(k == "preserve_sharing"), |
| 95 | + ) |
| 96 | + elseif func_k == "simplify_tree" |
| 97 | + tree -> simplify_tree(tree, operators) |
| 98 | + elseif func_k == "combine_operators" |
| 99 | + tree -> combine_operators(tree, operators) |
| 100 | + end |
| 101 | + |
| 102 | + #! format: off |
| 103 | + s[k] = @benchmarkable( |
| 104 | + [$(f)(tree) for tree in trees], |
| 105 | + seconds=10.0, |
| 106 | + setup=( |
| 107 | + ntrees=100; |
| 108 | + n=20; |
| 109 | + trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees] |
| 110 | + ) |
| 111 | + ) |
| 112 | + #! format: on |
| 113 | + end |
| 114 | + s |
68 | 115 | end
|
69 | 116 | end
|
70 |
| - #! format: on |
| 117 | + |
| 118 | + return suite |
71 | 119 | end
|
| 120 | + |
| 121 | +SUITE["eval"] = benchmark_evaluation() |
| 122 | +SUITE["utils"] = benchmark_utilities() |
0 commit comments