Skip to content

Commit 46eb0de

Browse files
authored
Merge pull request #114 from SymbolicML/allocs-functions
Create preallocation utility functions for expressions
2 parents e7955d6 + 16c5ef0 commit 46eb0de

10 files changed

+222
-93
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using DispatchDoctor: @stable, @unstable
99
include("OperatorEnum.jl")
1010
include("Node.jl")
1111
include("NodeUtils.jl")
12+
include("NodePreallocation.jl")
1213
include("Strings.jl")
1314
include("Evaluate.jl")
1415
include("EvaluateDerivative.jl")
@@ -41,11 +42,11 @@ import .ValueInterfaceModule:
4142
GraphNode,
4243
Node,
4344
copy_node,
44-
copy_node!,
4545
set_node!,
4646
tree_mapreduce,
4747
filter_map,
4848
filter_map!
49+
import .NodePreallocationModule: allocate_container, copy_into!
4950
import .NodeModule:
5051
constructorof,
5152
with_type_parameters,

src/Expression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import ..NodeUtilsModule:
1919
count_scalar_constants,
2020
get_scalar_constants,
2121
set_scalar_constants!
22+
import ..NodePreallocationModule: copy_into!, allocate_container
2223
import ..EvaluateModule: eval_tree_array, differentiable_eval_tree_array
2324
import ..EvaluateDerivativeModule: eval_grad_tree_array
2425
import ..EvaluationHelpersModule: _grad_evaluator
@@ -502,4 +503,23 @@ function (ex::AbstractExpression)(
502503
return get_tree(ex)(X, get_operators(ex, operators); kws...)
503504
end
504505

506+
# We don't require users to overload this, as it's not part of the required interface.
507+
# Also, there's no way to generally do this from the required interface, so for backwards
508+
# compatibility, we just return nothing.
509+
# COV_EXCL_START
510+
function copy_into!(::Nothing, src::AbstractExpression)
511+
return copy(src)
512+
end
513+
function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing)
514+
return nothing
515+
end
516+
# COV_EXCL_STOP
517+
function allocate_container(prototype::Expression, n::Union{Nothing,Integer}=nothing)
518+
return (; tree=allocate_container(get_contents(prototype), n))
519+
end
520+
function copy_into!(dest::NamedTuple, src::Expression)
521+
tree = copy_into!(dest.tree, get_contents(src))
522+
return with_contents(src, tree)
523+
end
524+
505525
end

src/Interfaces.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ using ..NodeModule:
1313
default_allocator,
1414
with_type_parameters,
1515
leaf_copy,
16-
leaf_copy!,
1716
leaf_convert,
1817
leaf_hash,
1918
leaf_equal,
2019
branch_copy,
21-
branch_copy!,
2220
branch_convert,
2321
branch_hash,
2422
branch_equal,
@@ -38,6 +36,8 @@ using ..NodeUtilsModule:
3836
has_constants,
3937
get_scalar_constants,
4038
set_scalar_constants!
39+
using ..NodePreallocationModule:
40+
copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container
4141
using ..StringsModule: string_tree
4242
using ..EvaluateModule: eval_tree_array
4343
using ..EvaluateDerivativeModule: eval_grad_tree_array
@@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression)
9696
end
9797

9898
## optional
99+
function _check_copy_into!(ex::AbstractExpression)
100+
container = allocate_container(ex)
101+
prealloc_ex = copy_into!(container, ex)
102+
return container !== nothing && prealloc_ex == ex && prealloc_ex !== ex
103+
end
99104
function _check_count_nodes(ex::AbstractExpression)
100105
return count_nodes(ex) isa Int64
101106
end
@@ -156,6 +161,7 @@ ei_components = (
156161
with_metadata = "returns the expression with different metadata" => _check_with_metadata,
157162
),
158163
optional = (
164+
copy_into! = "copies an expression into a preallocated container" => _check_copy_into!,
159165
count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes,
160166
count_constant_nodes = "counts the number of constant nodes in the expression tree" => _check_count_constant_nodes,
161167
count_depth = "calculates the depth of the expression tree" => _check_count_depth,
@@ -260,14 +266,19 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode)
260266
end
261267

