Skip to content

Commit 3549ca3

Browse files
committed
improved function partial application design
This replaces `Fix` (xref JuliaLang#54653) with `fix`. The usage is similar: use `fix(i)(f, x)` instead of `Fix{i}(f, x)`. Benefits: * Improved type safety: creating an invalid type such as `Fix{:some_symbol}` or `Fix{-7}` is not possible. * The design should be friendlier to future extensions. E.g., suppose that publicly-facing functionality for fixing a keyword (instead of positional) argument was desired, it could be achieved by adding a new method to `fix` taking a `Symbol`, instead of adding new public names. Lots of changes are shared with PR JuliaLang#56425, if one of them gets merged the other will be greatly simplified.
1 parent afdba95 commit 3549ca3

File tree

11 files changed

+381
-74
lines changed

11 files changed

+381
-74
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ New library functions
8787
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
8888
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
8989
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
90-
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
90+
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).
9191

9292
New library features
9393
--------------------

base/Base_compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ include("error.jl")
213213
include("bool.jl")
214214
include("number.jl")
215215
include("int.jl")
216+
include("typedomainnumbers.jl")
216217
include("operators.jl")
217218
include("pointer.jl")
218219
include("refvalue.jl")

base/operators.jl

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,55 +1153,152 @@ julia> filter(!isletter, str)
11531153
!(f::Function) = (!) f
11541154
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f
11551155

1156+
const _PositiveInteger = _TypeDomainNumbers.PositiveIntegers.PositiveInteger
1157+
1158+
struct PartiallyAppliedFunction{Position <: _PositiveInteger, Func, Arg} <: Function
1159+
partially_applied_argument_position::Position
1160+
f::Func
1161+
x::Arg
1162+
1163+
global function new_partially_applied_function(pos::_PositiveInteger, f::Func, x) where {Func}
1164+
new{typeof(pos), _stable_typeof(f), _stable_typeof(x)}(pos, f, x)
1165+
end
1166+
end
1167+
1168+
function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _PositiveInteger, Func}
1169+
pos = (Position::DataType).instance
1170+
new_partially_applied_function(pos, func, arg)
1171+
end
1172+
1173+
function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
1174+
getfield(v, s)
1175+
end # avoid overspecialization
1176+
1177+
function Base.show(
1178+
(@nospecialize io::Base.IO),
1179+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position}}),
1180+
) where {Position <: _PositiveInteger}
1181+
if Position isa DataType
1182+
print(io, "fix(")
1183+
show(io, Position.instance)
1184+
print(io, ')')
1185+
else
1186+
show(io, PartiallyAppliedFunction)
1187+
print(io, '{')
1188+
show(io, Position)
1189+
print(io, '}')
1190+
end
1191+
end
1192+
1193+
function Base.show(
1194+
(@nospecialize io::Base.IO),
1195+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func}}),
1196+
) where {Position <: _PositiveInteger, Func}
1197+
show(io, PartiallyAppliedFunction{Position})
1198+
print(io, '{')
1199+
show(io, Func)
1200+
print(io, '}')
1201+
end
1202+
1203+
function Base.show(
1204+
(@nospecialize io::Base.IO),
1205+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func, Arg}}),
1206+
) where {Position <: _PositiveInteger, Func, Arg}
1207+
show(io, PartiallyAppliedFunction{Position, Func})
1208+
print(io, '{')
1209+
show(io, Arg)
1210+
print(io, '}')
1211+
end
1212+
1213+
function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
1214+
print(io, "fix(")
1215+
show(io, p.partially_applied_argument_position)
1216+
print(io, ")(")
1217+
show(io, p.f)
1218+
print(io, ", ")
1219+
show(io, p.x)
1220+
print(io, ')')
1221+
end
1222+
1223+
function _partially_applied_function_check(m::Int, nm1::Int)
1224+
if m < nm1
1225+
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
1226+
end
1227+
end
1228+
1229+
function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
1230+
n = partial.partially_applied_argument_position
1231+
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
1232+
_partially_applied_function_check(M, Int(nm1))
1233+
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
1234+
partial.f(args_left..., partial.x, args_right...; kws...)
1235+
end
1236+
11561237
"""
1157-
Fix{N}(f, x)
1238+
fix(::Integer)::UnionAll
1239+
1240+
Return a [`UnionAll`](@ref) type such that:
1241+
* It's a constructor taking two arguments:
1242+
1. A function to be partially applied
1243+
2. An argument of the above function to be fixed
1244+
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.
1245+
1246+
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
11581247
1159-
A type representing a partially-applied version of a function `f`, with the argument
1160-
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
1161-
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
1248+
See also: [`Fix1`](@ref), [`Fix2`](@ref).
11621249
11631250
!!! compat "Julia 1.12"
1164-
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
1165-
are available earlier.
1251+
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).
11661252
11671253
!!! note
1168-
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
1254+
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
11691255
available arguments, rather than an absolute ordering on the target function. For example,
1170-
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
1256+
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
11711257
fixes the first and third arg.
1172-
"""
1173-
struct Fix{N,F,T} <: Function
1174-
f::F
1175-
x::T
11761258
1177-
function Fix{N}(f::F, x) where {N,F}
1178-
if !(N isa Int)
1179-
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
1180-
elseif N < 1
1181-
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
1182-
end
1183-
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
1184-
end
1185-
end
1259+
### Examples
11861260
1187-
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
1188-
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
1189-
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
1190-
end
1261+
```jldoctest
1262+
julia> Base.fix(2)(-, 3)(7)
1263+
4
11911264
1192-
# Special cases for improved constant propagation
1193-
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
1194-
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
1265+
julia> Base.fix(2) === Base.Fix2
1266+
true
1267+
1268+
julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
1269+
true
1270+
```
1271+
"""
1272+
function fix(@nospecialize m::Integer)
1273+
n = Int(m)::Int
1274+
if n 0
1275+
throw(ArgumentError("the index of the partially applied argument must be positive"))
1276+
end
1277+
k = _TypeDomainNumbers.Utils.from_abs_int(n)
1278+
PartiallyAppliedFunction{typeof(k)}
1279+
end
11951280

11961281
"""
1197-
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
1282+
Fix1::UnionAll
1283+
1284+
[`fix(1)`](@ref Base.fix).
11981285
"""
1199-
const Fix1{F,T} = Fix{1,F,T}
1286+
const Fix1 = fix(1)
12001287

12011288
"""
1202-
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
1289+
Fix2::UnionAll
1290+
1291+
[`fix(2)`](@ref Base.fix).
12031292
"""
1204-
const Fix2{F,T} = Fix{2,F,T}
1293+
const Fix2 = fix(2)
1294+
1295+
# Special cases for improved constant propagation
1296+
function (partial::Fix1)(x; kws...)
1297+
partial.f(partial.x, x; kws...)
1298+
end
1299+
function (partial::Fix2)(x; kws...)
1300+
partial.f(x, partial.x; kws...)
1301+
end
12051302

12061303

12071304
"""

