Skip to content

Commit 0e28d19

Browse files
committed
RFC: Introduce TypedCallable
TypedCallable provides a wrapper for callable objects, with the following benefits: 1. Enforced type-stability (for concrete AT/RT types) 2. Fast calling convention (frequently < 10 ns / call) 3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest 4. Pre-compilation support (including `--trim` compatibility) It can be used like this: ```julia callbacks = @TypedCallable{(::Int,::Int)->Bool}[] register_callback!(callbacks, f::F) where {F<:Function} = push!(callbacks, @TypedCallable f(::Int,::Int)::Bool) register_callback!(callbacks, (x,y)->(x == y)) register_callback!(callbacks, (x,y)->(x != y)) @Btime callbacks[rand(1:2)](1,1) ``` This is very similar to the existing `FunctionWrappers.jl`, but there are a few key differences: - Better type support: TypedCallable supports the full range of Julia types (incl. Varargs), and it has access to all of Julia's "internal" calling conventions so calls are fast (and allocation-free) for a wider range of input types - Improved dispatch handling: The `@cfunction` functionality used by FunctionWrappers has several dispatch bugs, which cause wrappers to occasionally not see new Methods. These bugs are fixed (or soon to be fixed) for TypedCallable. - Pre-compilation support including for `juliac` / `--trim` (#55047) Many of the improvements here are actually thanks to the `OpaqueClosure` introduced by @Keno - This type just builds on top of OpaqueClosure to provide an interface with Julia's usual dispatch semantics.
1 parent 2b140ba commit 0e28d19

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

base/opaque_closure.jl

