Skip to content

Commit fe066d6

Browse files
committed
Implement intersectBySorted API
1 parent 4bc714f commit fe066d6

File tree

4 files changed

+124
-5
lines changed

4 files changed

+124
-5
lines changed

src/Streamly/Internal/Data/Stream/IsStream/Top.hs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
2828
-- | These are not exactly set operations because streams are not
2929
-- necessarily sets, they may have duplicated elements.
3030
, intersectBy
31-
, mergeIntersectBy
31+
, intersectBySorted
3232
, differenceBy
3333
, mergeDifferenceBy
3434
, unionBy
@@ -65,6 +65,7 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
6565
import Streamly.Internal.Data.Stream.IsStream.Type
6666
(IsStream(..), adapt, foldl', fromList)
6767
import Streamly.Internal.Data.Stream.Serial (SerialT)
68+
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
6869
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
6970

7071
import qualified Data.List as List
@@ -79,6 +80,7 @@ import qualified Streamly.Internal.Data.Stream.IsStream.Expand as Stream
7980
import qualified Streamly.Internal.Data.Stream.IsStream.Reduce as Stream
8081
import qualified Streamly.Internal.Data.Stream.IsStream.Transform as Stream
8182
import qualified Streamly.Internal.Data.Stream.IsStream.Type as IsStream
83+
import qualified Streamly.Internal.Data.Stream.StreamD as StreamD
8284

8385
import Prelude hiding (filter, zipWith, concatMap, concat)
8486

@@ -580,11 +582,12 @@ intersectBy eq s1 s2 =
580582
--
581583
-- Time: O(m+n)
582584
--
583-
-- /Unimplemented/
584-
{-# INLINE mergeIntersectBy #-}
585-
mergeIntersectBy :: -- (IsStream t, Monad m) =>
585+
-- /Pre-release/
586+
{-# INLINE intersectBySorted #-}
587+
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
586588
(a -> a -> Ordering) -> t m a -> t m a -> t m a
587-
mergeIntersectBy _eq _s1 _s2 = undefined
589+
intersectBySorted eq s1 =
590+
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
588591

589592
-- Roughly joinLeft s1 s2 = s1 `difference` s2 + s1 `intersection` s2
590593

src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ module Streamly.Internal.Data.Stream.StreamD.Nesting
142142
-- | Opposite to compact in ArrayStream
143143
, splitInnerBy
144144
, splitInnerBySuffix
145+
, intersectBySorted
145146
)
146147
where
147148

@@ -482,6 +483,59 @@ mergeBy
482483
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
483484
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
484485

486+
-------------------------------------------------------------------------------
487+
-- Intersection of sorted streams ---------------------------------------------
488+
-------------------------------------------------------------------------------
489+
{-# INLINE_NORMAL intersectBySorted #-}
490+
intersectBySorted
491+
:: (MonadIO m, Eq a)
492+
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
493+
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
494+
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
495+
496+
where
497+
{-# INLINE_LATE step #-}
498+
499+
-- step 1
500+
step gst (Just sa, sb, Nothing, b, Nothing) = do
501+
r <- stepa gst sa
502+
return $ case r of
503+
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
504+
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
505+
Stop -> Stop
506+
507+
-- step 2
508+
step gst (sa, Just sb, a, Nothing, Nothing) = do
509+
r <- stepb gst sb
510+
return $ case r of
511+
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
512+
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
513+
Stop -> Stop
514+
515+
-- step 3
516+
-- both the values are available compare it
517+
step _ (sa, sb, Just a, Just b, Nothing) = do
518+
let res = cmp a b
519+
return $ case res of
520+
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
521+
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
522+
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
523+
524+
-- step 4
525+
-- Matching element
526+
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
527+
r1 <- stepa gst sa
528+
return $ case r1 of
529+
Yield a' sa' -> do
530+
if a' == b -- match with prev a
531+
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
532+
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
533+
534+
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
535+
Stop -> Stop
536+
537+
step _ (_, _, _, _, _) = return Stop
538+
485539
------------------------------------------------------------------------------
486540
-- Combine N Streams - unfoldMany
487541
------------------------------------------------------------------------------
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
module Main (main)
2+
where
3+
4+
import Data.List (intersect, sort)
5+
import Test.QuickCheck
6+
( Gen
7+
, Property
8+
, choose
9+
, forAll
10+
, listOf
11+
)
12+
import Test.QuickCheck.Monadic (monadicIO, assert, run)
13+
import qualified Streamly.Prelude as S
14+
import qualified Streamly.Internal.Data.Stream.IsStream.Top as Top
15+
16+
import Prelude hiding
17+
(maximum, minimum, elem, notElem, null, product, sum, head, last, take)
18+
import Test.Hspec as H
19+
import Test.Hspec.QuickCheck
20+
21+
min_value :: Int
22+
min_value = 0
23+
24+
max_value :: Int
25+
max_value = 10000
26+
27+
chooseInt :: (Int, Int) -> Gen Int
28+
chooseInt = choose
29+
30+
intersectBySorted :: Property
31+
intersectBySorted =
32+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
33+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
34+
monadicIO $ action (sort ls0) (sort ls1)
35+
36+
where
37+
38+
action ls0 ls1 = do
39+
v1 <-
40+
run
41+
$ S.toList
42+
$ Top.intersectBySorted
43+
compare
44+
(S.fromList ls0)
45+
(S.fromList ls1)
46+
let v2 = intersect ls0 ls1
47+
assert (v1 == sort v2)
48+
49+
-------------------------------------------------------------------------------
50+
moduleName :: String
51+
moduleName = "Data.Stream.Top"
52+
53+
main :: IO ()
54+
main = hspec $ do
55+
describe moduleName $ do
56+
-- intersect
57+
prop "intersectBySorted" Main.intersectBySorted

test/streamly-tests.cabal

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,8 @@ test-suite version-bounds
434434
import: test-options
435435
type: exitcode-stdio-1.0
436436
main-is: version-bounds.hs
437+
438+
test-suite Data.Stream.Top
439+
import: test-options
440+
type: exitcode-stdio-1.0
441+
main-is: Streamly/Test/Data/Stream/Top.hs

0 commit comments

Comments
 (0)