Skip to content

Commit 9e95f05

Browse files
authored
Merge pull request #84 from SymbolicML/parametric-expressions2
Fix additional ambiguous methods for Expression interface
2 parents 905ebc0 + 135b9db commit 9e95f05

File tree

5 files changed

+12
-8
lines changed

5 files changed

+12
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.18.1"
4+
version = "0.18.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Interfaces.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using ..NodeModule:
2626
filter_map!
2727
using ..NodeUtilsModule:
2828
NodeIndex,
29+
is_node_constant,
2930
count_constants,
3031
count_depth,
3132
index_constants,
@@ -273,6 +274,9 @@ end
273274
function _check_count_depth(tree::AbstractExpressionNode)
274275
return count_depth(tree) isa Int64
275276
end
277+
function _check_is_node_constant(tree::AbstractExpressionNode)
278+
return is_node_constant(tree) isa Bool
279+
end
276280
function _check_count_constants(tree::AbstractExpressionNode)
277281
return count_constants(tree) isa Int64
278282
end
@@ -324,6 +328,7 @@ ni_components = (
324328
branch_hash = "computes the hash of a branch node" => _check_branch_hash,
325329
branch_equal = "checks equality of two branch nodes" => _check_branch_equal,
326330
count_depth = "calculates the depth of the tree" => _check_count_depth,
331+
is_node_constant = "checks if the node is a constant" => _check_is_node_constant,
327332
count_constants = "counts the number of constants" => _check_count_constants,
328333
filter_map = "applies a filter and map function to the tree" => _check_filter_map,
329334
has_constants = "checks if the tree has constants" => _check_has_constants,

src/ParametricExpression.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import ..NodeUtilsModule:
1515
get_constants,
1616
set_constants!
1717
import ..StringsModule: string_tree
18-
import ..SimplifyModule: combine_operators, simplify_tree!
1918
import ..EvaluateModule: eval_tree_array
2019
import ..EvaluateDerivativeModule: eval_grad_tree_array
2120
import ..EvaluationHelpersModule: _grad_evaluator

src/PatchMethods.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module PatchMethodsModule
22

3+
using DynamicExpressions: get_contents, with_contents
34
using ..OperatorEnumModule: AbstractOperatorEnum
45
using ..NodeModule: constructorof
56
using ..ExpressionModule: Expression, get_tree, get_operators
@@ -11,17 +12,15 @@ function combine_operators(
1112
ex::Union{Expression{T,N},ParametricExpression{T,N}},
1213
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
1314
) where {T,N}
14-
return constructorof(typeof(ex))(
15-
combine_operators(get_tree(ex)::N, get_operators(ex, operators)), ex.metadata
15+
return with_contents(
16+
ex, combine_operators(get_contents(ex), get_operators(ex, operators))
1617
)
1718
end
1819
function simplify_tree!(
1920
ex::Union{Expression{T,N},ParametricExpression{T,N}},
2021
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
2122
) where {T,N}
22-
return constructorof(typeof(ex))(
23-
simplify_tree!(get_tree(ex)::N, get_operators(ex, operators)), ex.metadata
24-
)
23+
return with_contents(ex, simplify_tree!(get_contents(ex), get_operators(ex, operators)))
2524
end
2625

2726
end

src/Simplify.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ is_commutative(_) = false
1515
is_subtraction(::typeof(-)) = true
1616
is_subtraction(_) = false
1717

18-
# This is only defined for `Node` as it is not possible for
18+
combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
19+
# This is only defined for `Node` as it is not possible for, e.g.,
1920
# `GraphNode`.
2021
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
2122
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.

0 commit comments

Comments
 (0)