@@ -4,11 +4,14 @@ import (
44 "context"
55 "encoding/json"
66 http "net/http"
7+ "reflect"
78 "strings"
89
910 "github.com/ccheers/xpkg/sync/errgroup"
1011 "github.com/gorilla/websocket"
1112 "github.com/pkg/errors"
13+ "google.golang.org/protobuf/encoding/protojson"
14+ "google.golang.org/protobuf/proto"
1215)
1316
1417type WebSocket [T any , R any ] struct {
@@ -21,6 +24,7 @@ type Encoding struct {
2124 errorEncodeFunc func (w http.ResponseWriter , err error )
2225 requestDecodeFunc func (r * http.Request , req interface {}) error
2326
27+ wsReqDecodeFunc func (bs []byte , req interface {}) error
2428 replyEncodeFunc func (ws * websocket.Conn , resp interface {})
2529 replyErrorEncodeFunc func (ws * websocket.Conn , err error )
2630}
@@ -54,6 +58,9 @@ func defaultWSOptions() WSOptions {
5458 requestDecodeFunc : func (r * http.Request , req interface {}) error {
5559 return json .NewDecoder (r .Body ).Decode (req )
5660 },
61+ wsReqDecodeFunc : func (bs []byte , req interface {}) error {
62+ return unmarshalJSON (bs , req )
63+ },
5764 replyEncodeFunc : func (ws * websocket.Conn , resp interface {}) {
5865 _ = ws .WriteJSON (map [string ]interface {}{
5966 "code" : 200 ,
@@ -99,6 +106,12 @@ func WithRequestDecodeFunc(fn func(r *http.Request, req interface{}) error) WSOp
99106 }
100107}
101108
109+ func WithWsReqDecodeFunc (fn func (bs []byte , req interface {}) error ) WSOptionFunc {
110+ return func (options * WSOptions ) {
111+ options .encoding .wsReqDecodeFunc = fn
112+ }
113+ }
114+
102115func WithReplyEncodeFunc (fn func (ws * websocket.Conn , resp interface {})) WSOptionFunc {
103116 return func (options * WSOptions ) {
104117 options .encoding .replyEncodeFunc = fn
@@ -193,7 +206,7 @@ func (x *WebSocket[T, R]) readLoop(ctx context.Context, ws *websocket.Conn) {
193206 default :
194207 }
195208 var dst R
196- err := ws .ReadJSON ( & dst )
209+ _ , bs , err := ws .ReadMessage ( )
197210 if err != nil {
198211 if strings .Contains (err .Error (), "connection reset by peer" ) {
199212 return
@@ -204,6 +217,12 @@ func (x *WebSocket[T, R]) readLoop(ctx context.Context, ws *websocket.Conn) {
204217 continue
205218 }
206219
220+ err = x .options .encoding .wsReqDecodeFunc (bs , & dst )
221+ if err != nil {
222+ x .options .encoding .replyErrorEncodeFunc (ws , err )
223+ continue
224+ }
225+
207226 if validate , ok := (interface {})(& dst ).(interface { Validate () error }); ok {
208227 err := validate .Validate ()
209228 if err != nil {
@@ -237,3 +256,31 @@ func (x *WebSocket[T, R]) writeLoop(ctx context.Context, ws *websocket.Conn) {
237256 x .options .encoding .replyEncodeFunc (ws , resp )
238257 }
239258}
259+
260+ var (
261+ // unmarshalOptions is a configurable JSON format parser.
262+ unmarshalOptions = protojson.UnmarshalOptions {
263+ DiscardUnknown : true ,
264+ }
265+ )
266+
267+ func unmarshalJSON (data []byte , v interface {}) error {
268+ switch m := v .(type ) {
269+ case json.Unmarshaler :
270+ return m .UnmarshalJSON (data )
271+ case proto.Message :
272+ return unmarshalOptions .Unmarshal (data , m )
273+ default :
274+ rv := reflect .ValueOf (v )
275+ for rv := rv ; rv .Kind () == reflect .Ptr ; {
276+ if rv .IsNil () {
277+ rv .Set (reflect .New (rv .Type ().Elem ()))
278+ }
279+ rv = rv .Elem ()
280+ }
281+ if m , ok := reflect .Indirect (rv ).Interface ().(proto.Message ); ok {
282+ return unmarshalOptions .Unmarshal (data , m )
283+ }
284+ return json .Unmarshal (data , m )
285+ }
286+ }
0 commit comments