Skip to content

Commit ac0a69e

Browse files
committed
improved function partial application design
This replaces `Fix` (xref JuliaLang#54653) with `fix`. The usage is similar: use `fix(i)(f, x)` instead of `Fix{i}(f, x)`. Benefits: * Improved type safety: creating an invalid type such as `Fix{:some_symbol}` or `Fix{-7}` is not possible. * The design should be friendlier to future extensions. E.g., suppose that publicly-facing functionality for fixing a keyword (instead of positional) argument was desired, it could be achieved by adding a new method to `fix` taking a `Symbol`, instead of adding new public names. Lots of changes are shared with PR JuliaLang#56425, if one of them gets merged the other will be greatly simplified.
1 parent afdba95 commit ac0a69e

File tree

11 files changed

+354
-72
lines changed

11 files changed

+354
-72
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ New library functions
8787
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
8888
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
8989
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
90-
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
90+
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).
9191

9292
New library features
9393
--------------------

base/Base_compiler.jl

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ include("error.jl")
213213
include("bool.jl")
214214
include("number.jl")
215215
include("int.jl")
216+
include("typedomainnumbers.jl")
216217
include("operators.jl")
217218
include("pointer.jl")
218219
include("refvalue.jl")

base/operators.jl

+104-32
Original file line numberDiff line numberDiff line change
@@ -1153,55 +1153,127 @@ julia> filter(!isletter, str)
11531153
!(f::Function) = (!) f
11541154
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f
11551155

1156+
struct PartiallyAppliedFunction{Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger, Func, Arg} <: Function
1157+
partially_applied_argument_position::Position
1158+
f::Func
1159+
x::Arg
1160+
1161+
global function new_partially_applied_function(pos::_TypeDomainNumbers.PositiveIntegers.PositiveInteger, f::Func, x) where {Func}
1162+
new{typeof(pos), _stable_typeof(f), _stable_typeof(x)}(pos, f, x)
1163+
end
1164+
end
1165+
1166+
function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger, Func}
1167+
pos = (Position::DataType).instance
1168+
new_partially_applied_function(pos, func, arg)
1169+
end
1170+
1171+
function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
1172+
getfield(v, s)
1173+
end # avoid overspecialization
1174+
1175+
function Base.show((@nospecialize io::Base.IO), @nospecialize unused::Type{PartiallyAppliedFunction{Position}}) where {Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger}
1176+
if Position isa DataType
1177+
print(io, "fix(")
1178+
show(io, Position.instance)
1179+
print(io, ')')
1180+
else
1181+
show(io, PartiallyAppliedFunction)
1182+
print(io, '{')
1183+
show(io, Position)
1184+
print(io, '}')
1185+
end
1186+
end
1187+
1188+
function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
1189+
print(io, "fix(")
1190+
show(io, p.partially_applied_argument_position)
1191+
print(io, ")(")
1192+
show(io, p.f)
1193+
print(io, ", ")
1194+
show(io, p.x)
1195+
print(io, ')')
1196+
end
1197+
1198+
function _partially_applied_function_check(m::Int, nm1::Int)
1199+
if m < nm1
1200+
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
1201+
end
1202+
end
1203+
1204+
function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
1205+
n = partial.partially_applied_argument_position
1206+
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
1207+
_partially_applied_function_check(M, Int(nm1))
1208+
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
1209+
partial.f(args_left..., partial.x, args_right...; kws...)
1210+
end
1211+
11561212
"""
1157-
Fix{N}(f, x)
1213+
fix(::Integer)::UnionAll
1214+
1215+
Return a [`UnionAll`](@ref) type such that:
1216+
* It's a constructor taking two arguments:
1217+
1. A function to be partially applied
1218+
2. An argument of the above function to be fixed
1219+
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.
1220+
1221+
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
11581222
1159-
A type representing a partially-applied version of a function `f`, with the argument
1160-
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
1161-
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
1223+
See also: [`Fix1`](@ref), [`Fix2`](@ref).
11621224
11631225
!!! compat "Julia 1.12"
1164-
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
1165-
are available earlier.
1226+
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).
11661227
11671228
!!! note
1168-
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
1229+
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
11691230
available arguments, rather than an absolute ordering on the target function. For example,
1170-
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
1231+
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
11711232
fixes the first and third arg.
1172-
"""
1173-
struct Fix{N,F,T} <: Function
1174-
f::F
1175-
x::T
11761233
1177-
function Fix{N}(f::F, x) where {N,F}
1178-
if !(N isa Int)
1179-
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
1180-
elseif N < 1
1181-
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
1182-
end
1183-
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
1184-
end
1185-
end
1234+
### Examples
11861235
1187-
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
1188-
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
1189-
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
1190-
end
1236+
```jldoctest
1237+
julia> Base.fix(2)(-, 3)(7)
1238+
4
11911239
1192-
# Special cases for improved constant propagation
1193-
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
1194-
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
1240+
julia> Base.fix(2) === Base.Fix2
1241+
true
1242+
1243+
julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
1244+
true
1245+
```
1246+
"""
1247+
function fix(@nospecialize m::Integer)
1248+
n = Int(m)::Int
1249+
if n 0
1250+
throw(ArgumentError("the index of the partially applied argument must be positive"))
1251+
end
1252+
k = _TypeDomainNumbers.Utils.from_abs_int(n)
1253+
PartiallyAppliedFunction{typeof(k)}
1254+
end
11951255

11961256
"""
1197-
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
1257+
Fix1::UnionAll
1258+
1259+
[`fix(1)`](@ref Base.fix).
11981260
"""
1199-
const Fix1{F,T} = Fix{1,F,T}
1261+
const Fix1 = fix(1)
12001262

12011263
"""
1202-
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
1264+
Fix2::UnionAll
1265+
1266+
[`fix(2)`](@ref Base.fix).
12031267
"""
1204-
const Fix2{F,T} = Fix{2,F,T}
1268+
const Fix2 = fix(2)
1269+
1270+
# Special cases for improved constant propagation
1271+
function (partial::Fix1)(x; kws...)
1272+
partial.f(partial.x, x; kws...)
1273+
end
1274+
function (partial::Fix2)(x; kws...)
1275+
partial.f(x, partial.x; kws...)
1276+
end
12051277

12061278

12071279
"""