262268
## optional
269+
function _check_copy_into!(tree::AbstractExpressionNode)
270+
container = allocate_container(tree)
271+
prealloc_tree = copy_into!(container, tree)
272+
return container !== nothing && prealloc_tree == tree && prealloc_tree !== container
273+
end
263274
function _check_leaf_copy(tree::AbstractExpressionNode)
264275
tree.degree != 0 && return true
265276
return leaf_copy(tree) isa typeof(tree)
266277
end
267-
function _check_leaf_copy!(tree::AbstractExpressionNode{T}) where {T}
278+
function _check_leaf_copy_into!(tree::AbstractExpressionNode{T}) where {T}
268279
tree.degree != 0 && return true
269280
new_leaf = constructorof(typeof(tree))(; val=zero(T))
270-
ret = leaf_copy!(new_leaf, tree)
281+
ret = leaf_copy_into!(new_leaf, tree)
271282
return new_leaf == tree && ret === new_leaf
272283
end
273284
function _check_leaf_convert(tree::AbstractExpressionNode)
@@ -292,16 +303,16 @@ function _check_branch_copy(tree::AbstractExpressionNode)
292303
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
293304
end
294305
end
295-
function _check_branch_copy!(tree::AbstractExpressionNode{T}) where {T}
306+
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
296307
if tree.degree == 0
297308
return true
298309
end
299310
new_branch = constructorof(typeof(tree))(; val=zero(T))
300311
if tree.degree == 1
301-
ret = branch_copy!(new_branch, tree, copy(tree.l))
312+
ret = branch_copy_into!(new_branch, tree, copy(tree.l))
302313
return new_branch == tree && ret === new_branch
303314
else
304-
ret = branch_copy!(new_branch, tree, copy(tree.l), copy(tree.r))
315+
ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r))
305316
return new_branch == tree && ret === new_branch
306317
end
307318
end
@@ -372,13 +383,14 @@ ni_components = (
372383
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
373384
),
374385
optional = (
386+
copy_into! = "copies a node into a preallocated container" => _check_copy_into!,
375387
leaf_copy = "copies a leaf node" => _check_leaf_copy,
376-
leaf_copy! = "copies a leaf node in-place" => _check_leaf_copy!,
388+
leaf_copy_into! = "copies a leaf node in-place" => _check_leaf_copy_into!,
377389
leaf_convert = "converts a leaf node" => _check_leaf_convert,
378390
leaf_hash = "computes the hash of a leaf node" => _check_leaf_hash,
379391
leaf_equal = "checks equality of two leaf nodes" => _check_leaf_equal,
380392
branch_copy = "copies a branch node" => _check_branch_copy,
381-
branch_copy! = "copies a branch node in-place" => _check_branch_copy!,
393+
branch_copy_into! = "copies a branch node in-place" => _check_branch_copy_into!,
382394
branch_convert = "converts a branch node" => _check_branch_convert,
383395
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
384396
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
@@ -419,7 +431,7 @@ ni_description = (
419431
[Arguments()]
420432
)
421433
@implements(
422-
NodeInterface{all_ni_methods_except((:leaf_copy!, :branch_copy!))},
434+
NodeInterface{all_ni_methods_except(())},
423435
GraphNode,
424436
[Arguments()]
425437
)

src/Node.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,23 +321,12 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {
321321
return GraphNode{promote_type(T1, T2)}
322322
end
323323

324-
# TODO: Verify using this helps with garbage collection
325-
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N()
326-
327324
"""
328325
set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}
329326
330327
Set every field of `tree` equal to the corresponding field of `new_tree`.
331328
"""
332329
function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode)
333-
# First, ensure we free some memory:
334-
if new_tree.degree < 2 && tree.degree == 2
335-
tree.r = create_dummy_node(typeof(tree))
336-
end
337-
if new_tree.degree < 1 && tree.degree >= 1
338-
tree.l = create_dummy_node(typeof(tree))
339-
end
340-
341330
tree.degree = new_tree.degree
342331
if new_tree.degree == 0
343332
tree.constant = new_tree.constant

src/NodePreallocation.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
module NodePreallocationModule
2+
3+
using ..NodeModule:
4+
AbstractExpressionNode,
5+
with_type_parameters,
6+
tree_mapreduce,
7+
leaf_copy,
8+
branch_copy,
9+
set_node!
10+
11+
"""
12+
allocate_container(prototype::AbstractExpressionNode, n=nothing)
13+
14+
Preallocate an array of `n` empty nodes matching the type of `prototype`.
15+
If `n` is not provided, it will be computed from `length(prototype)`.
16+
17+
A given return value of this will be passed to `copy_into!` as the first argument,
18+
so it should be compatible.
19+
"""
20+
function allocate_container(
21+
prototype::N, n::Union{Nothing,Integer}=nothing
22+
) where {T,N<:AbstractExpressionNode{T}}
23+
num_nodes = @something(n, length(prototype))
24+
return N[with_type_parameters(N, T)() for _ in 1:num_nodes]
25+
end
26+
27+
"""
28+
copy_into!(dest::AbstractArray{N}, src::N) where {N<:AbstractExpressionNode}
29+
30+
Copy a node, recursively copying all children nodes, in-place to a preallocated container.
31+
This should result in no extra allocations.
32+
"""
33+
function copy_into!(
34+
dest::AbstractArray{N}, src::N; ref::Union{Nothing,Base.RefValue{<:Integer}}=nothing
35+
) where {N<:AbstractExpressionNode}
36+
_ref = if ref === nothing
37+
Ref(0)
38+
else
39+
ref.x = 0
40+
ref
41+
end
42+
return tree_mapreduce(
43+
leaf -> leaf_copy_into!(@inbounds(dest[_ref.x += 1]), leaf),
44+
identity,
45+
((p, c::Vararg{Any,M}) where {M}) ->
46+
branch_copy_into!(@inbounds(dest[_ref.x += 1]), p, c...),
47+
src,
48+
N,
49+
)
50+
end
51+
# COV_EXCL_START
52+
function leaf_copy_into!(dest::N, src::N) where {N<:AbstractExpressionNode}
53+
set_node!(dest, src)
54+
return dest
55+
end
56+
# COV_EXCL_STOP
57+
function branch_copy_into!(
58+
dest::N, src::N, children::Vararg{N,M}
59+
) where {N<:AbstractExpressionNode,M}
60+
dest.degree = M
61+
dest.op = src.op
62+
dest.l = children[1]
63+
if M == 2
64+
dest.r = children[2]
65+
end
66+
return dest
67+
end
68+
69+
end

src/ParametricExpression.jl

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
8-
using ..ExpressionModule: AbstractExpression, Metadata
8+
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
99
using ..ChainRulesModule: NodeTangent
1010

1111
import ..NodeModule:
1212
constructorof,
1313
with_type_parameters,
1414
preserve_sharing,
1515
leaf_copy,
16-
leaf_copy!,
1716
leaf_convert,
1817
leaf_hash,
1918
leaf_equal,
20-
branch_copy!
19+
set_node!
20+
import ..NodePreallocationModule: copy_into!, allocate_container
2121
import ..NodeUtilsModule:
2222
count_constant_nodes,
2323
index_constant_nodes,
@@ -124,21 +124,29 @@ function leaf_copy(t::ParametricNode{T}) where {T}
124124
return n
125125
end
126126
end
127-
function leaf_copy!(dest::N, src::N) where {T,N<:ParametricNode{T}}
128-
dest.degree = 0
129-
if src.constant
130-
dest.constant = true
131-
dest.val = src.val
132-
elseif !src.is_parameter
133-
dest.constant = false
134-
dest.is_parameter = false
135-
dest.feature = src.feature
127+
function set_node!(tree::ParametricNode, new_tree::ParametricNode)
128+
tree.degree = new_tree.degree
129+
if new_tree.degree == 0
130+
if new_tree.constant
131+
tree.constant = true
132+
tree.val = new_tree.val
133+
elseif !new_tree.is_parameter
134+
tree.constant = false
135+
tree.is_parameter = false
136+
tree.feature = new_tree.feature
137+
else
138+
tree.constant = false
139+
tree.is_parameter = true
140+
tree.parameter = new_tree.parameter
141+
end
136142
else
137-
dest.constant = false
138-
dest.is_parameter = true
139-
dest.parameter = src.parameter
143+
tree.op = new_tree.op
144+
tree.l = new_tree.l
145+
if new_tree.degree == 2
146+
tree.r = new_tree.r
147+
end
140148
end
141-
return dest
149+
return nothing
142150
end
143151
function leaf_convert(::Type{N}, t::ParametricNode) where {T,N<:ParametricNode{T}}
144152
if t.constant
@@ -444,6 +452,28 @@ end
444452
return node_type(; val=ex)
445453
end
446454
end
455+
function allocate_container(
456+
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
457+
)
458+
return (;
459+
tree=allocate_container(get_contents(prototype), n),
460+
parameters=similar(get_metadata(prototype).parameters),
461+
)
462+
end
463+
function copy_into!(dest::NamedTuple, src::ParametricExpression)
464+
new_tree = copy_into!(dest.tree, get_contents(src))
465+
metadata = get_metadata(src)
466+
new_parameters = dest.parameters
467+
new_parameters .= metadata.parameters
468+
new_metadata = Metadata((;
469+
operators=metadata.operators,
470+
variable_names=metadata.variable_names,
471+
parameters=new_parameters,
472+
parameter_names=metadata.parameter_names,
473+
))
474+
# TODO: Better interface for this^
475+
return with_metadata(with_contents(src, new_tree), new_metadata)
476+
end
447477
###############################################################################
448478

449479
end

src/StructuredExpression.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ using ..ExpressionModule: AbstractExpression, Metadata, node_type
66
using ..ChainRulesModule: NodeTangent
77

88
import ..NodeModule: constructorof
9+
import ..NodePreallocationModule: copy_into!, allocate_container
910
import ..ExpressionModule:
1011
get_contents,
1112
get_metadata,
1213
get_tree,
1314
get_operators,
1415
get_variable_names,
16+
with_contents,
1517
Metadata,
1618
_copy,
1719
_data,
@@ -164,4 +166,16 @@ function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
164166
return e
165167
end
166168

169+
function allocate_container(
170+
e::AbstractStructuredExpression, n::Union{Nothing,Integer}=nothing
171+
)
172+
ts = get_contents(e)
173+
return (; trees=NamedTuple{keys(ts)}(map(t -> allocate_container(t, n), values(ts))))
174+
end
175+
function copy_into!(dest::NamedTuple, src::AbstractStructuredExpression)
176+
ts = get_contents(src)
177+
new_contents = NamedTuple{keys(ts)}(map(copy_into!, values(dest.trees), values(ts)))
178+
return with_contents(src, new_contents)
179+
end
180+
167181
end

0 commit comments

Comments
 (0)