Skip to content

Commit 5678e3d

Browse files
committed
cleanup implementation
1 parent fb1141f commit 5678e3d

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

src/overdub.jl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,20 +319,59 @@ function overdub_pass!(reflection::Reflection,
319319
end
320320

321321
#=== mark all `llvmcall`s as nooverdub ===#
322-
# TODO: this only works for: `Intrinsics.llvmcall` and not `Core.Intrinsics.llvmcall`
323-
# since there is a getproperty call in the way.
322+
function unravel_intrinsics(x)
323+
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
324+
if Base.Meta.isexpr(stmt, :call)
325+
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
326+
f = applycall ? stmt.args[2] : stmt.args[1]
327+
f = ir_element(f, overdubbed_code)
328+
if f isa Expr && Base.Meta.isexpr(f, :call) &&
329+
is_ir_element(f.args[1], GlobalRef(Base, :getproperty), overdubbed_code)
330+
331+
# resolve getproperty here
332+
mod = ir_element(f.args[2], overdubbed_code)
333+
if mod isa GlobalRef
334+
mod = resolve_early(mod) # returns nothing if fails
335+
mod === nothing && return nothing
336+
end
337+
fname = ir_element(f.args[3], overdubbed_code)
338+
if fname isa QuoteNode
339+
fname = fname.value
340+
end
341+
f = GlobalRef(mod, fname)
342+
end
343+
if f isa GlobalRef
344+
f = resolve_early(f)
345+
end
346+
return f
347+
end
348+
return nothing
349+
end
350+
324351
# TODO: Need to fix for `istaggingenabled == true`
352+
# TODO: add user-facing flag to do this for all intrinsics
325353
if !iskwfunc && !istaggingenabled
326354
insert_statements!(overdubbed_code, overdubbed_codelocs,
327355
(x, i) -> begin
328-
if Base.Meta.isexpr(x, :call) &&
329-
is_ir_element(x.args[1], GlobalRef(Core.Intrinsics, :llvmcall), overdubbed_code)
356+
intrinsic = unravel_intrinsics(x)
357+
if intrinsic === nothing
358+
return nothing
359+
end
360+
if intrinsic === Core.Intrinsics.llvmcall
330361
return 1
331362
end
332-
return nothing
333363
end,
334364
(x, i) -> begin
335-
[Expr(:call, Expr(:nooverdub, GlobalRef(Core.Intrinsics, :llvmcall)), x.args[2:end]...)]
365+
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
366+
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
367+
intrinsic = unravel_intrinsics(x)
368+
if applycall
369+
# using stmt.args[2] instead of `intrinsic` leads to a bug
370+
stmt.args[2] = Expr(:nooverdub, intrinsic)
371+
else
372+
stmt.args[1] = Expr(:nooverdub, intrinsic)
373+
end
374+
[x]
336375
end)
337376
end
338377

src/pass.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,27 @@ function is_ir_element(x, y, code::Vector)
197197
end
198198
return result
199199
end
200+
201+
"""
202+
ir_element(x, code::Vector)
203+
204+
Follows the series of `SSAValue` that define `x`.
205+
206+
See also: [`is_ir_element`](@ref)
207+
"""
208+
function ir_element(x, code::Vector)
209+
while isa(x, Core.SSAValue)
210+
x = code[x.id]
211+
end
212+
return x
213+
end
214+
215+
function resolve_early(ref::GlobalRef)
216+
mod = ref.mod
217+
name = ref.name
218+
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
219+
return getfield(mod, name)
220+
else
221+
return nothing
222+
end
223+
end

test/misctests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,8 @@ Cassette.@context LLVMCallCtx
686686
end
687687
end
688688

689-
import Core.Intrinsics
690689
function llvm_sin(x::Float64)
691-
# Needs fix for Core.Intrinsics.llvmcall
692-
Intrinsics.llvmcall(
690+
Core.Intrinsics.llvmcall(
693691
(
694692
"""declare double @llvm.sin.f64(double)""",
695693
"""%2 = call double @llvm.sin.f64(double %0)

0 commit comments

Comments
 (0)