Skip to content

Commit 3f12fd9

Browse files
committed
Implement take, truncate, skip and drop
1 parent aabde2f commit 3f12fd9

File tree

3 files changed

+193
-21
lines changed

3 files changed

+193
-21
lines changed

src/FSharp.Control.TaskSeq/TaskSeq.fs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,15 @@ open System.Threading.Tasks
66

77
#nowarn "57"
88

9+
// Just for convenience
10+
module Internal = TaskSeqInternal
11+
912
[<AutoOpen>]
1013
module TaskSeqExtensions =
14+
// these need to be in a module, not a type for proper auto-initialization of generic values
1115
module TaskSeq =
16+
let empty<'T> = Internal.empty<'T>
1217

13-
let empty<'T> =
14-
{ new IAsyncEnumerable<'T> with
15-
member _.GetAsyncEnumerator(_) =
16-
{ new IAsyncEnumerator<'T> with
17-
member _.MoveNextAsync() = ValueTask.False
18-
member _.Current = Unchecked.defaultof<'T>
19-
member _.DisposeAsync() = ValueTask.CompletedTask
20-
}
21-
}
22-
23-
// Just for convenience
24-
module Internal = TaskSeqInternal
2518

2619
[<Sealed; AbstractClass>]
2720
type TaskSeq private () =
@@ -289,18 +282,27 @@ type TaskSeq private () =
289282

290283
static member choose chooser source = Internal.choose (TryPick chooser) source
291284
static member chooseAsync chooser source = Internal.choose (TryPickAsync chooser) source
285+
292286
static member filter predicate source = Internal.filter (Predicate predicate) source
293287
static member filterAsync predicate source = Internal.filter (PredicateAsync predicate) source
288+
289+
static member skip count source = Internal.skipOrTake Skip count source
290+
static member drop count source = Internal.skipOrTake Drop count source
291+
static member take count source = Internal.skipOrTake Take count source
292+
static member truncate count source = Internal.skipOrTake Truncate count source
293+
294294
static member takeWhile predicate source = Internal.takeWhile Exclusive (Predicate predicate) source
295295
static member takeWhileAsync predicate source = Internal.takeWhile Exclusive (PredicateAsync predicate) source
296296
static member takeWhileInclusive predicate source = Internal.takeWhile Inclusive (Predicate predicate) source
297297
static member takeWhileInclusiveAsync predicate source = Internal.takeWhile Inclusive (PredicateAsync predicate) source
298+
298299
static member tryPick chooser source = Internal.tryPick (TryPick chooser) source
299300
static member tryPickAsync chooser source = Internal.tryPick (TryPickAsync chooser) source
300301
static member tryFind predicate source = Internal.tryFind (Predicate predicate) source
301302
static member tryFindAsync predicate source = Internal.tryFind (PredicateAsync predicate) source
302303
static member tryFindIndex predicate source = Internal.tryFindIndex (Predicate predicate) source
303304
static member tryFindIndexAsync predicate source = Internal.tryFindIndex (PredicateAsync predicate) source
305+
304306
static member except itemsToExclude source = Internal.except itemsToExclude source
305307
static member exceptOfSeq itemsToExclude source = Internal.exceptOfSeq itemsToExclude source
306308

