Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Introduce TypedCallable #55111

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 196 additions & 2 deletions base/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ the argument type may be fixed length even if the function is variadic.
This interface is experimental and subject to change or removal without notice.
"""
macro opaque(ex)
esc(Expr(:opaque_closure, nothing, nothing, nothing, ex))
esc(Expr(:opaque_closure, nothing, nothing, nothing, #= allow_partial =# true, ex))
end

macro opaque(ty, ex)
Expand All @@ -34,7 +34,7 @@ macro opaque(ty, ex)
end
AT = (AT !== :_) ? AT : nothing
RT = (RT !== :_) ? RT : nothing
return esc(Expr(:opaque_closure, AT, RT, RT, ex))
return esc(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# true, ex))
end

# OpaqueClosure construction from pre-inferred CodeInfo/IRCode
Expand Down Expand Up @@ -110,3 +110,197 @@ function generate_opaque_closure(@nospecialize(sig), @nospecialize(rt_lb), @nosp
return ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint),
sig, rt_lb, rt_ub, mod, src, lineno, file, nargs, isva, env, do_compile, isinferred)
end

struct Slot{T} end
struct Splat{T}
value::T
end

# Args is a Tuple{Vararg{Union{Slot{T},Some{T}}}} where Slot{T} represents
# an uncurried argument slot, and Some{T} represents an argument to curry.
@noinline @generated function Core.OpaqueClosure(Args::Tuple, ::Slot{RT}) where RT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessarily complicated. Why not have AT be passthrough and specify (nreq, isva) as a Val?

Copy link
Member Author

@topolarity topolarity Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to support syntaxes like @TypedCallable add_node(self, ::Node)::Nothing or @TypedCallable show(self)::Nothing where we close over more than just the first argument

Copy link
Member

@vtjnash vtjnash Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also banned in precompile, since @generated is not permitted to return a new :opaque_closure object. Can we do without this hand-generated complexity? Stuff like Args.parameters is typically not actually recommended in a generated function either, as it returns something with incorrect type identity (makes the transofrm not pure). I remember doing something like make(f, AT, RT) = (Base.compilerbarrier(:const, Base.Experimental.@opaque(AT->RT, (args...)->f(args...)))::Core.OpaqueClosure{A,R})

The Slot/Splat seems to be just a partial re-implementation of lambdas, but seems a bit less reliable since it has none of the normal lowering, and makes it so that the call is not a subtype of its argument signature?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also banned in precompile, since @generated is not permitted to return a new :opaque_closure object.

this opts-out of PartialOpaque support via #54734 so that this is allowed in pre-compile

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not inference that is banned, it is the construct itself, since it allocates new state (a Method) which is forbidden during pure operations (a generator)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see - I thought that was only forbidden during the execution of the generator, but guess it also applies to the side effects of lowering the generated expression?

AT = Any[]
call = Expr(:call)
extracted = Expr[]
closure_args = Expr(:tuple)
for (i, T) in enumerate(Args.parameters)
v = Symbol("arg", i)
is_splat = T <: Splat
if is_splat # TODO: check position
push!(call.args, :($v...))
T = T.parameters[1]
else
push!(call.args, v)
end
if T <: Some
push!(extracted, :($v = something(Args[$i])))
elseif T <: Slot
SlotT = T.parameters[1]
push!(AT, is_splat ? Vararg{SlotT} : SlotT)
push!(closure_args.args, call.args[end])
else @assert false end
end
AT = Tuple{AT...}
return Base.remove_linenums!(quote
$(extracted...)
$(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# false, :(($(closure_args))->@inline $(call))))
end)
end

"""
TypedCallable{AT,RT}

TypedCallable provides a wrapper for callable objects, with the following benefits:
1. Enforced type-stability (for concrete AT/RT types)
2. Fast calling convention (frequently < 10 ns / call)
3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest
4. Full pre-compilation support (including `--trim` compatibility)

## Examples

```julia
const callbacks = @TypedCallable{(::Int,::Int)->Bool}[]

register_callback!(callbacks, f::F) where {F<:Function} =
push!(callbacks, @TypedCallable f(::Int,::Int)::Bool)

register_callback!(callbacks, (x,y)->(x == y))
register_callback!(callbacks, (x,y)->(x != y))

# Calling a random (or runtime-known) callback is fast!
@btime callbacks[rand(1:2)](1,1)
```