base/public.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public
1414
AsyncCondition,
1515
CodeUnits,
1616
Event,
17-
Fix,
17+
fix,
1818
Fix1,
1919
Fix2,
2020
Generator,

base/tuple.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
module _TupleTypeByLength
4+
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
5+
const Tuple1OrMore = Tuple{Any, Vararg}
6+
const Tuple2OrMore = Tuple{Any, Any, Vararg}
7+
const Tuple32OrMore = Tuple{
8+
Any, Any, Any, Any, Any, Any, Any, Any,
9+
Any, Any, Any, Any, Any, Any, Any, Any,
10+
Any, Any, Any, Any, Any, Any, Any, Any,
11+
Any, Any, Any, Any, Any, Any, Any, Any,
12+
Vararg{Any, N},
13+
} where {N}
14+
end
15+
316
# Document NTuple here where we have everything needed for the doc system
417
"""
518
NTuple{N, T}
@@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
358371
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
359372
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
360373
# stop inlining after some number of arguments to avoid code blowup
361-
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
362-
Any,Any,Any,Any,Any,Any,Any,Any,
363-
Any,Any,Any,Any,Any,Any,Any,Any,
364-
Any,Any,Any,Any,Any,Any,Any,Any,
365-
Vararg{Any,N}}
374+
const Any32 = _TupleTypeByLength.Tuple32OrMore
366375
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
367376
T,T,T,T,T,T,T,T,
368377
T,T,T,T,T,T,T,T,

0 commit comments

Comments
 (0)