Skip to content

Commit 462ee30

Browse files
committed
Cleanup, by applying more in-built python functionality for iterators
1 parent 8ae76d2 commit 462ee30

File tree

4 files changed

+24
-54
lines changed

4 files changed

+24
-54
lines changed

src/pipedata/core/chain.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
3+
from itertools import islice
44
from typing import (
55
Any,
66
Callable,
@@ -16,16 +16,13 @@
1616
overload,
1717
)
1818

19-
from pipedata.core.itertools import take_next, take_up_to_n
20-
2119
TStart = TypeVar("TStart")
2220
TEnd = TypeVar("TEnd")
2321
TOther = TypeVar("TOther")
2422

2523

2624
def _identity(input_iterator: Iterator[TEnd]) -> Iterator[TEnd]:
27-
while (element := take_next(input_iterator)) is not None:
28-
yield element
25+
yield from input_iterator
2926

3027

3128
class CountingIterator(Iterator[TStart]):
@@ -48,38 +45,31 @@ def get_count(self) -> int:
4845
return self._count
4946

5047

51-
class CountedFunc(Generic[TStart, TEnd]):
48+
class ChainLink(Generic[TStart, TEnd]):
5249
def __init__(
5350
self,
5451
func: Callable[[Iterator[TStart]], Iterator[TEnd]],
5552
) -> None:
5653
self._func = func
57-
self._counting_input: Optional[CountingIterator[TStart]] = None
58-
self._counting_output: Optional[CountingIterator[TEnd]] = None
54+
self._input: Optional[CountingIterator[TStart]] = None
55+
self._output: Optional[CountingIterator[TEnd]] = None
5956

6057
@property
6158
def __name__(self) -> str: # noqa: A003
6259
return self._func.__name__
6360

6461
def __call__(self, input_iterator: Iterator[TStart]) -> Iterator[TEnd]:
65-
self._counting_input = CountingIterator(input_iterator)
66-
self._counting_output = CountingIterator(self._func(self._counting_input))
67-
return self._counting_output
62+
self._input = CountingIterator(input_iterator)
63+
self._output = CountingIterator(self._func(self._input))
64+
return self._output
6865

6966
def get_counts(self) -> Tuple[int, int]:
7067
return (
71-
0 if self._counting_input is None else self._counting_input.get_count(),
72-
0 if self._counting_output is None else self._counting_output.get_count(),
68+
0 if self._input is None else self._input.get_count(),
69+
0 if self._output is None else self._output.get_count(),
7370
)
7471

7572

76-
@dataclass
77-
class StepCount:
78-
name: str
79-
inputs: int
80-
outputs: int
81-
82-
8373
class Chain(Generic[TStart, TEnd]):
8474
@overload
8575
def __init__(
@@ -106,11 +96,11 @@ def __init__(
10696
],
10797
) -> None:
10898
self._previous_steps = previous_steps
109-
self._func = CountedFunc(func)
99+
self._func = ChainLink(func)
110100

111101
def __call__(self, input_iterator: Iterator[TStart]) -> Iterator[TEnd]:
112102
if self._previous_steps is None:
113-
func = cast(CountedFunc[TStart, TEnd], self._func)
103+
func = cast(ChainLink[TStart, TEnd], self._func)
114104
return func(input_iterator)
115105

116106
return self._func(self._previous_steps(input_iterator)) # type: ignore
@@ -133,9 +123,7 @@ def filter(self, func: Callable[[TEnd], bool]) -> Chain[TStart, TEnd]: # noqa:
133123
"""
134124

135125
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TEnd]:
136-
while (element := take_next(previous_step)) is not None:
137-
if func(element) is True:
138-
yield element
126+
return filter(func, previous_step)
139127

140128
new_action.__name__ = func.__name__
141129
return self.flat_map(new_action)
@@ -148,8 +136,7 @@ def map( # noqa: A003
148136
"""
149137

150138
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TOther]:
151-
while (element := take_next(previous_step)) is not None:
152-
yield func(element)
139+
return map(func, previous_step)
153140

