@@ -7,6 +7,7 @@ using ChainRulesCore:
7
7
ZeroTangent,
8
8
Tangent,
9
9
@thunk ,
10
+ unthunk,
10
11
canonicalize
11
12
using .. OperatorEnumModule: OperatorEnum
12
13
using .. NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
@@ -52,7 +53,8 @@ struct EvalPullback{N,A,O} <: Function
52
53
end
53
54
54
55
# TODO : Preferable to use the primal in the pullback somehow
55
- function (e:: EvalPullback )((dY, _))
56
+ function (e:: EvalPullback )((thunked_dY, _))
57
+ dY = unthunk (thunked_dY)
56
58
_, dX_constants_dY, complete = eval_grad_tree_array (
57
59
e. tree, e. X, e. operators; variable= Val (:both )
58
60
)
@@ -66,10 +68,10 @@ function (e::EvalPullback)((dY, _))
66
68
dconstants_dY = @view dX_constants_dY[(nfeatures + 1 ): end , :]
67
69
68
70
dtree = NodeTangent (
69
- e. tree, sum (j -> dconstants_dY[:, j] * dY[j], eachindex (axes (dconstants_dY, 2 )))
71
+ e. tree, sum (j -> dconstants_dY[:, j] * dY[j], eachindex (dY, axes (dconstants_dY, 2 )))
70
72
)
71
73
72
- dX = dX_dY .* reshape (dY, 1 , size (dconstants_dY, 2 ))
74
+ dX = dX_dY .* reshape (dY, 1 , length (dY ))
73
75
74
76
return (NoTangent (), dtree, dX, NoTangent ())
75
77
end
0 commit comments