@@ -9,34 +9,36 @@ class StreamingClientCallListener[Res](
99 prefetch : Option [Int ],
1010 runtime : Runtime [Any ],
1111 call : ZClientCall [? , Res ],
12- queue : Queue [ResponseFrame [Res ]],
13- buffered : Ref [Int ]
12+ queue : Queue [ResponseFrame [Res ]]
1413) extends ClientCall .Listener [Res ] {
15- private val increment = if (prefetch.isDefined) buffered.update(_ + 1 ) else ZIO .unit
16- private val fetchOne = if (prefetch.isDefined) ZIO .unit else call.request(1 )
17- private val fetchMore = prefetch match {
18- case None => ZIO .unit
19- case Some (n) => buffered.get.flatMap(b => call.request(n - b).when(n > b))
20- }
14+ private val fetchOne =
15+ if (prefetch.isDefined) ZIO .unit else call.request(1 )
16+
17+ private def fetchMore (n : Int ) =
18+ if (prefetch.isDefined) call.request(n) else ZIO .unit
2119
2220 private def unsafeRun (task : IO [Any , Unit ]): Unit =
2321 Unsafe .unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure())
2422
25- private def handle (promise : Promise [StatusException , Unit ])(
26- chunk : Chunk [ResponseFrame [Res ]]
27- ) = (chunk.lastOption match {
28- case Some (ResponseFrame .Trailers (status, trailers)) =>
29- val exit = if (status.isOk) Exit .unit else Exit .fail(new StatusException (status, trailers))
30- promise.done(exit) *> queue.shutdown
31- case _ =>
32- buffered.update(_ - chunk.size) *> fetchMore
33- }).as(chunk)
23+ private def handle (promise : Promise [StatusException , Unit ])(chunk : Chunk [ResponseFrame [Res ]]) =
24+ if (chunk.isEmpty) ZIO .unit
25+ else {
26+ chunk.last match {
27+ case ResponseFrame .Trailers (status, trailers) =>
28+ val exit =
29+ if (status.isOk) Exit .unit
30+ else Exit .fail(new StatusException (status, trailers))
31+ promise.done(exit) *> queue.shutdown
32+ case _ =>
33+ fetchMore(chunk.size)
34+ }
35+ }
3436
3537 override def onHeaders (headers : Metadata ): Unit =
36- unsafeRun(queue.offer(ResponseFrame .Headers (headers)) *> increment )
38+ unsafeRun(queue.offer(ResponseFrame .Headers (headers)).unit )
3739
3840 override def onMessage (message : Res ): Unit =
39- unsafeRun(queue.offer(ResponseFrame .Message (message)) *> increment *> fetchOne)
41+ unsafeRun(queue.offer(ResponseFrame .Message (message)) *> fetchOne)
4042
4143 override def onClose (status : Status , trailers : Metadata ): Unit =
4244 unsafeRun(queue.offer(ResponseFrame .Trailers (status, trailers)).unit)
@@ -45,15 +47,14 @@ class StreamingClientCallListener[Res](
4547 ZStream .fromZIO(Promise .make[StatusException , Unit ]).flatMap { promise =>
4648 ZStream
4749 .fromQueue(queue, prefetch.getOrElse(ZStream .DefaultChunkSize ))
48- .mapChunksZIO (handle(promise))
50+ .tapChunks (handle(promise))
4951 .concat(ZStream .execute(promise.await))
5052 }
5153}
5254
5355object StreamingClientCallListener {
5456 def make [Res ](call : ZClientCall [? , Res ], prefetch : Option [Int ]): UIO [StreamingClientCallListener [Res ]] = for {
55- runtime <- ZIO .runtime[Any ]
56- queue <- Queue .unbounded[ResponseFrame [Res ]]
57- buffered <- Ref .make(0 )
58- } yield new StreamingClientCallListener (prefetch, runtime, call, queue, buffered)
57+ runtime <- ZIO .runtime[Any ]
58+ queue <- Queue .unbounded[ResponseFrame [Res ]]
59+ } yield new StreamingClientCallListener (prefetch, runtime, call, queue)
5960}
0 commit comments