Skip to content

Commit 4b22d2c

Browse files
committed
fix: unthunk according to JuliaDiff/ChainRulesCore.jl#687
1 parent d2d7d07 commit 4b22d2c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/ChainRules.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using ChainRulesCore:
77
ZeroTangent,
88
Tangent,
99
@thunk,
10+
unthunk,
1011
canonicalize
1112
using ..OperatorEnumModule: OperatorEnum
1213
using ..NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
@@ -52,7 +53,8 @@ struct EvalPullback{N,A,O} <: Function
5253
end
5354

5455
# TODO: Preferable to use the primal in the pullback somehow
55-
function (e::EvalPullback)((dY, _))
56+
function (e::EvalPullback)((thunked_dY, _))
57+
dY = unthunk(thunked_dY)
5658
_, dX_constants_dY, complete = eval_grad_tree_array(
5759
e.tree, e.X, e.operators; variable=Val(:both)
5860
)
@@ -66,10 +68,10 @@ function (e::EvalPullback)((dY, _))
6668
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]
6769

6870
dtree = NodeTangent(
69-
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(axes(dconstants_dY, 2)))
71+
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
7072
)
7173

72-
dX = dX_dY .* reshape(dY, 1, size(dconstants_dY, 2))
74+
dX = dX_dY .* reshape(dY, 1, length(dY))
7375

7476
return (NoTangent(), dtree, dX, NoTangent())
7577
end

0 commit comments

Comments
 (0)