Skip to content

Commit cd1fc0c

Browse files
authored
Merge pull request #50 from SymbolicML/lighter-types
Switch to UInt8/UInt16 for Node fields
2 parents 5836ba2 + e79953b commit cd1fc0c

11 files changed

+59
-63
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.12.3"
4+
version = "0.13.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

benchmark/benchmark_utils.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ function random_node(tree::Node{T})::Node{T} where {T}
88
if tree.degree == 0
99
return tree
1010
end
11-
b = 0
12-
c = 0
13-
if tree.degree >= 1
14-
b = count_nodes(tree.l)
15-
end
16-
if tree.degree == 2
17-
c = count_nodes(tree.r)
11+
b = count_nodes(tree.l)
12+
c = if tree.degree == 2
13+
count_nodes(tree.r)
14+
else
15+
0
1816
end
1917

2018
i = rand(1:(1 + b + c))
@@ -27,7 +25,7 @@ function random_node(tree::Node{T})::Node{T} where {T}
2725
return random_node(tree.r)
2826
end
2927

30-
function make_random_leaf(nfeatures::Int, ::Type{T})::Node{T} where {T}
28+
function make_random_leaf(nfeatures::Integer, ::Type{T})::Node{T} where {T}
3129
if rand() > 0.5
3230
return Node(; val=randn(T))
3331
else
@@ -37,7 +35,7 @@ end
3735