base/public.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public
1414
AsyncCondition,
1515
CodeUnits,
1616
Event,
17-
Fix,
17+
fix,
1818
Fix1,
1919
Fix2,
2020
Generator,

base/tuple.jl

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
module _TupleTypeByLength
4+
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
5+
const Tuple1OrMore = Tuple{Any, Vararg}
6+
const Tuple2OrMore = Tuple{Any, Any, Vararg}
7+
const Tuple32OrMore = Tuple{
8+
Any, Any, Any, Any, Any, Any, Any, Any,
9+
Any, Any, Any, Any, Any, Any, Any, Any,
10+
Any, Any, Any, Any, Any, Any, Any, Any,
11+
Any, Any, Any, Any, Any, Any, Any, Any,
12+
Vararg{Any, N},
13+
} where {N}
14+
end
15+
316
# Document NTuple here where we have everything needed for the doc system
417
"""
518
NTuple{N, T}
@@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
358371
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
359372
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
360373
# stop inlining after some number of arguments to avoid code blowup
361-
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
362-
Any,Any,Any,Any,Any,Any,Any,Any,
363-
Any,Any,Any,Any,Any,Any,Any,Any,
364-
Any,Any,Any,Any,Any,Any,Any,Any,
365-
Vararg{Any,N}}
374+
const Any32 = _TupleTypeByLength.Tuple32OrMore
366375
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
367376
T,T,T,T,T,T,T,T,
368377
T,T,T,T,T,T,T,T,

base/typedomainnumbers.jl

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
# Adapted from the TypeDomainNaturalNumbers.jl package.
4+
module _TypeDomainNumbers
5+
module Zeros
6+
export Zero
7+
struct Zero end
8+
end
9+
10+
module PositiveIntegers
11+
module RecursiveStep
12+
using ...Zeros
13+
export recursive_step
14+
function recursive_step(@nospecialize t::Type)
15+
Union{Zero, t}
16+
end
17+
end
18+
module UpperBounds
19+
using ..RecursiveStep
20+
abstract type A end
21+
abstract type B{P <: recursive_step(A)} <: A end
22+
abstract type C{P <: recursive_step(B)} <: B{P} end
23+
abstract type D{P <: recursive_step(C)} <: C{P} end
24+
end
25+
using .RecursiveStep
26+
const PositiveIntegerUpperBound = UpperBounds.A
27+
const PositiveIntegerUpperBoundTighter = UpperBounds.D
28+
export
29+
natural_successor, natural_predecessor,
30+
NonnegativeInteger, NonnegativeIntegerUpperBound,
31+
PositiveInteger, PositiveIntegerUpperBound
32+
struct PositiveInteger{
33+
Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter),
34+
} <: PositiveIntegerUpperBoundTighter{Predecessor}
35+
predecessor::Predecessor
36+
global const NonnegativeInteger = recursive_step(PositiveInteger)
37+
global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound)
38+
global function natural_successor(p::P) where {P <: NonnegativeInteger}
39+
new{P}(p)
40+
end
41+
end
42+
function natural_predecessor(@nospecialize o::PositiveInteger)
43+
getfield(o, :predecessor) # avoid specializing `getproperty` for each number
44+
end
45+
end
46+
47+
module IntegersGreaterThanOne
48+
using ..PositiveIntegers
49+
export
50+
IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound,
51+
natural_predecessor_predecessor
52+
const IntegerGreaterThanOne = let t = PositiveInteger
53+
t{P} where {P <: t}
54+
end
55+
const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound
56+
PositiveIntegers.UpperBounds.B{P} where {P <: t}
57+
end
58+
function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne)
59+
natural_predecessor(natural_predecessor(x))
60+
end
61+
end
62+
63+
module Constants
64+
using ..Zeros, ..PositiveIntegers
65+
export n0, n1
66+
const n0 = Zero()
67+
const n1 = natural_successor(n0)
68+
end
69+
70+
module Utils
71+
using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants
72+
using Base: @_foldable_meta
73+
function subtracted_nonnegative((@nospecialize l::NonnegativeInteger), @nospecialize r::NonnegativeInteger)
74+
@_foldable_meta
75+
if r isa PositiveIntegerUpperBound
76+
let a = natural_predecessor(l), b = natural_predecessor(r)
77+
subtracted_nonnegative(a, b)
78+
end
79+
else
80+
l
81+
end
82+
end
83+
function abs_decrement(n::Int)
84+
@_foldable_meta
85+
if signbit(n)
86+
n + true
87+
else
88+
n - true
89+
end
90+
end
91+
function to_int(@nospecialize o::NonnegativeInteger)
92+
@_foldable_meta
93+
if o isa PositiveIntegerUpperBound
94+
let p = natural_predecessor(o), t = to_int(p)
95+
t + true
96+
end
97+
else
98+
0
99+
end
100+
end
101+
function from_abs_int(n::Int)
102+
@_foldable_meta
103+
ret = n0
104+
while !iszero(n)
105+
n = abs_decrement(n)
106+
ret = natural_successor(ret)
107+
end
108+
ret
109+
end
110+
end
111+
112+
module Overloads
113+
using ..PositiveIntegers, ..Utils
114+
function (::Type{Int})(@nospecialize o::NonnegativeInteger)
115+
Utils.to_int(o)
116+
end
117+
function Base.show((@nospecialize io::Base.IO), @nospecialize n::NonnegativeInteger)
118+
i = Int(n)
119+
Base.show(io, i)
120+
end
121+
end
122+
end
123+
124+
module _TypeDomainNumberTupleUtils
125+
using
126+
.._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne,
127+
.._TypeDomainNumbers.Constants, .._TypeDomainNumbers.Utils, .._TupleTypeByLength
128+
using Base: @_total_meta, @_foldable_meta, front, tail
129+
export tuple_type_domain_length, split_tuple, skip_from_front, skip_from_tail
130+
function tuple_type_domain_length(@nospecialize tup::Tuple)
131+
@_total_meta
132+
if tup isa Tuple1OrMore
133+
let t = tail(tup), rec = tuple_type_domain_length(t)
134+
natural_successor(rec)
135+
end
136+
else
137+
n0
138+
end
139+
end
140+
function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
141+
@_foldable_meta
142+
if skip_count isa PositiveIntegerUpperBound
143+
let cm1 = natural_predecessor(skip_count), t = tail(tup)
144+
@inline skip_from_front(t, cm1)
145+
end
146+
else
147+
tup
148+
end
149+
end
150+
function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
151+
@_foldable_meta
152+
if skip_count isa PositiveIntegerUpperBound
153+
let cm1 = natural_predecessor(skip_count), t = front(tup)
154+
@inline skip_from_tail(t, cm1)
155+
end
156+
else
157+
tup
158+
end
159+
end
160+
function split_tuple((@nospecialize tup::Tuple), @nospecialize len_l::NonnegativeInteger)
161+
len = tuple_type_domain_length(tup)
162+
len_r = Utils.subtracted_nonnegative(len, len_l)
163+
tup_l = skip_from_tail(tup, len_r)
164+
tup_r = skip_from_front(tup, len_l)
165+
(tup_l, tup_r)
166+
end
167+
end

doc/src/base/base.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ Base.:(|>)
285285
Base.:(∘)
286286
Base.ComposedFunction
287287
Base.splat
288-
Base.Fix
288+
Base.fix
289289
Base.Fix1
290290
Base.Fix2
291291
```

0 commit comments

Comments
 (0)