Skip to content

Commit c943854

Browse files
committed
Fix client logic for backpressuring streams with prefetch
1 parent 8301b09 commit c943854

File tree

7 files changed

+181
-163
lines changed

7 files changed

+181
-163
lines changed

core/src/main/scalajvm/scalapb/zio_grpc/client/StreamingClientCallListener.scala

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,36 @@ class StreamingClientCallListener[Res](
99
prefetch: Option[Int],
1010
runtime: Runtime[Any],
1111
call: ZClientCall[?, Res],
12-
queue: Queue[ResponseFrame[Res]],
13-
buffered: Ref[Int]
12+
queue: Queue[ResponseFrame[Res]]
1413
) extends ClientCall.Listener[Res] {
15-
private val increment = if (prefetch.isDefined) buffered.update(_ + 1) else ZIO.unit
16-
private val fetchOne = if (prefetch.isDefined) ZIO.unit else call.request(1)
17-
private val fetchMore = prefetch match {
18-
case None => ZIO.unit
19-
case Some(n) => buffered.get.flatMap(b => call.request(n - b).when(n > b))
20-
}
14+
private val fetchOne =
15+
if (prefetch.isDefined) ZIO.unit else call.request(1)
16+
17+
private def fetchMore(n: Int) =
18+
if (prefetch.isDefined) call.request(n) else ZIO.unit
2119

2220
private def unsafeRun(task: IO[Any, Unit]): Unit =
2321
Unsafe.unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure())
2422

25-
private def handle(promise: Promise[StatusException, Unit])(
26-
chunk: Chunk[ResponseFrame[Res]]
27-
) = (chunk.lastOption match {
28-
case Some(ResponseFrame.Trailers(status, trailers)) =>
29-
val exit = if (status.isOk) Exit.unit else Exit.fail(new StatusException(status, trailers))
30-
promise.done(exit) *> queue.shutdown
31-
case _ =>
32-
buffered.update(_ - chunk.size) *> fetchMore
33-
}).as(chunk)
23+
private def handle(promise: Promise[StatusException, Unit])(chunk: Chunk[ResponseFrame[Res]]) =
24+
if (chunk.isEmpty) ZIO.unit
25+
else {
26+
chunk.last match {
27+
case ResponseFrame.Trailers(status, trailers) =>
28+
val exit =
29+
if (status.isOk) Exit.unit
30+
else Exit.fail(new StatusException(status, trailers))
31+
promise.done(exit) *> queue.shutdown
32+
case _ =>
33+
fetchMore(chunk.size)
34+
}
35+
}
3436

3537
override def onHeaders(headers: Metadata): Unit =
36-
unsafeRun(queue.offer(ResponseFrame.Headers(headers)) *> increment)
38+
unsafeRun(queue.offer(ResponseFrame.Headers(headers)).unit)
3739

3840
override def onMessage(message: Res): Unit =
39-
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> increment *> fetchOne)
41+
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> fetchOne)
4042

4143
override def onClose(status: Status, trailers: Metadata): Unit =
4244
unsafeRun(queue.offer(ResponseFrame.Trailers(status, trailers)).unit)
@@ -45,15 +47,14 @@ class StreamingClientCallListener[Res](
4547
ZStream.fromZIO(Promise.make[StatusException, Unit]).flatMap { promise =>
4648
ZStream
4749
.fromQueue(queue, prefetch.getOrElse(ZStream.DefaultChunkSize))
48-
.mapChunksZIO(handle(promise))
50+
.tapChunks(handle(promise))
4951
.concat(ZStream.execute(promise.await))
5052
}
5153
}
5254

5355
object StreamingClientCallListener {
5456
def make[Res](call: ZClientCall[?, Res], prefetch: Option[Int]): UIO[StreamingClientCallListener[Res]] = for {
55-
runtime <- ZIO.runtime[Any]
56-
queue <- Queue.unbounded[ResponseFrame[Res]]
57-
buffered <- Ref.make(0)
58-
} yield new StreamingClientCallListener(prefetch, runtime, call, queue, buffered)
57+
runtime <- ZIO.runtime[Any]
58+
queue <- Queue.unbounded[ResponseFrame[Res]]
59+
} yield new StreamingClientCallListener(prefetch, runtime, call, queue)
5960
}

e2e/protos/src/main/protobuf/testservice.proto

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,37 @@ package scalapb.zio_grpc;
55
import "scalapb/scalapb.proto";
66

77
message Request {
8-
enum Scenario {
9-
OK = 0;
10-
ERROR_NOW = 1; // fail with an error
11-
ERROR_AFTER = 2; // for server streaming, error after two responses
12-
DELAY = 3; // do not return a response. for testing cancellations
13-
DIE = 4; // fail
14-
UNAVAILABLE = 5; // fail with UNAVAILABLE, to test client retries
15-
}
16-
Scenario scenario = 1;
17-
int32 in = 2;
8+
enum Scenario {
9+
OK = 0;
10+
ERROR_NOW = 1; // fail with an error
11+
ERROR_AFTER = 2; // for server streaming, error after two responses
12+
DELAY = 3; // do not return a response, to test cancellations
13+
LARGE_STREAM = 4; // stream of large elements, to test backpressure
14+
DIE = 5; // fail
15+
UNAVAILABLE = 6; // fail with UNAVAILABLE, to test client retries
16+
}
17+
Scenario scenario = 1;
18+
int32 in = 2;
1819
}
1920

20-
message Response {
21-
string out = 1;
22-
}
21+
message Response { string out = 1; }
2322

2423
message ResponseTypeMapped {
25-
option (scalapb.message).type = "scalapb.zio_grpc.WrappedString";
24+
option (scalapb.message).type = "scalapb.zio_grpc.WrappedString";
2625

27-
string out = 1;
26+
string out = 1;
2827
}
2928

3029
service TestService {
31-
rpc Unary(Request) returns (Response);
30+
rpc Unary(Request) returns (Response);
3231

33-
rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped);
32+
rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped);
3433

35-
rpc ServerStreaming(Request) returns (stream Response);
34+
rpc ServerStreaming(Request) returns (stream Response);
3635

37-
rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped);
36+
rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped);
3837

39-
rpc ClientStreaming(stream Request) returns (Response);
38+
rpc ClientStreaming(stream Request) returns (Response);
4039

41-
rpc BidiStreaming(stream Request) returns (stream Response);
40+
rpc BidiStreaming(stream Request) returns (stream Response);
4241
}

0 commit comments

Comments
 (0)