3836
# Add a random unary/binary operation to the end of a tree
3937
function append_random_op(
40-
tree::Node{T}, operators, nfeatures::Int; makeNewBinOp::Union{Bool,Nothing}=nothing
38+
tree::Node{T}, operators, nfeatures::Integer; makeNewBinOp::Union{Bool,Nothing}=nothing
4139
)::Node{T} where {T}
4240
nuna = length(operators.unaops)
4341
nbin = length(operators.binops)
@@ -66,7 +64,7 @@ function append_random_op(
6664
end
6765

6866
function gen_random_tree_fixed_size(
69-
node_count::Int, operators, nfeatures::Int, ::Type{T}
67+
node_count::Integer, operators, nfeatures::Integer, ::Type{T}
7068
)::Node{T} where {T}
7169
tree = make_random_leaf(nfeatures, T)
7270
cur_size = count_nodes(tree)

docs/src/eval.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ all variables (or, all constants). Both use forward-mode automatic, but use
9090
`Zygote.jl` to compute derivatives of each operator, so this is very efficient.
9191

9292
```@docs
93-
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
93+
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer) where {T<:Number}
9494
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
9595
```
9696

docs/src/types.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ There are a variety of constructors for `Node` objects, including:
5555

5656
```@docs
5757
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
58-
Node(op::Int, l::Node)
59-
Node(op::Int, l::Node, r::Node)
58+
Node(op::Integer, l::Node)
59+
Node(op::Integer, l::Node, r::Node)
6060
Node(var_string::String)
6161
```
6262

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ function parse_tree_to_eqs(
2727
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
2828
end
2929
# Collect the next children
30-
children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,)
30+
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
3131
# Get the operation
32-
op = tree.degree > 1 ? operators.binops[tree.op] : operators.unaops[tree.op]
32+
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
3333
# Create an N tuple of Numbers for each argument
3434
dtypes = map(x -> Number, 1:(tree.degree))
3535
#
@@ -228,7 +228,7 @@ function multiply_powers(
228228
@return_on_false complete eqn
229229
@return_on_false isgood(l) eqn
230230
n = args[2]
231-
if typeof(n) <: Int
231+
if typeof(n) <: Integer
232232
if n == 1
233233
return l, true
234234
elseif n == -1

src/Equation.jl

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap
55

66
const DEFAULT_NODE_TYPE = Float32
77

8+
#! format: off
89
"""
910
Node{T}
1011
@@ -15,16 +16,16 @@ nodes, you can evaluate or print a given expression.
1516
1617
# Fields
1718
18-
- `degree::Int`: Degree of the node. 0 for constants, 1 for
19+
- `degree::UInt8`: Degree of the node. 0 for constants, 1 for
1920
unary operators, 2 for binary operators.
2021
- `constant::Bool`: Whether the node is a constant.
2122
- `val::T`: Value of the node. If `degree==0`, and `constant==true`,
2223
this is the value of the constant. It has a type specified by the
2324
overall type of the `Node` (e.g., `Float64`).
24-
- `feature::Int` (optional): Index of the feature to use in the
25+
- `feature::UInt16`: Index of the feature to use in the
2526
case of a feature node. Only used if `degree==0` and `constant==false`.
2627
Only defined if `degree == 0 && constant == false`.
27-
- `op::Int`: If `degree==1`, this is the index of the operator
28+
- `op::UInt8`: If `degree==1`, this is the index of the operator
2829
in `operators.unaops`. If `degree==2`, this is the index of the
2930
operator in `operators.binops`. In other words, this is an enum
3031
of the operators, and is dependent on the specific `OperatorEnum`
@@ -36,36 +37,32 @@ nodes, you can evaluate or print a given expression.
3637
argument to the binary operator.
3738
"""
3839
mutable struct Node{T}
39-
degree::Int # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
40+
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
4041
constant::Bool # false if variable
4142
val::Union{T,Nothing} # If is a constant, this stores the actual value
4243
# ------------------- (possibly undefined below)
43-
feature::Int # If is a variable (e.g., x in cos(x)), this stores the feature index.
44-
op::Int # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
44+
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
45+
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
4546
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
4647
r::Node{T} # Right child node. Only defined for degree=2.
4748

4849
#################
4950
## Constructors:
5051
#################
51-
Node(d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
52-
Node(::Type{_T}, d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
53-
Node(::Type{_T}, d::Int, c::Bool, v::Nothing, f::Int) where {_T} = new{_T}(d, c, v, f)
54-
function Node(d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}) where {_T}
55-
return new{_T}(d, c, v, f, o, l)
56-
end
57-
function Node(
58-
d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}, r::Node{_T}
59-
) where {_T}
60-
return new{_T}(d, c, v, f, o, l, r)
61-
end
52+
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
53+
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
54+
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
55+
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
56+
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)
57+
6258
end
6359
################################################################################
60+
#! format: on
6461

6562
include("base.jl")
6663

6764
"""
68-
Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T}
65+
Node([::Type{T}]; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
6966
7067
Create a leaf node: either a constant, or a variable.
7168
@@ -115,18 +112,18 @@ function Node(
115112
end
116113

117114
"""
118-
Node(op::Int, l::Node)
115+
Node(op::Integer, l::Node)
119116
120117
Apply unary operator `op` (enumerating over the order given) to `Node` `l`
121118
"""
122-
Node(op::Int, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
119+
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
123120

124121
"""
125-
Node(op::Int, l::Node, r::Node)
122+
Node(op::Integer, l::Node, r::Node)
126123
127124
Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r`
128125
"""
129-
function Node(op::Int, l::Node{T1}, r::Node{T2}) where {T1,T2}
126+
function Node(op::Integer, l::Node{T1}, r::Node{T2}) where {T1,T2}
130127
# Get highest type:
131128
if T1 != T2
132129
T = promote_type(T1, T2)
@@ -141,7 +138,7 @@ end
141138
142139
Create a variable node, using the format `"x1"` to mean feature 1
143140
"""
144-
Node(var_string::String) = Node(; feature=parse(Int, var_string[2:end]))
141+
Node(var_string::String) = Node(; feature=parse(UInt16, var_string[2:end]))
145142

146143
"""
147144
Node(var_string::String, variable_names::Array{String, 1})
@@ -261,7 +258,7 @@ Convert an equation to a string.
261258
262259
# Keyword Arguments
263260
- `bracketed`: (optional) whether to put brackets around the outside.
264-
- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::Int, variable_names)`.
261+
- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::UInt8, variable_names)`.
265262
- `f_constant`: (optional) function to convert a constant to a string, of the form `(val, bracketed::Bool)`
266263
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature.
267264
"""

src/EquationUtils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ has_constants(tree::Node) = any(is_node_constant, tree)
4646
4747
Check if a tree has any operators.
4848
"""
49-
has_operators(tree::Node) = tree.degree !== 0
49+
has_operators(tree::Node) = tree.degree != 0
5050

5151
"""
5252
is_constant(tree::Node)::Bool
5353
5454
Check if an expression is a constant numerical value, or
5555
whether it depends on input features.
5656
"""
57-
is_constant(tree::Node) = all(t -> t.degree !== 0 || t.constant, tree)
57+
is_constant(tree::Node) = all(t -> t.degree != 0 || t.constant, tree)
5858

5959
"""
6060
get_constants(tree::Node{T})::Vector{T} where {T}
@@ -92,25 +92,25 @@ end
9292
# This will mirror a Node struct, rather
9393
# than adding a new attribute to Node.
9494
mutable struct NodeIndex
95-
constant_index::Int # Index of this constant (if a constant exists here)
95+
constant_index::UInt16 # Index of this constant (if a constant exists here)
9696
l::NodeIndex
9797
r::NodeIndex
9898

9999
NodeIndex() = new()
100100
end
101101

102102
function index_constants(tree::Node)::NodeIndex
103-
return index_constants(tree, 0)
103+
return index_constants(tree, UInt16(0))
104104
end
105105

106-
function index_constants(tree::Node, left_index::Int)::NodeIndex
106+
function index_constants(tree::Node, left_index)::NodeIndex
107107
index_tree = NodeIndex()
108108
index_constants!(tree, index_tree, left_index)
109109
return index_tree
110110
end
111111

112112
# Count how many constants to the left of this node, and put them in a tree
113-
function index_constants!(tree::Node, index_tree::NodeIndex, left_index::Int)
113+
function index_constants!(tree::Node, index_tree::NodeIndex, left_index)
114114
if tree.degree == 0
115115
if tree.constant
116116
index_tree.constant_index = left_index + 1

src/EvaluateEquationDerivative.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function assert_autodiff_enabled(operators::OperatorEnum)
1818
end
1919

2020
"""
21-
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int; turbo::Bool=false)
21+
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false)
2222
2323
Compute the forward derivative of an expression, using a similar
2424
structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -31,7 +31,7 @@ respect to `x1`.
3131
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
3232
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
3333
must be `true`. This is needed to create the derivative operations.
34-
- `direction::Int`: The index of the variable to take the derivative with respect to.
34+
- `direction::Integer`: The index of the variable to take the derivative with respect to.
3535
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
3636
3737
# Returns
@@ -43,7 +43,7 @@ function eval_diff_tree_array(
4343
tree::Node{T},
4444
cX::AbstractMatrix{T},
4545
operators::OperatorEnum,
46-
direction::Int;
46+
direction::Integer;
4747
turbo::Bool=false,
4848
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
4949
assert_autodiff_enabled(operators)
@@ -57,7 +57,7 @@ function eval_diff_tree_array(
5757
tree::Node{T1},
5858
cX::AbstractMatrix{T2},
5959
operators::OperatorEnum,
60-
direction::Int;
60+
direction::Integer;
6161
turbo::Bool=false,
6262
) where {T1<:Number,T2<:Number}
6363
T = promote_type(T1, T2)
@@ -71,7 +71,7 @@ function _eval_diff_tree_array(
7171
tree::Node{T},
7272
cX::AbstractMatrix{T},
7373
operators::OperatorEnum,
74-
direction::Int,
74+
direction::Integer,
7575
::Val{turbo},
7676
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,turbo}
7777
evaluation, derivative, complete = if tree.degree == 0
@@ -102,7 +102,7 @@ function _eval_diff_tree_array(
102102
end
103103

104104
function diff_deg0_eval(
105-
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
105+
tree::Node{T}, cX::AbstractMatrix{T}, direction::Integer
106106
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
107107
const_part = deg0_eval(tree, cX)[1]
108108
derivative_part = if ((!tree.constant) && tree.feature == direction)
@@ -119,7 +119,7 @@ function diff_deg1_eval(
119119
op::F,
120120
diff_op::dF,
121121
operators::OperatorEnum,
122-
direction::Int,
122+
direction::Integer,
123123
::Val{turbo},
124124
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
125125
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
@@ -144,7 +144,7 @@ function diff_deg2_eval(
144144
op::F,
145145
diff_op::dF,
146146
operators::OperatorEnum,
147-
direction::Int,
147+
direction::Integer,
148148
::Val{turbo},
149149
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
150150
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
@@ -200,7 +200,7 @@ function eval_grad_tree_array(
200200
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
201201
assert_autodiff_enabled(operators)
202202
n_gradients = variable ? size(cX, 1) : count_constants(tree)
203-
index_tree = index_constants(tree, 0)
203+
index_tree = index_constants(tree, UInt16(0))
204204
return eval_grad_tree_array(
205205
tree,
206206
Val(n_gradients),

src/OperatorEnumConstruction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradien
77
import ..EvaluationHelpersModule: _grad_evaluator
88

99
"""Used to set a default value for `operators` for ease of use."""
10-
@enum AvailableOperatorTypes begin
10+
@enum AvailableOperatorTypes::UInt8 begin
1111
IsNothing
1212
IsOperatorEnum
1313
IsGenericOperatorEnum
@@ -19,8 +19,8 @@ end
1919

2020
const LATEST_OPERATORS = Ref{Union{Nothing,AbstractOperatorEnum}}(nothing)
2121
const LATEST_OPERATORS_TYPE = Ref{AvailableOperatorTypes}(IsNothing)
22-
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,Int}()
23-
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,Int}()
22+
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,fieldtype(Node{Float64}, :op)}()
23+
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,fieldtype(Node{Float64}, :op)}()
2424
const ALREADY_DEFINED_UNARY_OPERATORS = (;
2525
operator_enum=Dict{Function,Bool}(), generic_operator_enum=Dict{Function,Bool}()
2626
)

src/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ julia> tree_mapreduce(t -> 1, (p, c...) -> p + max(c...), tree) # compute depth
5252
5
5353
5454
julia> tree_mapreduce(vcat, tree) do t
55-
t.degree == 2 ? [t.op] : Int[]
55+
t.degree == 2 ? [t.op] : UInt8[]
5656
end # Get list of binary operators used. (regular mapreduce also works)
57-
2-element Vector{Int64}:
57+
2-element Vector{UInt8}:
5858
1
5959
2
6060

test/test_base.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ end
114114
@test length(unique(map(objectid, copy_node(tree; preserve_sharing=true)))) == 24 - 3
115115
map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree)
116116
@test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2
117-
@test typeof(map(t -> t.degree, ctree, Int)) == Vector{Int}
118-
@test first(map(t -> t.degree, ctree, Int)) == 2
117+
local T = fieldtype(typeof(ctree), :degree)
118+
@test typeof(map(t -> t.degree, ctree, T)) == Vector{T}
119+
@test first(map(t -> t.degree, ctree, T)) == 2
119120
end
120121

121122
@testset "in" begin

0 commit comments

Comments
 (0)