Skip to content

Commit badad9d

Browse files
authored
Implements flatmap (#44792)
flatmap is the composition of map and flatten. It is important for functional programming patterns. Some tasks that can be easily attained with list-comprehensions, including the composition of filter and mapping, or flattening a list of computed lists, can only be attained with do-syntax style if a flatmap functor is available. (Or appending a `|> flatten`, etc.) Filtering can be implemented by outputing empty lists or singleton lists for the values to be removed or kept.
1 parent ad047d0 commit badad9d

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

base/iterators.jl

+24-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import .Base:
2222
getindex, setindex!, get, iterate,
2323
popfirst!, isdone, peek
2424

25-
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition
25+
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition, flatmap
2626

2727
"""
2828
Iterators.map(f, iterators...)
@@ -1162,6 +1162,29 @@ end
11621162
reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it))
11631163
last(f::Flatten) = last(last(f.it))
11641164

1165+
"""
1166+
Iterators.flatmap(f, iterators...)
1167+
1168+
Equivalent to `flatten(map(f, iterators...))`.
1169+
1170+
# Examples
1171+
```jldoctest
1172+
julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
1173+
9-element Vector{Int64}:
1174+
-1
1175+
1
1176+
-2
1177+
0
1178+
2
1179+
-3
1180+
-1
1181+
1
1182+
3
1183+
```
1184+
"""
1185+
# flatmap = flatten ∘ map
1186+
flatmap(f, c...) = flatten(map(f, c...))
1187+
11651188
"""
11661189
partition(collection, n)
11671190

test/iterators.jl

+23
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,29 @@ end
476476
# see #29112, #29464, #29548
477477
@test Base.return_types(Base.IteratorEltype, Tuple{Array}) == [Base.HasEltype]
478478

479+
# flatmap
480+
# -------
481+
@test flatmap(1:3) do j flatmap(1:3) do k
482+
j!=k ? ((j,k),) : ()
483+
end end |> collect == [(j,k) for j in 1:3 for k in 1:3 if j!=k]
484+
# Test inspired by the monad associativity law
485+
fmf(x) = x<0 ? () : (x^2,)
486+
fmg(x) = x<1 ? () : (x/2,)
487+
fmdata = -2:0.75:2
488+
fmv1 = flatmap(tuple.(fmdata)) do h
489+
flatmap(h) do x
490+
gx = fmg(x)
491+
flatmap(gx) do x
492+
fmf(x)
493+
end
494+
end
495+
end
496+
fmv2 = flatmap(tuple.(fmdata)) do h
497+
gh = flatmap(h) do x fmg(x) end
498+
flatmap(gh) do x fmf(x) end
499+
end
500+
@test all(fmv1 .== fmv2)
501+
479502
# partition(c, n)
480503
let v = collect(partition([1,2,3,4,5], 1))
481504
@test all(i->v[i][1] == i, v)

0 commit comments

Comments
 (0)