src/FSharp.Control.TaskSeq/TaskSeq.fsi

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,65 @@ type TaskSeq =
725725
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
726726
static member filterAsync: predicate: ('T -> #Task<bool>) -> source: TaskSeq<'T> -> TaskSeq<'T>
727727

728+
/// <summary>
729+
/// Returns a task sequence that, when iterated, skips <paramref name="count" /> elements of the
730+
/// underlying sequence, and then returns the remainder of the elements. Raises an exception if there are not enough
731+
/// elements in the sequence. See <see cref="drop" /> for a version that does not raise an exception.
732+
/// See also <see cref="take" /> for the inverse of this operation.
733+
/// </summary>
734+
///
735+
/// <param name="count">The number of items to skip.</param>
736+
/// <param name="source">The input task sequence.</param>
737+
/// <returns>The resulting task sequence.</returns>
738+
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
739+
/// <exception cref="T:ArgumentException">Thrown when <paramref name="count" /> is less than zero.</exception>
740+
/// <exception cref="T:InvalidOperationException">Thrown when count exceeds the number of elements in the sequence.</exception>
741+
static member skip: count: int -> source: TaskSeq<'T> -> TaskSeq<'T>
742+
743+
744+
/// <summary>
745+
/// Returns a task sequence that, when iterated, drops at most <paramref name="count" /> elements of the
746+
/// underlying sequence, and then returns the remainder of the elements, if any.
747+
/// See <see cref="skip" /> for a version that raises an exception if there
748+
/// are not enough elements. See also <see cref="truncate" /> for the inverse of this operation.
749+
/// </summary>
750+
///
751+
/// <param name="count">The number of items to drop.</param>
752+
/// <param name="source">The input task sequence.</param>
753+
/// <returns>The resulting task sequence.</returns>
754+
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
755+
/// <exception cref="T:ArgumentException">Thrown when <paramref name="count" /> is less than zero.</exception>
756+
static member drop: count: int -> source: TaskSeq<'T> -> TaskSeq<'T>
757+
758+
/// <summary>
759+
/// Returns a task sequence that, when iterated, yields <paramref name="count" /> elements of the
760+
/// underlying sequence, and then returns no further elements. Raises an exception if there are not enough
761+
/// elements in the sequence. See <see cref="truncate" /> for a version that does not raise an exception.
762+
/// See also <see cref="skip" /> for the inverse of this operation.
763+
/// </summary>
764+
///
765+
/// <param name="count">The number of items to take.</param>
766+
/// <param name="source">The input task sequence.</param>
767+
/// <returns>The resulting task sequence.</returns>
768+
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
769+
/// <exception cref="T:ArgumentException">Thrown when <paramref name="count" /> is less than zero.</exception>
770+
/// <exception cref="T:InvalidOperationException">Thrown when count exceeds the number of elements in the sequence.</exception>
771+
static member take: count: int -> source: TaskSeq<'T> -> TaskSeq<'T>
772+
773+
/// <summary>
774+
/// Returns a task sequence that, when iterated, yields at most <paramref name="count" /> elements of the underlying
775+
/// sequence, truncating the remainder, if any.
776+
/// See <see cref="take" /> for a version that raises an exception if there are not enough elements in the
777+
/// sequence. See also <see cref="drop" /> for the inverse of this operation.
778+
/// </summary>
779+
///
780+
/// <param name="count">The maximum number of items to enumerate.</param>
781+
/// <param name="source">The input task sequence.</param>
782+
/// <returns>The resulting task sequence.</returns>
783+
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
784+
/// <exception cref="T:ArgumentException">Thrown when <paramref name="count" /> is less than zero.</exception>
785+
static member truncate: count: int -> source: TaskSeq<'T> -> TaskSeq<'T>
786+
728787
/// <summary>
729788
/// Returns a task sequence that, when iterated, yields elements of the underlying sequence while the
730789
/// given function <paramref name="predicate" /> returns <see cref="true" />, and then returns no further elements.

src/FSharp.Control.TaskSeq/TaskSeqInternal.fs

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ type internal WhileKind =
1818
/// The item under test is always excluded
1919
| Exclusive
2020

21+
[<Struct>]
22+
type internal TakeOrSkipKind =
23+
/// use the Seq.take semantics, raises exception if not enough elements
24+
| Take
25+
/// use the Seq.skip semantics, raises exception if not enough elements
26+
| Skip
27+
/// use the Seq.truncate semantics, safe operation, returns all if count exceeds the seq
28+
| Truncate
29+
/// no Seq equiv, but like Stream.drop in Scala: safe operation, return empty if not enough elements
30+
| Drop
31+
2132
[<Struct>]
2233
type internal Action<'T, 'U, 'TaskU when 'TaskU :> Task<'U>> =
2334
| CountableAction of countable_action: (int -> 'T -> 'U)
@@ -51,20 +62,15 @@ module internal TaskSeqInternal =
5162
if isNull arg then
5263
nullArg argName
5364

54-
let inline raiseEmptySeq () =
55-
ArgumentException("The asynchronous input sequence was empty.", "source")
56-
|> raise
65+
let inline raiseEmptySeq () = invalidArg "source" "The input task sequence was empty."
5766

58-
let inline raiseCannotBeNegative (name: string) =
59-
ArgumentException("The value cannot be negative", name)
60-
|> raise
67+
let inline raiseCannotBeNegative name = invalidArg name "The value must be non-negative"
6168

6269
let inline raiseInsufficient () =
63-
ArgumentException("The asynchronous input sequence was has an insufficient number of elements.", "source")
64-
|> raise
70+
invalidArg "source" "The input task sequence was has an insufficient number of elements."
6571

6672
let inline raiseNotFound () =
67-
KeyNotFoundException("The predicate function or index did not satisfy any item in the async sequence.")
73+
KeyNotFoundException("The predicate function or index did not satisfy any item in the task sequence.")
6874
|> raise
6975

7076
let isEmpty (source: TaskSeq<_>) =
@@ -76,6 +82,16 @@ module internal TaskSeqInternal =
7682
return not step
7783
}
7884

