Skip to content

Commit 1bc64c2

Browse files
authored
Merge pull request #10 from SymbolicML/cleaner-barriers
Reduce specialization in evaluation methods
2 parents 7334f79 + 9acf199 commit 1bc64c2

9 files changed

+251
-220
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/DynamicExpressions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,8 @@ const PACKAGE_VERSION = let
3939
VersionNumber(project["version"])
4040
end
4141

42+
macro ignore(args...) end
43+
# To get LanguageServer to register library within tests
44+
@ignore include("../test/runtests.jl")
45+
4246
end

src/EvaluateEquation.jl

Lines changed: 78 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EvaluateEquationModule
33
import LoopVectorization: @turbo, indices
44
import ..EquationModule: Node, string_tree
55
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6-
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, vals
6+
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array
77
import ..EquationUtilsModule: is_constant
88

99
macro return_on_check(val, T, n)
@@ -28,7 +28,7 @@ macro return_on_nonfinite_array(array, T, n)
2828
end
2929

3030
"""
31-
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool)
31+
eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
3232
3333
Evaluate a binary tree (equation) over a given input data matrix. The
3434
operators contain all of the operators used. This function fuses doublets
@@ -88,6 +88,7 @@ end
8888
function _eval_tree_array(
8989
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
9090
)::Tuple{AbstractVector{T},Bool} where {T<:Real,turbo}
91+
n = size(cX, 2)
9192
# First, we see if there are only constants in the tree - meaning
9293
# we can just return the constant result.
9394
if tree.degree == 0
@@ -98,104 +99,88 @@ function _eval_tree_array(
9899
!flag && return Array{T,1}(undef, size(cX, 2)), false
99100
return fill(result, size(cX, 2)), true
100101
elseif tree.degree == 1
102+
op = operators.unaops[tree.op]
101103
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
102104
# op(op2(x, y)), where x, y, z are constants or variables.
103-
return deg1_l2_ll0_lr0_eval(
104-
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
105-
)
105+
op_l = operators.binops[tree.l.op]
106+
return deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo))
106107
elseif tree.l.degree == 1 && tree.l.l.degree == 0
107108
# op(op2(x)), where x is a constant or variable.
108-
return deg1_l1_ll0_eval(
109-
tree, cX, vals[tree.op], vals[tree.l.op], operators, Val(turbo)
110-
)
111-
else
112-
# op(x), for any x.
113-
return deg1_eval(tree, cX, vals[tree.op], operators, Val(turbo))
109+
op_l = operators.unaops[tree.l.op]
110+
return deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo))
114111
end
112+
113+
# op(x), for any x.
114+
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
115+
@return_on_false complete cumulator
116+
@return_on_nonfinite_array cumulator T n
117+
return deg1_eval(cumulator, op, Val(turbo))
118+
115119
elseif tree.degree == 2
120+
op = operators.binops[tree.op]
116121
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
122+
# op(x, y), where x, y are constants or variables.
117123
if tree.l.degree == 0 && tree.r.degree == 0
118-
# op(x, y), where x, y are constants or variables.
119-
return deg2_l0_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
120-
elseif tree.l.degree == 0
121-
# op(x, y), where x is a constant or variable but y is not.
122-
return deg2_l0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
124+
return deg2_l0_r0_eval(tree, cX, op, Val(turbo))
123125
elseif tree.r.degree == 0
126+
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
127+
@return_on_false complete cumulator_l
128+
@return_on_nonfinite_array cumulator_l T n
124129
# op(x, y), where y is a constant or variable but x is not.
125-
return deg2_r0_eval(tree, cX, vals[tree.op], operators, Val(turbo))
126-
else
127-
# op(x, y), for any x or y
128-
return deg2_eval(tree, cX, vals[tree.op], operators, Val(turbo))
130+
return deg2_r0_eval(tree, cumulator_l, cX, op, Val(turbo))
131+
elseif tree.l.degree == 0
132+
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
133+
@return_on_false complete cumulator_r
134+
@return_on_nonfinite_array cumulator_r T n
135+
# op(x, y), where x is a constant or variable but y is not.
136+
return deg2_l0_eval(tree, cumulator_r, cX, op, Val(turbo))
129137
end
138+
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
139+
@return_on_false complete cumulator_l
140+
@return_on_nonfinite_array cumulator_l T n
141+
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
142+
@return_on_false complete cumulator_r
143+
@return_on_nonfinite_array cumulator_r T n
144+
# op(x, y), for any x or y
145+
return deg2_eval(cumulator_l, cumulator_r, op, Val(turbo))
130146
end
131147
end
132148

133149
function deg2_eval(
134-
tree::Node{T},
135-
cX::AbstractMatrix{T},
136-
::Val{op_idx},
137-
operators::OperatorEnum,
138-
::Val{turbo},
139-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
140-
n = size(cX, 2)
141-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
142-
@return_on_false complete cumulator
143-
@return_on_nonfinite_array cumulator T n
144-
(array2, complete2) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
145-
@return_on_false complete2 cumulator
146-
@return_on_nonfinite_array array2 T n
147-
op = operators.binops[op_idx]
148-
149-
# We check inputs (and intermediates), not outputs.
150-
@maybe_turbo turbo for j in indices(cumulator)
151-
x = op(cumulator[j], array2[j])::T
152-
cumulator[j] = x
150+
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo}
151+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
152+
@maybe_turbo turbo for j in indices(cumulator_l)
153+
x = op(cumulator_l[j], cumulator_r[j])::T
154+
cumulator_l[j] = x
153155
end
154-
# return (cumulator, finished_loop) #
155-
return (cumulator, true)
156+
return (cumulator_l, true)
156157
end
157158

158159
function deg1_eval(
159-
tree::Node{T},
160-
cX::AbstractMatrix{T},
161-
::Val{op_idx},
162-
operators::OperatorEnum,
163-
::Val{turbo},
164-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
165-
n = size(cX, 2)
166-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
167-
@return_on_false complete cumulator
168-
@return_on_nonfinite_array cumulator T n
169-
op = operators.unaops[op_idx]
160+
cumulator::AbstractVector{T}, op::F, ::Val{turbo}
161+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
170162
@maybe_turbo turbo for j in indices(cumulator)
171163
x = op(cumulator[j])::T
172164
cumulator[j] = x
173165
end
174-
return (cumulator, true) #
166+
return (cumulator, true)
175167
end
176168

177169
function deg0_eval(
178170
tree::Node{T}, cX::AbstractMatrix{T}
179171
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
180-
n = size(cX, 2)
181172
if tree.constant
173+
n = size(cX, 2)
182174
return (fill(tree.val::T, n), true)
183175
else
184176
return (cX[tree.feature, :], true)
185177
end
186178
end
187179

188180
function deg1_l2_ll0_lr0_eval(
189-
tree::Node{T},
190-
cX::AbstractMatrix{T},
191-
::Val{op_idx},
192-
::Val{op_l_idx},
193-
operators::OperatorEnum,
194-
::Val{turbo},
195-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
181+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
182+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
196183
n = size(cX, 2)
197-
op = operators.unaops[op_idx]
198-
op_l = operators.binops[op_l_idx]
199184
if tree.l.l.constant && tree.l.r.constant
200185
val_ll = tree.l.l.val::T
201186
val_lr = tree.l.r.val::T
@@ -243,16 +228,9 @@ end
243228

244229
# op(op2(x)) for x variable or constant
245230
function deg1_l1_ll0_eval(
246-
tree::Node{T},
247-
cX::AbstractMatrix{T},
248-
::Val{op_idx},
249-
::Val{op_l_idx},
250-
operators::OperatorEnum,
251-
::Val{turbo},
252-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,op_l_idx,turbo}
231+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
232+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
253233
n = size(cX, 2)
254-
op = operators.unaops[op_idx]
255-
op_l = operators.unaops[op_l_idx]
256234
if tree.l.l.constant
257235
val_ll = tree.l.l.val::T
258236
@return_on_check val_ll T n
@@ -275,14 +253,9 @@ end
275253

276254
# op(x, y) for x and y variable/constant
277255
function deg2_l0_r0_eval(
278-
tree::Node{T},
279-
cX::AbstractMatrix{T},
280-
::Val{op_idx},
281-
operators::OperatorEnum,
282-
::Val{turbo},
283-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
256+
tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
257+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
284258
n = size(cX, 2)
285-
op = operators.binops[op_idx]
286259
if tree.l.constant && tree.r.constant
287260
val_l = tree.l.val::T
288261
@return_on_check val_l T n
@@ -323,17 +296,9 @@ end
323296

324297
# op(x, y) for x variable/constant, y arbitrary
325298
function deg2_l0_eval(
326-
tree::Node{T},
327-
cX::AbstractMatrix{T},
328-
::Val{op_idx},
329-
operators::OperatorEnum,
330-
::Val{turbo},
331-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
299+
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
300+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
332301
n = size(cX, 2)
333-
(cumulator, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
334-
@return_on_false complete cumulator
335-
@return_on_nonfinite_array cumulator T n
336-
op = operators.binops[op_idx]
337302
if tree.l.constant
338303
val = tree.l.val::T
339304
@return_on_check val T n
@@ -353,17 +318,9 @@ end
353318

354319
# op(x, y) for x arbitrary, y variable/constant
355320
function deg2_r0_eval(
356-
tree::Node{T},
357-
cX::AbstractMatrix{T},
358-
::Val{op_idx},
359-
operators::OperatorEnum,
360-
::Val{turbo},
361-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,turbo}
321+
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
322+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
362323
n = size(cX, 2)
363-
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
364-
@return_on_false complete cumulator
365-
@return_on_nonfinite_array cumulator T n
366-
op = operators.binops[op_idx]
367324
if tree.r.constant
368325
val = tree.r.val::T
369326
@return_on_check val T n
@@ -394,9 +351,9 @@ function _eval_constant_tree(
394351
if tree.degree == 0
395352
return deg0_eval_constant(tree)
396353
elseif tree.degree == 1
397-
return deg1_eval_constant(tree, vals[tree.op], operators)
354+
return deg1_eval_constant(tree, operators.unaops[tree.op], operators)
398355
else
399-
return deg2_eval_constant(tree, vals[tree.op], operators)
356+
return deg2_eval_constant(tree, operators.binops[tree.op], operators)
400357
end
401358
end
402359

@@ -405,19 +362,17 @@ end
405362
end
406363

407364
function deg1_eval_constant(
408-
tree::Node{T}, ::Val{op_idx}, operators::OperatorEnum
409-
)::Tuple{T,Bool} where {T<:Real,op_idx}
410-
op = operators.unaops[op_idx]
365+
tree::Node{T}, op::F, operators::OperatorEnum
366+
)::Tuple{T,Bool} where {T<:Real,F}
411367
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
412368
!complete && return zero(T), false
413369
output = op(cumulator)::T
414370
return output, isfinite(output)
415371
end
416372

417373
function deg2_eval_constant(
418-
tree::Node{T}, ::Val{op_idx}, operators::OperatorEnum
419-
)::Tuple{T,Bool} where {T<:Real,op_idx}
420-
op = operators.binops[op_idx]
374+
tree::Node{T}, op::F, operators::OperatorEnum
375+
)::Tuple{T,Bool} where {T<:Real,F}
421376
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
422377
!complete && return zero(T), false
423378
(cumulator2, complete2) = _eval_constant_tree(tree.r, operators)
@@ -442,31 +397,29 @@ function differentiable_eval_tree_array(
442397
return (cX[tree.feature, :], true)
443398
end
444399
elseif tree.degree == 1
445-
return deg1_diff_eval(tree, cX, vals[tree.op], operators)
400+
return deg1_diff_eval(tree, cX, operators.unaops[tree.op], operators)
446401
else
447-
return deg2_diff_eval(tree, cX, vals[tree.op], operators)
402+
return deg2_diff_eval(tree, cX, operators.binops[tree.op], operators)
448403
end
449404
end
450405

451406
function deg1_diff_eval(
452-
tree::Node{T1}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
453-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,T1}
407+
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
408+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
454409
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
455410
@return_on_false complete left
456-
op = operators.unaops[op_idx]
457411
out = op.(left)
458412
no_nans = !any(x -> (!isfinite(x)), out)
459413
return (out, no_nans)
460414
end
461415

462416
function deg2_diff_eval(
463-
tree::Node{T1}, cX::AbstractMatrix{T}, ::Val{op_idx}, operators::OperatorEnum
464-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,op_idx,T1}
417+
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
418+
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
465419
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
466420
@return_on_false complete left
467421
(right, complete2) = differentiable_eval_tree_array(tree.r, cX, operators)
468422
@return_on_false complete2 left
469-
op = operators.binops[op_idx]
470423
out = op.(left, right)
471424
no_nans = !any(x -> (!isfinite(x)), out)
472425
return (out, no_nans)
@@ -557,30 +510,32 @@ function _eval_tree_array_generic(
557510
end
558511
end
559512
elseif tree.degree == 1
560-
return deg1_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
513+
return deg1_eval_generic(
514+
tree, cX, operators.unaops[tree.op], operators, Val(throw_errors)
515+
)
561516
else
562-
return deg2_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
517+
return deg2_eval_generic(
518+
tree, cX, operators.binops[tree.op], operators, Val(throw_errors)
519+
)
563520
end
564521
end
565522

566523
function deg1_eval_generic(
567-
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
568-
) where {op_idx,throw_errors}
524+
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
525+
) where {F,throw_errors}
569526
left, complete = eval_tree_array(tree.l, cX, operators)
570527
!throw_errors && !complete && return nothing, false
571-
op = operators.unaops[op_idx]
572528
!throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false
573529
return op(left), true
574530
end
575531

576532
function deg2_eval_generic(
577-
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
578-
) where {op_idx,throw_errors}
533+
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
534+
) where {F,throw_errors}
579535
left, complete = eval_tree_array(tree.l, cX, operators)
580536
!throw_errors && !complete && return nothing, false
581537
right, complete = eval_tree_array(tree.r, cX, operators)
582538
!throw_errors && !complete && return nothing, false
583-
op = operators.binops[op_idx]
584539
!throw_errors &&
585540
!hasmethod(op, Tuple{typeof(left),typeof(right)}) &&
586541
return nothing, false

0 commit comments

Comments
 (0)