154141
new_action.__name__ = func.__name__
155142
return self.flat_map(new_action)
@@ -165,7 +152,7 @@ def map_tuple(
165152
"""
166153

167154
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TOther]:
168-
while elements := take_up_to_n(previous_step, n):
155+
while elements := tuple(islice(previous_step, n)):
169156
yield func(elements)
170157

171158
new_action.__name__ = func.__name__

src/pipedata/core/itertools.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

tests/core/test_chain.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Iterator, Tuple
22

33
from pipedata.core import Chain, ChainStart
4-
from pipedata.core.itertools import take_next
54

65

76
def test_chain() -> None:
@@ -86,7 +85,7 @@ def is_even(value: int) -> bool:
8685

8786
def test_chain_flat_map() -> None:
8887
def add_one(input_iterator: Iterator[int]) -> Iterator[int]:
89-
while (element := take_next(input_iterator)) is not None:
88+
for element in input_iterator:
9089
yield element + 1
9190

9291
chain = ChainStart[int]().flat_map(add_one)
@@ -100,11 +99,11 @@ def add_one(input_iterator: Iterator[int]) -> Iterator[int]:
10099

101100
def test_chain_multiple_operations() -> None:
102101
def add_one(input_iterator: Iterator[int]) -> Iterator[int]:
103-
while (element := take_next(input_iterator)) is not None:
102+
for element in input_iterator:
104103
yield element + 1
105104

106105
def multiply_two(input_iterator: Iterator[int]) -> Iterator[int]:
107-
while (element := take_next(input_iterator)) is not None:
106+
for element in input_iterator:
108107
yield element * 2
109108

110109
def is_even(value: int) -> bool:

tests/core/test_stream.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from itertools import islice
12
from typing import Iterable, Iterator, List
23

34
from pipedata.core import ChainStart, StreamStart
4-
from pipedata.core.itertools import take_next, take_up_to_n
55

66

77
def test_stream_to_list() -> None:
@@ -81,20 +81,19 @@ def is_even(value: int) -> bool:
8181

8282
def test_stream_flat_map_identity() -> None:
8383
def identity(input_iterator: Iterator[int]) -> Iterator[int]:
84-
while (element := take_next(input_iterator)) is not None:
85-
yield element
84+
yield from input_iterator
8685

8786
result = StreamStart([0, 1, 2, 3]).flat_map(identity).to_list()
8887
assert result == [0, 1, 2, 3]
8988

9089

9190
def test_stream_flat_map_chain() -> None:
9291
def add_one(input_iterator: Iterator[int]) -> Iterator[int]:
93-
while (element := take_next(input_iterator)) is not None:
92+
for element in input_iterator:
9493
yield element + 1
9594

9695
def multiply_two(input_iterator: Iterator[int]) -> Iterator[int]:
97-
while (element := take_next(input_iterator)) is not None:
96+
for element in input_iterator:
9897
yield element * 2
9998

10099
result = (
@@ -105,7 +104,7 @@ def multiply_two(input_iterator: Iterator[int]) -> Iterator[int]:
105104

106105
def test_stream_flat_map_growing() -> None:
107106
def add_element(input_iterator: Iterator[int]) -> Iterator[int]:
108-
while (element := take_next(input_iterator)) is not None:
107+
for element in input_iterator:
109108
yield element
110109
yield element + 1
111110

@@ -115,7 +114,7 @@ def add_element(input_iterator: Iterator[int]) -> Iterator[int]:
115114

116115
def test_stream_flat_map_shrinking() -> None:
117116
def add_two_values(input_iterator: Iterator[int]) -> Iterator[int]:
118-
while batch := take_up_to_n(input_iterator, 2):
117+
while batch := tuple(islice(input_iterator, 2)):
119118
yield sum(batch)
120119

121120
result = StreamStart([0, 1, 2, 3, 4]).flat_map(add_two_values).to_list()

0 commit comments

Comments
 (0)