Skip to content

Commit 1f1ad6c

Browse files
authored
Merge pull request #61 from SymbolicML:bump-alloc3
Bump allocator version of expression evaluation
2 parents 17f04ad + 15844ad commit 1f1ad6c

18 files changed

+562
-175
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.15.0"
66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
109
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1110
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1211
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -16,17 +15,22 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1615
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1716

1817
[weakdeps]
18+
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
19+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1920
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2021
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

2324
[extensions]
25+
DynamicExpressionsBumperExt = "Bumper"
26+
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
2427
DynamicExpressionsOptimExt = "Optim"
2528
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2629
DynamicExpressionsZygoteExt = "Zygote"
2730

2831
[compat]
2932
Aqua = "0.7"
33+
Bumper = "0.6"
3034
Compat = "3.37, 4"
3135
Enzyme = "^0.11.12"
3236
LoopVectorization = "0.12"
@@ -41,8 +45,10 @@ julia = "1.6"
4145

4246
[extras]
4347
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
48+
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
4449
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4550
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
51+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
4652
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
4753
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4854
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -52,4 +58,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5258
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5359

5460
[targets]
55-
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "Optim", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
61+
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]

benchmark/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
4+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
35
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b"
47
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

benchmark/benchmarks.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using DynamicExpressions, BenchmarkTools, Random
22
using DynamicExpressions.EquationUtilsModule: is_constant
3+
4+
# Trigger extensions:
5+
using LoopVectorization
6+
using Bumper
7+
using StrideArrays
38
using Zygote
9+
410
if PACKAGE_VERSION < v"0.14.0"
511
@eval using DynamicExpressions: Node as GraphNode
612
else
@@ -27,27 +33,46 @@ function benchmark_evaluation()
2733
n = 1_000
2834

