Skip to content

Commit cc313ad

Browse files
authored
Merge pull request #7 from SymbolicML/equality-operator
Add equality operator
2 parents 6cdd2a5 + d4b951b commit cc313ad

File tree

5 files changed

+86
-1
lines changed

5 files changed

+86
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.2.3"
4+
version = "0.3.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/DynamicExpressions.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,11 @@ using Reexport
3131
@reexport import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node
3232
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
3333

34+
import TOML: parsefile
35+
36+
const PACKAGE_VERSION = let
37+
project = parsefile(joinpath(pkgdir(@__MODULE__), "Project.toml"))
38+
VersionNumber(project["version"])
39+
end
40+
3441
end

src/Equation.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,34 @@ function Base.hash(tree::Node{T})::UInt where {T}
396396
end
397397
end
398398

399+
function is_equal(a::Node{T}, b::Node{T})::Bool where {T}
400+
if a.degree == 0
401+
b.degree != 0 && return false
402+
if a.constant
403+
!(b.constant) && return false
404+
return a.val::T == b.val::T
405+
else
406+
b.constant && return false
407+
return a.feature == b.feature
408+
end
409+
elseif a.degree == 1
410+
b.degree != 1 && return false
411+
a.op != b.op && return false
412+
return is_equal(a.l, b.l)
413+
else
414+
b.degree != 2 && return false
415+
a.op != b.op && return false
416+
return is_equal(a.l, b.l) && is_equal(a.r, b.r)
417+
end
418+
end
419+
420+
function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T}
421+
return is_equal(a, b)
422+
end
423+
424+
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
425+
T = promote_type(T1, T2)
426+
return is_equal(convert(Node{T}, a), convert(Node{T}, b))
427+
end
428+
399429
end

test/test_equality.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using DynamicExpressions
2+
using Test
3+
4+
operators = OperatorEnum(;
5+
binary_operators=[+, *, -, /], unary_operators=[sin, cos, exp, log]
6+
)
7+
8+
# Create a big expression, using those operators:
9+
x1 = Node(; feature=1)
10+
x2 = Node(; feature=2)
11+
x3 = Node(; feature=3)
12+
13+
tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
14+
same_tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
15+
@test tree == same_tree
16+
17+
copied_tree = copy_node(tree; preserve_topology=true)
18+
@test tree == copied_tree
19+
20+
copied_tree2 = copy_node(tree; preserve_topology=false)
21+
@test tree == copied_tree2
22+
23+
modifed_tree = x1 + x2 * x1 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
24+
@test tree != modifed_tree
25+
modifed_tree2 = x1 + x2 * x3 - log(x2 * 3.1) + 1.5 * cos(x2 / x1)
26+
@test tree != modifed_tree2
27+
modifed_tree3 = x1 + x2 * x3 - exp(x2 * 3.2) + 1.5 * cos(x2 / x1)
28+
@test tree != modifed_tree3
29+
modified_tree4 = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 * x1)
30+
@test tree != modified_tree4
31+
32+
# Order matters!
33+
modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2)
34+
@test tree != modified_tree5
35+
36+
# Type should not matter if equivalent in the promoted type:
37+
f64_tree = x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)
38+
f32_tree = x1 + x2 * x3 - log(x2 * 3.0f0) + 1.5f0 * cos(x2 / x1)
39+
@test typeof(f64_tree) == Node{Float64}
40+
@test typeof(f32_tree) == Node{Float32}
41+
42+
@test convert(Node{Float64}, f32_tree) == f64_tree
43+
44+
@test f64_tree == f32_tree

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ end
5555
@safetestset "Test error handling" begin
5656
include("test_error_handling.jl")
5757
end
58+
59+
@safetestset "Test equality operator" begin
60+
include("test_equality.jl")
61+
end

0 commit comments

Comments
 (0)