Skip to content

Commit fb1141f

Browse files
committed
WIP: Fix llvmcall recursion
1 parent b4b88c9 commit fb1141f

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/overdub.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,24 @@ function overdub_pass!(reflection::Reflection,
318318
end)
319319
end
320320

321+
#=== mark all `llvmcall`s as nooverdub ===#
322+
# TODO: this only works for: `Intrinsics.llvmcall` and not `Core.Intrinsics.llvmcall`
323+
# since there is a getproperty call in the way.
324+
# TODO: Need to fix for `istaggingenabled == true`
325+
if !iskwfunc && !istaggingenabled
326+
insert_statements!(overdubbed_code, overdubbed_codelocs,
327+
(x, i) -> begin
328+
if Base.Meta.isexpr(x, :call) &&
329+
is_ir_element(x.args[1], GlobalRef(Core.Intrinsics, :llvmcall), overdubbed_code)
330+
return 1
331+
end
332+
return nothing
333+
end,
334+
(x, i) -> begin
335+
[Expr(:call, Expr(:nooverdub, GlobalRef(Core.Intrinsics, :llvmcall)), x.args[2:end]...)]
336+
end)
337+
end
338+
321339
#=== untag all `foreigncall` SSAValue/SlotNumber arguments if tagging is enabled ===#
322340

323341
if istaggingenabled && !iskwfunc

test/misctests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,12 @@ callback()
633633

634634
println("done (took ", time() - before_time, " seconds)")
635635

636+
#############################################################################################
636637
# Test overdubbing of a call overload invoke
637638

639+
print(" running CtxCallOverload test...")
640+
before_time = time()
641+
638642
using LinearAlgebra
639643

640644
struct Dense{F,S,T}
@@ -663,3 +667,38 @@ let d = Dense(3,3)
663667
data = rand(3)
664668
Cassette.overdub(CtxCallOverload(), d, data)
665669
end
670+
671+
println("done (took ", time() - before_time, " seconds)")
672+
673+
#############################################################################################
674+
675+
print(" running LLVMCallCtx test...")
676+
before_time = time()
677+
using Cassette
678+
Cassette.@context LLVMCallCtx
679+
680+
# This overdub does nothing
681+
@noinline function Cassette.overdub(ctx::LLVMCallCtx, f, args...)
682+
if Cassette.canrecurse(ctx, f, args...)
683+
Cassette.recurse(ctx, f, args...)
684+
else
685+
Cassette.fallback(ctx, f, args...)
686+
end
687+
end
688+
689+
import Core.Intrinsics
690+
function llvm_sin(x::Float64)
691+
# Needs fix for Core.Intrinsics.llvmcall
692+
Intrinsics.llvmcall(
693+
(
694+
"""declare double @llvm.sin.f64(double)""",
695+
"""%2 = call double @llvm.sin.f64(double %0)
696+
ret double %2"""
697+
),
698+
Float64, Tuple{Float64}, x
699+
)
700+
end
701+
702+
Cassette.@overdub LLVMCallCtx() llvm_sin(4.0)
703+
704+
println("done (took ", time() - before_time, " seconds)")

0 commit comments

Comments
 (0)