2935
#! format: off
30-
for turbo in (false, true)
31-
if turbo && !(T in (Float32, Float64))
32-
continue
36+
for turbo in (false, true), bumper in (false, true)
37+
38+
(turbo || bumper) && !(T in (Float32, Float64)) && continue
39+
if bumper
40+
try
41+
eval_tree_array(Node{T}(val=1.0), ones(T, 5, n), operators; turbo, bumper)
42+
catch e
43+
isa(e, MethodError) || rethrow(e)
44+
@warn "Skipping bumper tests"
45+
continue # Assume its not available
46+
end
47+
end
48+
49+
extra_key = if turbo && bumper
50+
"_turbo_bumper"
51+
elseif turbo
52+
"_turbo"
53+
elseif bumper
54+
"_bumper"
55+
else
56+
""
3357
end
34-
extra_key = turbo ? "_turbo" : ""
58+
extra_kws = bumper ? (; bumper=Val(true)) : ()
3559
eval_tree_array(
3660
gen_random_tree_fixed_size(20, operators, 5, T),
3761
randn(MersenneTwister(0), T, 5, n),
3862
operators;
39-
turbo=turbo
63+
turbo,
64+
extra_kws...
4065
)
4166
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
42-
[eval_tree_array(tree, X, $operators; turbo=$turbo) for tree in trees],
67+
[eval_tree_array(tree, X, $operators; turbo=$turbo, $extra_kws...) for tree in trees],
4368
setup=(
4469
X=randn(MersenneTwister(0), $T, 5, $n);
4570
treesize=20;
4671
ntrees=100;
4772
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
4873
)
4974
)
50-
if T <: Real
75+
if T <: Real && !bumper
5176
eval_grad_tree_array(
5277
gen_random_tree_fixed_size(20, operators, 5, T),
5378
randn(MersenneTwister(0), T, 5, n),

docs/src/eval.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,23 @@ Assuming you are only using a single `OperatorEnum`, you can also use
1313
the following shorthand by using the expression as a function:
1414

1515
```
16-
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
16+
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false))
17+
18+
Evaluate a binary tree (equation) over a given input data matrix. The
19+
operators contain all of the operators used. This function fuses doublets
20+
and triplets of operations for lower memory usage.
1721
1822
# Arguments
19-
- `X::AbstractArray`: The input data to evaluate the tree on.
20-
- `operators::GenericOperatorEnum`: The operators used in the tree.
21-
- `throw_errors::Bool=true`: Whether to throw errors
22-
if they occur during evaluation. Otherwise,
23-
MethodErrors will be caught before they happen and
24-
evaluation will return `nothing`,
25-
rather than throwing an error. This is useful in cases
26-
where you are unsure if a particular tree is valid or not,
27-
and would prefer to work with `nothing` as an output.
23+
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
24+
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
25+
- `operators::OperatorEnum`: The operators used in the tree.
26+
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
27+
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
2828
2929
# Returns
30-
- `output`: the result of the evaluation.
31-
If evaluation failed, `nothing` will be returned for the first argument.
32-
A `false` complete means an operator was called on input types
33-
that it was not defined for. You can change this behavior by
34-
setting `throw_errors=false`.
30+
- `output::AbstractVector{T}`: the result, which is a 1D array.
31+
Any NaN, Inf, or other failure during the evaluation will result in the entire
32+
output array being set to NaN.
3533
```
3634

3735
For example,
@@ -98,7 +96,7 @@ all variables (or, all constants). Both use forward-mode automatic, but use
9896

9997
```@docs
10098
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer) where {T<:Number}
101-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
99+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
102100
```
103101

104102
You can compute gradients this with shorthand notation as well (which by default computes

ext/DynamicExpressionsBumperExt.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
module DynamicExpressionsBumperExt
2+
3+
using Bumper: @no_escape, @alloc
4+
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
5+
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array
6+
7+
import DynamicExpressions.ExtensionInterfaceModule:
8+
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
9+
10+
function bumper_eval_tree_array(
11+
tree::AbstractExpressionNode{T},
12+
cX::AbstractMatrix{T},
13+
operators::OperatorEnum,
14+
::Val{turbo},
15+
) where {T,turbo}
16+
result = similar(cX, axes(cX, 2))
17+
n = size(cX, 2)
18+
all_ok = Ref(false)
19+
@no_escape begin
20+
_result_ok = tree_mapreduce(
21+
# Leaf nodes, we create an allocation and fill
22+
# it with the value of the leaf:
23+
leaf_node -> begin
24+
ar = @alloc(T, n)
25+
ok = if leaf_node.constant
26+
v = leaf_node.val::T
27+
ar .= v
28+
isfinite(v)
29+
else
30+
ar .= view(cX, leaf_node.feature, :)
31+
true
32+
end
33+
ResultOk(ar, ok)
34+
end,
35+
# Branch nodes, we simply pass them to the evaluation kernel:
36+
branch_node -> branch_node,
37+
# In the evaluation kernel, we combine the branch nodes
38+
# with the arrays created by the leaf nodes:
39+
((args::Vararg{Any,M}) where {M}) ->
40+
dispatch_kerns!(operators, args..., Val(turbo)),
41+
tree;
42+
break_sharing=Val(true),
43+
)
44+
x = _result_ok.x
45+
result .= x
46+
all_ok[] = _result_ok.ok
47+
end
48+
return (result, all_ok[])
49+
end
50+
51+
function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo}
52+
cumulator.ok || return cumulator
53+
54+
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
55+
return ResultOk(out, !is_bad_array(out))
56+
end
57+
function dispatch_kerns!(
58+
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
59+
) where {turbo}
60+
cumulator1.ok || return cumulator1
61+
cumulator2.ok || return cumulator2
62+
63+
out = dispatch_kern2!(
64+
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
65+
)
66+
return ResultOk(out, !is_bad_array(out))
67+
end
68+
69+
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}
70+
nuna = counttuple(unaops)
71+
quote
72+
Base.@nif(
73+
$nuna,
74+
i -> i == op_idx,
75+
i -> let op = unaops[i]
76+
return bumper_kern1!(op, cumulator, Val(turbo))
77+
end,
78+
)
79+
end
80+
end
81+
@generated function dispatch_kern2!(
82+
binops, op_idx, cumulator1, cumulator2, ::Val{turbo}
83+
) where {turbo}
84+
nbin = counttuple(binops)
85+
quote
86+
Base.@nif(
87+
$nbin,
88+
i -> i == op_idx,
89+
i -> let op = binops[i]
90+
return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo))
91+
end,
92+
)
93+
end
94+
end
95+
function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F}
96+
@. cumulator = op(cumulator)
97+
return cumulator
98+
end
99+
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F}
100+
@. cumulator1 = op(cumulator1, cumulator2)
101+
return cumulator1
102+
end
103+
104+
end

0 commit comments

Comments
 (0)