# Extended help

### As an invalidation barrier

TypedCallable can also be used as an "invalidation barrier", since the caller of a
TypedCallable is not affected by any invalidations of its callee(s). This doesn't
completely cure the original invalidation, but it stops it from propagating all the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe useful to note that this will also block other information propagation? I am thinking constant-propagation, effects, escape-analysis etc...

way through your code.

This can be especially helpful, e.g., when calling back to user-provided functions
whose invalidations you may have no control over.
"""
mutable struct TypedCallable{AT,RT}
@atomic oc::Base.RefValue{Core.OpaqueClosure{AT,RT}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra ref?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to keep the atomic operations small, since otherwise the OC is inlined and we start emitting jl_(un)lock_value, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just refuse to inline @atomic annotated structs that are larger than our max atomic size?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd be in support of that - #51495 (comment) is also related

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main advantage of that is we could do atomic reads without needing to take the lock, so reading may be more scalable. Changing the implementation to use a seqlock would also fix that, still without requiring the extra allocation of this. The test in #51495 was benchmarking a store of a large object with not using a large object, so it wasn't directly comparing equivalent things.

const task::Union{Task,Nothing}
const build_oc::Function
end

function Base.show(io::IO, tc::Base.Experimental.TypedCallable)
A, R = typeof(tc).parameters
Base.print(io, "@TypedCallable{")
Base.show_tuple_as_call(io, Symbol(""), A; hasfirst=false)
Base.print(io, "->◌::", R, "}()")
end

function rebuild_in_world!(@nospecialize(self::TypedCallable), world::UInt)
oc = Base.invoke_in_world(world, self.build_oc)
@atomic :release self.oc = Base.Ref(oc)
return oc
end

@inline function (self::TypedCallable{AT,RT})(args...) where {AT,RT}
invoke_world = if self.task === nothing
Base.get_world_counter() # Base.unsafe_load(cglobal(:jl_world_counter, UInt), :acquire) ?
elseif self.task === Base.current_task()
Base.tls_world_age()
else
error("TypedCallable{...} was called from a different task than it was created in.")
end
oc = (@atomic :acquire self.oc)[]
if oc.world != invoke_world
oc = @noinline rebuild_in_world!(self, invoke_world)::Core.OpaqueClosure{AT,RT}
end
return oc(args...)
end

function _TypedCallable_type(ex)
type_err = "Invalid @TypedCallable expression: $(ex)\nExpected \"@TypedCallable{(::T,::U,...)->RT}\""

# Unwrap {...}
(length(ex.args) != 1) && error(type_err)
ex = ex.args[1]

# Unwrap (...)->RT
!(Base.isexpr(ex, :->) && length(ex.args) == 2) && error(type_err)
tuple_, rt = ex.args
if !(Base.isexpr(tuple_, :tuple) && all((x)->Base.isexpr(x, :(::)), tuple_.args))
# note: (arg::T, ...) is specifically allowed (the "arg" part is unused)
error(type_err)
end
!Base.isexpr(rt, :block) && error(type_err)

# Remove any LineNumberNodes inserted by lowering
filter!((x)->!isa(x,Core.LineNumberNode), rt.args)
(length(rt.args) != 1) && error(type_err)

# Build args
AT = Expr[esc(last(x.args)) for x in tuple_.args]
RT = rt.args[1]

# Unwrap ◌::T to T
if Base.isexpr(RT, :(::)) && length(RT.args) == 2 && RT.args[1] == :◌
RT = RT.args[2]
end

return :($TypedCallable{Tuple{$(AT...)}, $(esc(RT))})
end

function _TypedCallable_closure(ex)
if Base.isexpr(ex, :call)
error("""
Invalid @TypedCallable expression: $(ex)
An explicit return type assert is required (e.g. "@TypedCallable f(...)::RT")
""")
end

call_, RT = ex.args
if !Base.isexpr(call_, :call)
error("""Invalid @TypedCallable expression: $(ex)
The supported syntax is:
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
""")
end
oc_args = map(call_.args) do arg
is_splat = Base.isexpr(arg, :(...))
arg = is_splat ? arg.args[1] : arg
transformed = if Base.isexpr(arg, :(::))
if length(arg.args) == 1 # it's a "slot"
slot_ty = esc(only(arg.args))
:(Slot{$slot_ty}())
elseif length(arg.args) == 2
(arg, ty) = arg.args
:(Some{$(esc(ty))}($(esc(arg))))
else @assert false end
else
:(Some($(esc(arg))))
end
return is_splat ? Expr(:call, Splat, transformed) : transformed
end
# TODO: kwargs support
RT = :(Slot{$(esc(RT))}())
invoke_latest = true # expose as flag?
task = invoke_latest ? nothing : :(Base.current_task())
return quote
build_oc = ()->Core.OpaqueClosure(($(oc_args...),), $(RT))
$(TypedCallable)(Ref(build_oc()), $task, build_oc)
end
end

macro TypedCallable(ex)
if Base.isexpr(ex, :braces)
return _TypedCallable_type(ex)
elseif Base.isexpr(ex, :call) || (Base.isexpr(ex, :(::)) && length(ex.args) == 2)
return _TypedCallable_closure(ex)
else
error("""Invalid @TypedCallable expression: $(ex)
The supported syntax is:
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
""")
end
end
10 changes: 6 additions & 4 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2435,7 +2435,8 @@
(let* ((argt (something (list (expand-forms (cadr e)) #f)))
(rt_lb (something (list (expand-forms (caddr e)) #f)))
(rt_ub (something (list (expand-forms (cadddr e)) #f)))
(F (caddddr e))
(allow-partial (caddddr e))
(F (cadddddr e))
(isva (let* ((arglist (function-arglist F))
(lastarg (and (pair? arglist) (last arglist))))
(if (and argt (any (lambda (arg)
Expand All @@ -2460,7 +2461,7 @@
(let* ((argtype (foldl (lambda (var ex) `(call (core UnionAll) ,var ,ex))
(expand-forms `(curly (core Tuple) ,@argtypes))
(reverse tvars))))
`(_opaque_closure ,(or argt argtype) ,rt_lb ,rt_ub ,isva ,(length argtypes) ,functionloc ,lam))))
`(_opaque_closure ,(or argt argtype) ,rt_lb ,rt_ub ,isva ,(length argtypes) ,allow-partial ,functionloc ,lam))))

'block
(lambda (e)
Expand Down Expand Up @@ -4028,7 +4029,8 @@ f(x) = yt(x)
((_opaque_closure)
(let* ((isva (car (cddddr e)))
(nargs (cadr (cddddr e)))
(functionloc (caddr (cddddr e)))
(allow-partial (caddr (cddddr e)))
(functionloc (cadddr (cddddr e)))
(lam2 (last e))
(vis (lam:vinfo lam2))
(cvs (map car (cadr vis))))
Expand All @@ -4040,7 +4042,7 @@ f(x) = yt(x)
v)))
cvs)))
`(new_opaque_closure
,(cadr e) ,(or (caddr e) '(call (core apply_type) (core Union))) ,(or (cadddr e) '(core Any)) (true)
,(cadr e) ,(or (caddr e) '(call (core apply_type) (core Union))) ,(or (cadddr e) '(core Any)) ,allow-partial
(opaque_closure_method (null) ,nargs ,isva ,functionloc ,(convert-lambda lam2 (car (lam:args lam2)) #f '() (symbol-to-idx-map cvs)))
,@var-exprs))))
((method)
Expand Down
7 changes: 7 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,13 @@ precompile_test_harness("Generated Opaque") do load_path
Expr(:opaque_closure_method, nothing, 0, false, lno, ci))
end
@assert oc_re_generated_no_partial()() === 1
@generated function oc_re_generated_no_partial_macro()
AT = nothing
RT = nothing
allow_partial = false # makes this legal to generate during pre-compile
return Expr(:opaque_closure, AT, RT, RT, allow_partial, :(()->const_int_barrier()))
end
@assert oc_re_generated_no_partial_macro()() === 1
end
""")
Base.compilecache(Base.PkgId("GeneratedOpaque"))
Expand Down