Skip to content

Commit 8b2e9ef

Browse files
committed
improved function partial application design
This replaces `Fix` (xref JuliaLang#54653) with `fix`. The usage is similar: use `fix(i)(f, x)` instead of `Fix{i}(f, x)`. Benefits: * Improved type safety: creating an invalid type such as `Fix{:some_symbol}` or `Fix{-7}` is not possible. * The design should be friendlier to future extensions. E.g., suppose that publicly-facing functionality for fixing a keyword (instead of positional) argument was desired, it could be achieved by adding a new method to `fix` taking a `Symbol`, instead of adding new public names. Lots of changes are shared with PR JuliaLang#56425, if one of them gets merged the other will be greatly simplified.
1 parent afdba95 commit 8b2e9ef

File tree

11 files changed

+378
-74
lines changed

11 files changed

+378
-74
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ New library functions
8787
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
8888
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
8989
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
90-
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
90+
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).
9191

9292
New library features
9393
--------------------

base/Base_compiler.jl

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ include("error.jl")
213213
include("bool.jl")
214214
include("number.jl")
215215
include("int.jl")
216+
include("typedomainnumbers.jl")
216217
include("operators.jl")
217218
include("pointer.jl")
218219
include("refvalue.jl")

base/operators.jl

+126-32
Original file line numberDiff line numberDiff line change
@@ -1153,55 +1153,149 @@ julia> filter(!isletter, str)
11531153
!(f::Function) = (!) f
11541154
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f
11551155

1156+
const _PositiveInteger = _TypeDomainNumbers.PositiveIntegers.PositiveInteger
1157+
1158+
struct PartiallyAppliedFunction{Position <: _PositiveInteger, Func, Arg} <: Function
1159+
partially_applied_argument_position::Position
1160+
f::Func
1161+
x::Arg
1162+
1163+
function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _PositiveInteger, Func}
1164+
Pos = Position::DataType
1165+
pos = Pos.instance
1166+
new{Pos, _stable_typeof(func), _stable_typeof(arg)}(pos, func, arg)
1167+
end
1168+
end
1169+
1170+
function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
1171+
getfield(v, s)
1172+
end # avoid overspecialization
1173+
1174+
function Base.show(
1175+
(@nospecialize io::Base.IO),
1176+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position}}),
1177+
) where {Position <: _PositiveInteger}
1178+
if Position isa DataType
1179+
print(io, "fix(")
1180+
show(io, Position.instance)
1181+
print(io, ')')
1182+
else
1183+
show(io, PartiallyAppliedFunction)
1184+
print(io, '{')
1185+
show(io, Position)
1186+
print(io, '}')
1187+
end
1188+
end
1189+
1190+
function Base.show(
1191+
(@nospecialize io::Base.IO),
1192+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func}}),
1193+
) where {Position <: _PositiveInteger, Func}
1194+
show(io, PartiallyAppliedFunction{Position})
1195+
print(io, '{')
1196+
show(io, Func)
1197+
print(io, '}')
1198+
end
1199+
1200+
function Base.show(
1201+
(@nospecialize io::Base.IO),
1202+
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func, Arg}}),
1203+
) where {Position <: _PositiveInteger, Func, Arg}
1204+
show(io, PartiallyAppliedFunction{Position, Func})
1205+
print(io, '{')
1206+
show(io, Arg)
1207+
print(io, '}')
1208+
end
1209+
1210+
function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
1211+
print(io, "fix(")
1212+
show(io, p.partially_applied_argument_position)
1213+
print(io, ")(")
1214+
show(io, p.f)
1215+
print(io, ", ")
1216+
show(io, p.x)
1217+
print(io, ')')
1218+
end
1219+
1220+
function _partially_applied_function_check(m::Int, nm1::Int)
1221+
if m < nm1
1222+
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
1223+
end
1224+
end
1225+
1226+
function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
1227+
n = partial.partially_applied_argument_position
1228+
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
1229+
_partially_applied_function_check(M, Int(nm1))
1230+
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
1231+
partial.f(args_left..., partial.x, args_right...; kws...)
1232+
end
1233+
11561234
"""
1157-
Fix{N}(f, x)
1235+
fix(::Integer)::UnionAll
1236+
1237+
Return a [`UnionAll`](@ref) type such that:
1238+
* It's a constructor taking two arguments:
1239+
1. A function to be partially applied
1240+
2. An argument of the above function to be fixed
1241+
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.
11581242
1159-
A type representing a partially-applied version of a function `f`, with the argument
1160-
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
1161-
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
1243+
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
1244+
1245+
See also: [`Fix1`](@ref), [`Fix2`](@ref).
11621246
11631247
!!! compat "Julia 1.12"
1164-
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
1165-
are available earlier.
1248+
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).
11661249
11671250
!!! note
1168-
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
1251+
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
11691252
available arguments, rather than an absolute ordering on the target function. For example,
1170-
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
1253+
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
11711254
fixes the first and third arg.
1172-
"""
1173-
struct Fix{N,F,T} <: Function
1174-
f::F
1175-
x::T
11761255
1177-
function Fix{N}(f::F, x) where {N,F}
1178-
if !(N isa Int)
1179-
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
1180-
elseif N < 1
1181-
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
1182-
end
1183-
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
1184-
end
1185-
end
1256+
### Examples
11861257
1187-
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
1188-
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
1189-
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
1190-
end
1258+
```jldoctest
1259+
julia> Base.fix(2)(-, 3)(7)
1260+
4
11911261
1192-
# Special cases for improved constant propagation
1193-
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
1194-
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
1262+
julia> Base.fix(2) === Base.Fix2
1263+
true
11951264
1265+
julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
1266+
true
1267+
```
11961268
"""
1197-
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
1269+
function fix(@nospecialize m::Integer)
1270+
n = Int(m)::Int
1271+
if n 0
1272+
throw(ArgumentError("the index of the partially applied argument must be positive"))
1273+
end
1274+
k = _TypeDomainNumbers.Utils.from_abs_int(n)
1275+
PartiallyAppliedFunction{typeof(k)}
1276+
end
1277+
1278+
"""
1279+
Fix1::UnionAll
1280+
1281+
[`fix(1)`](@ref Base.fix).
11981282
"""
1199-
const Fix1{F,T} = Fix{1,F,T}
1283+
const Fix1 = fix(1)
12001284

12011285
"""
1202-
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
1286+
Fix2::UnionAll
1287+
1288+
[`fix(2)`](@ref Base.fix).
12031289
"""
1204-
const Fix2{F,T} = Fix{2,F,T}
1290+
const Fix2 = fix(2)
1291+
1292+
# Special cases for improved constant propagation
1293+
function (partial::Fix1)(x; kws...)
1294+
partial.f(partial.x, x; kws...)
1295+
end
1296+
function (partial::Fix2)(x; kws...)
1297+
partial.f(x, partial.x; kws...)
1298+
end
12051299

12061300

12071301
"""

