diff --git a/base/opaque_closure.jl b/base/opaque_closure.jl index 779cbf55ceaf3..0ae9fdd469434 100644 --- a/base/opaque_closure.jl +++ b/base/opaque_closure.jl @@ -18,7 +18,7 @@ the argument type may be fixed length even if the function is variadic. This interface is experimental and subject to change or removal without notice. """ macro opaque(ex) - esc(Expr(:opaque_closure, nothing, nothing, nothing, ex)) + esc(Expr(:opaque_closure, nothing, nothing, nothing, #= allow_partial =# true, ex)) end macro opaque(ty, ex) @@ -34,7 +34,7 @@ macro opaque(ty, ex) end AT = (AT !== :_) ? AT : nothing RT = (RT !== :_) ? RT : nothing - return esc(Expr(:opaque_closure, AT, RT, RT, ex)) + return esc(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# true, ex)) end # OpaqueClosure construction from pre-inferred CodeInfo/IRCode @@ -110,3 +110,197 @@ function generate_opaque_closure(@nospecialize(sig), @nospecialize(rt_lb), @nosp return ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint), sig, rt_lb, rt_ub, mod, src, lineno, file, nargs, isva, env, do_compile, isinferred) end + +struct Slot{T} end +struct Splat{T} + value::T +end + +# Args is a Tuple{Vararg{Union{Slot{T},Some{T}}}} where Slot{T} represents +# an uncurried argument slot, and Some{T} represents an argument to curry. +@noinline @generated function Core.OpaqueClosure(Args::Tuple, ::Slot{RT}) where RT + AT = Any[] + call = Expr(:call) + extracted = Expr[] + closure_args = Expr(:tuple) + for (i, T) in enumerate(Args.parameters) + v = Symbol("arg", i) + is_splat = T <: Splat + if is_splat # TODO: check position + push!(call.args, :($v...)) + T = T.parameters[1] + else + push!(call.args, v) + end + if T <: Some + push!(extracted, :($v = something(Args[$i]))) + elseif T <: Slot + SlotT = T.parameters[1] + push!(AT, is_splat ? Vararg{SlotT} : SlotT) + push!(closure_args.args, call.args[end]) + else @assert false end + end + AT = Tuple{AT...} + return Base.remove_linenums!(quote + $(extracted...) + $(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# false, :(($(closure_args))->@inline $(call)))) + end) +end + +""" + TypedCallable{AT,RT} + +TypedCallable provides a wrapper for callable objects, with the following benefits: + 1. Enforced type-stability (for concrete AT/RT types) + 2. Fast calling convention (frequently < 10 ns / call) + 3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest + 4. Full pre-compilation support (including `--trim` compatibility) + +## Examples + +```julia +const callbacks = @TypedCallable{(::Int,::Int)->Bool}[] + +register_callback!(callbacks, f::F) where {F<:Function} = + push!(callbacks, @TypedCallable f(::Int,::Int)::Bool) + +register_callback!(callbacks, (x,y)->(x == y)) +register_callback!(callbacks, (x,y)->(x != y)) + +# Calling a random (or runtime-known) callback is fast! +@btime callbacks[rand(1:2)](1,1) +``` + +# Extended help + +### As an invalidation barrier + +TypedCallable can also be used as an "invalidation barrier", since the caller of a +TypedCallable is not affected by any invalidations of its callee(s). This doesn't +completely cure the original invalidation, but it stops it from propagating all the +way through your code. + +This can be especially helpful, e.g., when calling back to user-provided functions +whose invalidations you may have no control over. +""" +mutable struct TypedCallable{AT,RT} + @atomic oc::Base.RefValue{Core.OpaqueClosure{AT,RT}} + const task::Union{Task,Nothing} + const build_oc::Function +end + +function Base.show(io::IO, tc::Base.Experimental.TypedCallable) + A, R = typeof(tc).parameters + Base.print(io, "@TypedCallable{") + Base.show_tuple_as_call(io, Symbol(""), A; hasfirst=false) + Base.print(io, "->◌::", R, "}()") +end + +function rebuild_in_world!(@nospecialize(self::TypedCallable), world::UInt) + oc = Base.invoke_in_world(world, self.build_oc) + @atomic :release self.oc = Base.Ref(oc) + return oc +end + +@inline function (self::TypedCallable{AT,RT})(args...) where {AT,RT} + invoke_world = if self.task === nothing + Base.get_world_counter() # Base.unsafe_load(cglobal(:jl_world_counter, UInt), :acquire) ? + elseif self.task === Base.current_task() + Base.tls_world_age() + else + error("TypedCallable{...} was called from a different task than it was created in.") + end + oc = (@atomic :acquire self.oc)[] + if oc.world != invoke_world + oc = @noinline rebuild_in_world!(self, invoke_world)::Core.OpaqueClosure{AT,RT} + end + return oc(args...) +end + +function _TypedCallable_type(ex) + type_err = "Invalid @TypedCallable expression: $(ex)\nExpected \"@TypedCallable{(::T,::U,...)->RT}\"" + + # Unwrap {...} + (length(ex.args) != 1) && error(type_err) + ex = ex.args[1] + + # Unwrap (...)->RT + !(Base.isexpr(ex, :->) && length(ex.args) == 2) && error(type_err) + tuple_, rt = ex.args + if !(Base.isexpr(tuple_, :tuple) && all((x)->Base.isexpr(x, :(::)), tuple_.args)) + # note: (arg::T, ...) is specifically allowed (the "arg" part is unused) + error(type_err) + end + !Base.isexpr(rt, :block) && error(type_err) + + # Remove any LineNumberNodes inserted by lowering + filter!((x)->!isa(x,Core.LineNumberNode), rt.args) + (length(rt.args) != 1) && error(type_err) + + # Build args + AT = Expr[esc(last(x.args)) for x in tuple_.args] + RT = rt.args[1] + + # Unwrap ◌::T to T + if Base.isexpr(RT, :(::)) && length(RT.args) == 2 && RT.args[1] == :◌ + RT = RT.args[2] + end + + return :($TypedCallable{Tuple{$(AT...)}, $(esc(RT))}) +end + +function _TypedCallable_closure(ex) + if Base.isexpr(ex, :call) + error(""" + Invalid @TypedCallable expression: $(ex) + An explicit return type assert is required (e.g. "@TypedCallable f(...)::RT") + """) + end + + call_, RT = ex.args + if !Base.isexpr(call_, :call) + error("""Invalid @TypedCallable expression: $(ex) + The supported syntax is: + @TypedCallable{(::T,::U,...)->RT} (to construct the type) + @TypedCallable f(x,::T,...)::RT (to construct the TypedCallable) + """) + end + oc_args = map(call_.args) do arg + is_splat = Base.isexpr(arg, :(...)) + arg = is_splat ? arg.args[1] : arg + transformed = if Base.isexpr(arg, :(::)) + if length(arg.args) == 1 # it's a "slot" + slot_ty = esc(only(arg.args)) + :(Slot{$slot_ty}()) + elseif length(arg.args) == 2 + (arg, ty) = arg.args + :(Some{$(esc(ty))}($(esc(arg)))) + else @assert false end + else + :(Some($(esc(arg)))) + end + return is_splat ? Expr(:call, Splat, transformed) : transformed + end + # TODO: kwargs support + RT = :(Slot{$(esc(RT))}()) + invoke_latest = true # expose as flag? + task = invoke_latest ? nothing : :(Base.current_task()) + return quote + build_oc = ()->Core.OpaqueClosure(($(oc_args...),), $(RT)) + $(TypedCallable)(Ref(build_oc()), $task, build_oc) + end +end + +macro TypedCallable(ex) + if Base.isexpr(ex, :braces) + return _TypedCallable_type(ex) + elseif Base.isexpr(ex, :call) || (Base.isexpr(ex, :(::)) && length(ex.args) == 2) + return _TypedCallable_closure(ex) + else + error("""Invalid @TypedCallable expression: $(ex) + The supported syntax is: + @TypedCallable{(::T,::U,...)->RT} (to construct the type) + @TypedCallable f(x,::T,...)::RT (to construct the TypedCallable) + """) + end +end diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 5636caa48e6e6..a2d3ffdd66f67 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2435,7 +2435,8 @@ (let* ((argt (something (list (expand-forms (cadr e)) #f))) (rt_lb (something (list (expand-forms (caddr e)) #f))) (rt_ub (something (list (expand-forms (cadddr e)) #f))) - (F (caddddr e)) + (allow-partial (caddddr e)) + (F (cadddddr e)) (isva (let* ((arglist (function-arglist F)) (lastarg (and (pair? arglist) (last arglist)))) (if (and argt (any (lambda (arg) @@ -2460,7 +2461,7 @@ (let* ((argtype (foldl (lambda (var ex) `(call (core UnionAll) ,var ,ex)) (expand-forms `(curly (core Tuple) ,@argtypes)) (reverse tvars)))) - `(_opaque_closure ,(or argt argtype) ,rt_lb ,rt_ub ,isva ,(length argtypes) ,functionloc ,lam)))) + `(_opaque_closure ,(or argt argtype) ,rt_lb ,rt_ub ,isva ,(length argtypes) ,allow-partial ,functionloc ,lam)))) 'block (lambda (e) @@ -4028,7 +4029,8 @@ f(x) = yt(x) ((_opaque_closure) (let* ((isva (car (cddddr e))) (nargs (cadr (cddddr e))) - (functionloc (caddr (cddddr e))) + (allow-partial (caddr (cddddr e))) + (functionloc (cadddr (cddddr e))) (lam2 (last e)) (vis (lam:vinfo lam2)) (cvs (map car (cadr vis)))) @@ -4040,7 +4042,7 @@ f(x) = yt(x) v))) cvs))) `(new_opaque_closure - ,(cadr e) ,(or (caddr e) '(call (core apply_type) (core Union))) ,(or (cadddr e) '(core Any)) (true) + ,(cadr e) ,(or (caddr e) '(call (core apply_type) (core Union))) ,(or (cadddr e) '(core Any)) ,allow-partial (opaque_closure_method (null) ,nargs ,isva ,functionloc ,(convert-lambda lam2 (car (lam:args lam2)) #f '() (symbol-to-idx-map cvs))) ,@var-exprs)))) ((method) diff --git a/test/precompile.jl b/test/precompile.jl index 21a17e0778496..0c918ae46a733 100644 --- a/test/precompile.jl +++ b/test/precompile.jl @@ -1995,6 +1995,13 @@ precompile_test_harness("Generated Opaque") do load_path Expr(:opaque_closure_method, nothing, 0, false, lno, ci)) end @assert oc_re_generated_no_partial()() === 1 + @generated function oc_re_generated_no_partial_macro() + AT = nothing + RT = nothing + allow_partial = false # makes this legal to generate during pre-compile + return Expr(:opaque_closure, AT, RT, RT, allow_partial, :(()->const_int_barrier())) + end + @assert oc_re_generated_no_partial_macro()() === 1 end """) Base.compilecache(Base.PkgId("GeneratedOpaque"))