@@ -4,19 +4,23 @@ import cats.Applicative
44import cats .data .EitherT
55import cats .effect .{Async , Sync }
66import cats .implicits .*
7+ import fs2 .Stream
78import fs2 .compression .Compression
89import fs2 .io .{readOutputStream , toInputStreamResource }
910import fs2 .text .decodeWithCharset
1011import org .http4s .headers .{`Content-Encoding` , `Content-Type` }
11- import org .http4s .{Charset , ContentCoding , DecodeResult , Entity , EntityDecoder , EntityEncoder , Media , MediaRange , MediaType }
12+ import org .http4s .{ContentCoding , DecodeResult , Entity , EntityDecoder , EntityEncoder , Headers , MediaRange , MediaType }
1213import org .ivovk .connect_rpc_scala .ConnectRpcHttpRoutes .getClass
1314import org .slf4j .{Logger , LoggerFactory }
1415import scalapb .json4s .{JsonFormat , Printer }
1516import scalapb .{GeneratedMessage as Message , GeneratedMessageCompanion as Companion }
1617
18+ import java .net .URLDecoder
19+ import java .util .Base64
20+
1721object MessageCodec {
1822 given [F [_] : Applicative , A <: Message ](using codec : MessageCodec [F ], cmp : Companion [A ]): EntityDecoder [F , A ] =
19- EntityDecoder .decodeBy(MediaRange .`*/*`)(codec.decode)
23+ EntityDecoder .decodeBy(MediaRange .`*/*`)(m => codec.decode( RequestEntity (m)) )
2024
2125 given [F [_], A <: Message ](using codec : MessageCodec [F ]): EntityEncoder [F , A ] =
2226 EntityEncoder .encodeBy(`Content-Type`(codec.mediaType))(codec.encode)
@@ -26,27 +30,33 @@ trait MessageCodec[F[_]] {
2630
2731 val mediaType : MediaType
2832
29- def decode [A <: Message ](m : Media [F ])(using cmp : Companion [A ]): DecodeResult [F , A ]
33+ def decode [A <: Message ](m : RequestEntity [F ])(using cmp : Companion [A ]): DecodeResult [F , A ]
3034
3135 def encode [A <: Message ](message : A ): Entity [F ]
3236
3337}
3438
35- class JsonMessageCodec [F [_] : Sync : Compression ](jsonPrinter : Printer ) extends MessageCodec [F ] {
39+ class JsonMessageCodec [F [_] : Sync : Compression ](printer : Printer ) extends MessageCodec [F ] {
3640
3741 private val logger : Logger = LoggerFactory .getLogger(getClass)
3842
3943 override val mediaType : MediaType = MediaTypes .`application/json`
4044
41- override def decode [A <: Message ](m : Media [F ])(using cmp : Companion [A ]): DecodeResult [F , A ] = {
42- val charset = m.charset.getOrElse(Charset .`UTF-8`).nioCharset
45+ override def decode [A <: Message ](entity : RequestEntity [F ])(using cmp : Companion [A ]): DecodeResult [F , A ] = {
46+ val charset = entity.charset.nioCharset
47+ val string = entity.message match {
48+ case str : String =>
49+ Sync [F ].delay(URLDecoder .decode(str, charset))
50+ case stream : Stream [F , Byte ] =>
51+ decompressed(entity.headers, stream)
52+ .through(decodeWithCharset(charset))
53+ .compile.string
54+ }
4355
44- val f = decompressed(m)
45- .through(decodeWithCharset(charset))
46- .compile.string
56+ val f = string
4757 .flatMap { str =>
4858 if (logger.isTraceEnabled) {
49- logger.trace(s " >>> Headers: ${m .headers}" )
59+ logger.trace(s " >>> Headers: ${entity .headers}" )
5060 logger.trace(s " >>> JSON: $str" )
5161 }
5262
@@ -57,7 +67,7 @@ class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends
5767 }
5868
5969 override def encode [A <: Message ](message : A ): Entity [F ] = {
60- val string = jsonPrinter .print(message)
70+ val string = printer .print(message)
6171
6272 if (logger.isTraceEnabled) {
6373 logger.trace(s " <<< JSON: $string" )
@@ -72,23 +82,28 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {
7282
7383 private val logger : Logger = LoggerFactory .getLogger(getClass)
7484
75- override val mediaType : MediaType = MediaTypes .`application/proto`
85+ private val base64dec = Base64 .getUrlDecoder
7686
77- override def decode [A <: Message ](m : Media [F ])(using cmp : Companion [A ]): DecodeResult [F , A ] = {
78- val f = toInputStreamResource(decompressed(m)).use { is =>
79- Async [F ].delay {
80- val message = cmp.parseFrom(is)
87+ override val mediaType : MediaType = MediaTypes .`application/proto`
8188
82- if (logger.isTraceEnabled) {
83- logger.trace(s " >>> Headers: ${m.headers}" )
84- logger.trace(s " >>> Proto: ${message.toProtoString}" )
85- }
89+ override def decode [A <: Message ](entity : RequestEntity [F ])(using cmp : Companion [A ]): DecodeResult [F , A ] = {
90+ val f = entity.message match {
91+ case str : String =>
92+ Async [F ].delay(base64dec.decode(str.getBytes(entity.charset.nioCharset)))
93+ .flatMap(arr => Async [F ].delay(cmp.parseFrom(arr)))
94+ case stream : Stream [F , Byte ] =>
95+ toInputStreamResource(decompressed(entity.headers, stream))
96+ .use(is => Async [F ].delay(cmp.parseFrom(is)))
97+ }
8698
87- message
99+ EitherT .right(f.map { message =>
100+ if (logger.isTraceEnabled) {
101+ logger.trace(s " >>> Headers: ${entity.headers}" )
102+ logger.trace(s " >>> Proto: ${message.toProtoString}" )
88103 }
89- }
90104
91- EitherT .right(f)
105+ message
106+ })
92107 }
93108
94109 override def encode [A <: Message ](message : A ): Entity [F ] = {
@@ -104,10 +119,10 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {
104119
105120}
106121
107- def decompressed [F [_] : Compression ](m : Media [ F ]): fs2. Stream [F , Byte ] = {
108- val encoding = m. headers.get[`Content-Encoding`].map(_.contentCoding)
122+ def decompressed [F [_] : Compression ](headers : Headers , body : Stream [ F , Byte ]): Stream [F , Byte ] = {
123+ val encoding = headers.get[`Content-Encoding`].map(_.contentCoding)
109124
110- m. body.through(encoding match {
125+ body.through(encoding match {
111126 case Some (ContentCoding .gzip) =>
112127 Compression [F ].gunzip().andThen(_.flatMap(_.content))
113128 case _ =>
0 commit comments