base/public.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public
1414
AsyncCondition,
1515
CodeUnits,
1616
Event,
17-
Fix,
17+
fix,
1818
Fix1,
1919
Fix2,
2020
Generator,

base/tuple.jl

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
module _TupleTypeByLength
4+
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
5+
const Tuple1OrMore = Tuple{Any, Vararg}
6+
const Tuple2OrMore = Tuple{Any, Any, Vararg}
7+
const Tuple32OrMore = Tuple{
8+
Any, Any, Any, Any, Any, Any, Any, Any,
9+
Any, Any, Any, Any, Any, Any, Any, Any,
10+
Any, Any, Any, Any, Any, Any, Any, Any,
11+
Any, Any, Any, Any, Any, Any, Any, Any,
12+
Vararg{Any, N},
13+
} where {N}
14+
end
15+
316
# Document NTuple here where we have everything needed for the doc system
417
"""
518
NTuple{N, T}
@@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
358371
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
359372
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
360373
# stop inlining after some number of arguments to avoid code blowup
361-
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
362-
Any,Any,Any,Any,Any,Any,Any,Any,
363-
Any,Any,Any,Any,Any,Any,Any,Any,
364-
Any,Any,Any,Any,Any,Any,Any,Any,
365-
Vararg{Any,N}}
374+
const Any32 = _TupleTypeByLength.Tuple32OrMore
366375
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
367376
T,T,T,T,T,T,T,T,
368377
T,T,T,T,T,T,T,T,

0 commit comments

Comments
 (0)