+194
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,197 @@ function generate_opaque_closure(@nospecialize(sig), @nospecialize(rt_lb), @nosp
110110
return ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint),
111111
sig, rt_lb, rt_ub, mod, src, lineno, file, nargs, isva, env, do_compile, isinferred)
112112
end
113+
114+
struct Slot{T} end
115+
struct Splat{T}
116+
value::T
117+
end
118+
119+
# Args is a Tuple{Vararg{Union{Slot{T},Some{T}}}} where Slot{T} represents
120+
# an uncurried argument slot, and Some{T} represents an argument to curry.
121+
@noinline @generated function Core.OpaqueClosure(Args::Tuple, ::Slot{RT}) where RT
122+
AT = Any[]
123+
call = Expr(:call)
124+
extracted = Expr[]
125+
closure_args = Expr(:tuple)
126+
for (i, T) in enumerate(Args.parameters)
127+
v = Symbol("arg", i)
128+
is_splat = T <: Splat
129+
if is_splat # TODO: check position
130+
push!(call.args, :($v...))
131+
T = T.parameters[1]
132+
else
133+
push!(call.args, v)
134+
end
135+
if T <: Some
136+
push!(extracted, :($v = something(Args[$i])))
137+
elseif T <: Slot
138+
SlotT = T.parameters[1]
139+
push!(AT, is_splat ? Vararg{SlotT} : SlotT)
140+
push!(closure_args.args, call.args[end])
141+
else @assert false end
142+
end
143+
AT = Tuple{AT...}
144+
return Base.remove_linenums!(quote
145+
$(extracted...)
146+
$(Expr(:opaque_closure, AT, RT, RT, #= allow_partial =# false, :(($(closure_args))->@inline $(call))))
147+
end)
148+
end
149+
150+
"""
151+
TypedCallable{AT,RT}
152+
153+
TypedCallable provides a wrapper for callable objects, with the following benefits:
154+
1. Enforced type-stability (for concrete AT/RT types)
155+
2. Fast calling convention (frequently < 10 ns / call)
156+
3. Normal Julia dispatch semantics (sees new Methods, etc.) + invoke_latest
157+
4. Full pre-compilation support (including `--trim` compatibility)
158+
159+
## Examples
160+
161+
```julia
162+
const callbacks = @TypedCallable{(::Int,::Int)->Bool}[]
163+
164+
register_callback!(callbacks, f::F) where {F<:Function} =
165+
push!(callbacks, @TypedCallable f(::Int,::Int)::Bool)
166+
167+
register_callback!(callbacks, (x,y)->(x == y))
168+
register_callback!(callbacks, (x,y)->(x != y))
169+
170+
# Calling a random (or runtime-known) callback is fast!
171+
@btime callbacks[rand(1:2)](1,1)
172+
```
173+
174+
# Extended help
175+
176+
### As an invalidation barrier
177+
178+
TypedCallable can also be used as an "invalidation barrier", since the caller of a
179+
TypedCallable is not affected by any invalidations of its callee(s). This doesn't
180+
completely cure the original invalidation, but it stops it from propagating all the
181+
way through your code.
182+
183+
This can be especially helpful, e.g., when calling back to user-provided functions
184+
whose invalidations you may have no control over.
185+
"""
186+
mutable struct TypedCallable{AT,RT}
187+
@atomic oc::Base.RefValue{Core.OpaqueClosure{AT,RT}}
188+
const task::Union{Task,Nothing}
189+
const build_oc::Function
190+
end
191+
192+
function Base.show(io::IO, tc::Base.Experimental.TypedCallable)
193+
A, R = typeof(tc).parameters
194+
Base.print(io, "@TypedCallable{")
195+
Base.show_tuple_as_call(io, Symbol(""), A; hasfirst=false)
196+
Base.print(io, "->◌::", R, "}()")
197+
end
198+
199+
function rebuild_in_world!(@nospecialize(self::TypedCallable), world::UInt)
200+
oc = Base.invoke_in_world(world, self.build_oc)
201+
@atomic :release self.oc = Base.Ref(oc)
202+
return oc
203+
end
204+
205+
@inline function (self::TypedCallable{AT,RT})(args...) where {AT,RT}
206+
invoke_world = if self.task === nothing
207+
Base.get_world_counter() # Base.unsafe_load(cglobal(:jl_world_counter, UInt), :acquire) ?
208+
elseif self.task === Base.current_task()
209+
Base.tls_world_age()
210+
else
211+
error("TypedCallable{...} was called from a different task than it was created in.")
212+
end
213+
oc = (@atomic :acquire self.oc)[]
214+
if oc.world != invoke_world
215+
oc = @noinline rebuild_in_world!(self, invoke_world)::Core.OpaqueClosure{AT,RT}
216+
end
217+
return oc(args...)
218+
end
219+
220+
function _TypedCallable_type(ex)
221+
type_err = "Invalid @TypedCallable expression: $(ex)\nExpected \"@TypedCallable{(::T,::U,...)->RT}\""
222+
223+
# Unwrap {...}
224+
(length(ex.args) != 1) && error(type_err)
225+
ex = ex.args[1]
226+
227+
# Unwrap (...)->RT
228+
!(Base.isexpr(ex, :->) && length(ex.args) == 2) && error(type_err)
229+
tuple_, rt = ex.args
230+
if !(Base.isexpr(tuple_, :tuple) && all((x)->Base.isexpr(x, :(::)), tuple_.args))
231+
# note: (arg::T, ...) is specifically allowed (the "arg" part is unused)
232+
error(type_err)
233+
end
234+
!Base.isexpr(rt, :block) && error(type_err)
235+
236+
# Remove any LineNumberNodes inserted by lowering
237+
filter!((x)->!isa(x,Core.LineNumberNode), rt.args)
238+
(length(rt.args) != 1) && error(type_err)
239+
240+
# Build args
241+
AT = Expr[esc(last(x.args)) for x in tuple_.args]
242+
RT = rt.args[1]
243+
244+
# Unwrap ◌::T to T
245+
if Base.isexpr(RT, :(::)) && length(RT.args) == 2 && RT.args[1] == :◌
246+
RT = RT.args[2]
247+
end
248+
249+
return :($TypedCallable{Tuple{$(AT...)}, $(esc(RT))})
250+
end
251+
252+
function _TypedCallable_closure(ex)
253+
if Base.isexpr(ex, :call)
254+
error("""
255+
Invalid @TypedCallable expression: $(ex)
256+
An explicit return type assert is required (e.g. "@TypedCallable f(...)::RT")
257+
""")
258+
end
259+
260+
call_, RT = ex.args
261+
if !Base.isexpr(call_, :call)
262+
error("""Invalid @TypedCallable expression: $(ex)
263+
The supported syntax is:
264+
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
265+
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
266+
""")
267+
end
268+
oc_args = map(call_.args) do arg
269+
is_splat = Base.isexpr(arg, :(...))
270+
arg = is_splat ? arg.args[1] : arg
271+
transformed = if Base.isexpr(arg, :(::))
272+
if length(arg.args) == 1 # it's a "slot"
273+
slot_ty = esc(only(arg.args))
274+
:(Slot{$slot_ty}())
275+
elseif length(arg.args) == 2
276+
(arg, ty) = arg.args
277+
:(Some{$(esc(ty))}($(esc(arg))))
278+
else @assert false end
279+
else
280+
:(Some($(esc(arg))))
281+
end
282+
return is_splat ? Expr(:call, Splat, transformed) : transformed
283+
end
284+
# TODO: kwargs support
285+
RT = :(Slot{$(esc(RT))}())
286+
invoke_latest = true # expose as flag?
287+
task = invoke_latest ? nothing : :(Base.current_task())
288+
return quote
289+
build_oc = ()->Core.OpaqueClosure(($(oc_args...),), $(RT))
290+
$(TypedCallable)(Ref(build_oc()), $task, build_oc)
291+
end
292+
end
293+
294+
macro TypedCallable(ex)
295+
if Base.isexpr(ex, :braces)
296+
return _TypedCallable_type(ex)
297+
elseif Base.isexpr(ex, :call) || (Base.isexpr(ex, :(::)) && length(ex.args) == 2)
298+
return _TypedCallable_closure(ex)
299+
else
300+
error("""Invalid @TypedCallable expression: $(ex)
301+
The supported syntax is:
302+
@TypedCallable{(::T,::U,...)->RT} (to construct the type)
303+
@TypedCallable f(x,::T,...)::RT (to construct the TypedCallable)
304+
""")
305+
end
306+
end

0 commit comments

Comments
 (0)