85+
let empty<'T> =
86+
{ new IAsyncEnumerable<'T> with
87+
member _.GetAsyncEnumerator(_) =
88+
{ new IAsyncEnumerator<'T> with
89+
member _.MoveNextAsync() = ValueTask.False
90+
member _.Current = Unchecked.defaultof<'T>
91+
member _.DisposeAsync() = ValueTask.CompletedTask
92+
}
93+
}
94+
7995
let singleton (value: 'T) =
8096
{ new IAsyncEnumerable<'T> with
8197
member _.GetAsyncEnumerator(_) =
@@ -613,6 +629,101 @@ module internal TaskSeqInternal =
613629
| false -> ()
614630
}
615631

632+
633+
let skipOrTake skipOrTake count (source: TaskSeq<_>) =
634+
checkNonNull (nameof source) source
635+
636+
if count < 0 then
637+
raiseCannotBeNegative (nameof count)
638+
639+
match skipOrTake with
640+
| Skip ->
641+
// don't create a new sequence if count = 0
642+
if count = 0 then
643+
source
644+
else
645+
taskSeq {
646+
use e = source.GetAsyncEnumerator CancellationToken.None
647+
648+
for _ in 1..count do
649+
let! step = e.MoveNextAsync()
650+
651+
if not step then
652+
raiseInsufficient ()
653+
654+
let mutable cont = true
655+
656+
while cont do
657+
yield e.Current
658+
let! moveNext = e.MoveNextAsync()
659+
cont <- moveNext
660+
661+
}
662+
| Drop ->
663+
// don't create a new sequence if count = 0
664+
if count = 0 then
665+
source
666+
else
667+
taskSeq {
668+
use e = source.GetAsyncEnumerator CancellationToken.None
669+
670+
let! step = e.MoveNextAsync()
671+
let mutable cont = step
672+
let mutable pos = 0
673+
674+
// skip, or stop looping if we reached the end
675+
while cont do
676+
pos <- pos + 1
677+
let! moveNext = e.MoveNextAsync()
678+
cont <- moveNext && pos <= count
679+
680+
// return the rest
681+
while cont do
682+
yield e.Current
683+
let! moveNext = e.MoveNextAsync()
684+
cont <- moveNext
685+
686+
}
687+
| Take ->
688+
// don't initialize an empty task sequence
689+
if count = 0 then
690+
empty
691+
else
692+
taskSeq {
693+
use e = source.GetAsyncEnumerator CancellationToken.None
694+
695+
for _ in count .. - 1 .. 1 do
696+
let! step = e.MoveNextAsync()
697+
698+
if not step then
699+
raiseInsufficient ()
700+
701+
yield e.Current
702+
}
703+
704+
| Truncate ->
705+
// don't create a new sequence if count = 0
706+
if count = 0 then
707+
empty
708+
else
709+
taskSeq {
710+
use e = source.GetAsyncEnumerator CancellationToken.None
711+
712+
let! step = e.MoveNextAsync()
713+
let mutable cont = step
714+
let mutable pos = 0
715+
716+
// return items until we've exhausted the seq
717+
// report this line, weird error:
718+
//while! e.MoveNextAsync() && pos < 1 do
719+
while cont do
720+
yield e.Current
721+
pos <- pos + 1
722+
let! moveNext = e.MoveNextAsync()
723+
cont <- moveNext && pos <= count
724+
725+
}
726+
616727
let takeWhile whileKind predicate (source: TaskSeq<_>) =
617728
checkNonNull (nameof source) source
618729

0 commit comments

Comments
 (0)