11from __future__ import annotations
22
3- from itertools import islice
3+ import functools
4+ import itertools
45from typing import (
56 Any ,
67 Callable ,
2122TOther = TypeVar ("TOther" )
2223
2324
25+ def _batched (iterable : Iterator [TEnd ], n : Optional [int ]) -> Iterator [Tuple [TEnd , ...]]:
26+ """Can be replaced by itertools.batched once using Python 3.12+."""
27+ while (elements := tuple (itertools .islice (iterable , n ))) != ():
28+ yield elements
29+
30+
2431def _identity (input_iterator : Iterator [TEnd ]) -> Iterator [TEnd ]:
2532 yield from input_iterator
2633
@@ -122,10 +129,10 @@ def filter(self, func: Callable[[TEnd], bool]) -> Chain[TStart, TEnd]: # noqa:
122129 Remove elements from the stream that do not pass the filter function.
123130 """
124131
132+ @functools .wraps (func )
125133 def new_action (previous_step : Iterator [TEnd ]) -> Iterator [TEnd ]:
126134 return filter (func , previous_step )
127135
128- new_action .__name__ = func .__name__
129136 return self .flat_map (new_action )
130137
131138 def map ( # noqa: A003
@@ -135,13 +142,13 @@ def map( # noqa: A003
135142 Return a single transformed element from each input element.
136143 """
137144
145+ @functools .wraps (func )
138146 def new_action (previous_step : Iterator [TEnd ]) -> Iterator [TOther ]:
139147 return map (func , previous_step )
140148
141- new_action .__name__ = func .__name__
142149 return self .flat_map (new_action )
143150
144- def map_tuple (
151+ def batched_map (
145152 self , func : Callable [[Tuple [TEnd , ...]], TOther ], n : Optional [int ] = None
146153 ) -> Chain [TStart , TOther ]:
147154 """
@@ -151,11 +158,11 @@ def map_tuple(
151158 an iterator of 1 element.
152159 """
153160
161+ @functools .wraps (func )
154162 def new_action (previous_step : Iterator [TEnd ]) -> Iterator [TOther ]:
155- while elements := tuple ( islice ( previous_step , n ) ):
163+ for elements in _batched ( previous_step , n ):
156164 yield func (elements )
157165
158- new_action .__name__ = func .__name__
159166 return self .flat_map (new_action )
160167
161168 def get_counts (self ) -> List [Dict [str , Any ]]:
0 commit comments