|
1 |
| -package mqttcodec |
| 1 | +package codec |
2 | 2 |
|
3 | 3 | import (
|
4 | 4 | "bytes"
|
5 | 5 | "fmt"
|
6 | 6 | "io"
|
| 7 | + |
| 8 | + mqttproto "github.com/grepplabs/mqtt-proxy/pkg/mqtt/codec/proto" |
| 9 | + mqtt311 "github.com/grepplabs/mqtt-proxy/pkg/mqtt/codec/v311" |
7 | 10 | )
|
8 | 11 |
|
9 |
| -type ControlPacket interface { |
10 |
| - Write(io.Writer) error |
11 |
| - Unpack(io.Reader) error |
12 |
| - String() string |
13 |
| - Type() byte |
14 |
| - Name() string |
| 12 | +func ReadPacket(r io.Reader, protocolVersion byte) (mqttproto.ControlPacket, error) { |
| 13 | + if protocolVersion == 0 { |
| 14 | + var ( |
| 15 | + err error |
| 16 | + buf bytes.Buffer |
| 17 | + ) |
| 18 | + versionReader := io.TeeReader(r, &buf) |
| 19 | + protocolVersion, err = readConnectVersion(versionReader) |
| 20 | + if err != nil { |
| 21 | + return nil, err |
| 22 | + } |
| 23 | + r = io.MultiReader(bytes.NewReader(buf.Bytes()), r) |
| 24 | + } |
| 25 | + switch protocolVersion { |
| 26 | + case mqttproto.MQTT_3_1_1: |
| 27 | + return mqtt311.ReadPacket(r) |
| 28 | + case mqttproto.MQTT_5: |
| 29 | + return nil, mqtt311.NewConnAckError(mqttproto.RefusedUnacceptableProtocolVersion, "mqtt5 is not supported yet") |
| 30 | + default: |
| 31 | + return nil, mqtt311.NewConnAckError(mqttproto.RefusedUnacceptableProtocolVersion, fmt.Sprintf("unsupported protocol version %v", protocolVersion)) |
| 32 | + } |
15 | 33 | }
|
16 | 34 |
|
17 |
| -func ReadPacket(r io.Reader) (ControlPacket, error) { |
18 |
| - var fh FixedHeader |
| 35 | +func readConnectVersion(r io.Reader) (byte, error) { |
| 36 | + // fixed header |
| 37 | + var fh mqttproto.FixedHeader |
19 | 38 | b1 := make([]byte, 1)
|
20 | 39 | _, err := io.ReadFull(r, b1)
|
21 | 40 | if err != nil {
|
22 |
| - return nil, err |
| 41 | + return 0, err |
23 | 42 | }
|
24 |
| - |
25 |
| - err = fh.unpack(b1[0], r) |
| 43 | + err = fh.Unpack(b1[0], r) |
26 | 44 | if err != nil {
|
27 |
| - return nil, err |
| 45 | + return 0, err |
28 | 46 | }
|
29 |
| - |
30 |
| - err = fh.validate() |
| 47 | + err = fh.Validate() |
31 | 48 | if err != nil {
|
32 |
| - return nil, err |
| 49 | + return 0, err |
33 | 50 | }
|
34 |
| - |
35 |
| - cp, err := NewControlPacketWithHeader(fh) |
36 |
| - if err != nil { |
37 |
| - return nil, err |
| 51 | + if fh.MessageType != mqttproto.CONNECT { |
| 52 | + return 0, fmt.Errorf("expected CONNECT packet but got type 0x%x", fh.MessageType) |
38 | 53 | }
|
39 |
| - |
40 |
| - packetBytes := make([]byte, fh.RemainingLength) |
41 |
| - n, err := io.ReadFull(r, packetBytes) |
| 54 | + // variable header |
| 55 | + // Protocol Name |
| 56 | + _, err = mqttproto.DecodeString(r) |
42 | 57 | if err != nil {
|
43 |
| - return nil, err |
| 58 | + return 0, err |
44 | 59 | }
|
45 |
| - if n != fh.RemainingLength { |
46 |
| - return nil, fmt.Errorf("failed to read encoded data, read %d from %d", n, fh.RemainingLength) |
47 |
| - } |
48 |
| - err = cp.Unpack(bytes.NewBuffer(packetBytes)) |
49 |
| - return cp, err |
50 |
| -} |
51 |
| - |
52 |
| -func NewControlPacket(packetType byte) ControlPacket { |
53 |
| - switch packetType { |
54 |
| - case CONNECT: |
55 |
| - return &ConnectPacket{FixedHeader: FixedHeader{MessageType: CONNECT}} |
56 |
| - case CONNACK: |
57 |
| - return &ConnackPacket{FixedHeader: FixedHeader{MessageType: CONNACK}} |
58 |
| - case PUBLISH: |
59 |
| - return &PublishPacket{FixedHeader: FixedHeader{MessageType: PUBLISH}} |
60 |
| - case PUBACK: |
61 |
| - return &PubackPacket{FixedHeader: FixedHeader{MessageType: PUBACK}} |
62 |
| - case PUBREC: |
63 |
| - return &PubrecPacket{FixedHeader: FixedHeader{MessageType: PUBREC}} |
64 |
| - case PUBREL: |
65 |
| - return &PubrelPacket{FixedHeader: FixedHeader{MessageType: PUBREL}} |
66 |
| - case PUBCOMP: |
67 |
| - return &PubcompPacket{FixedHeader: FixedHeader{MessageType: PUBCOMP}} |
68 |
| - case SUBSCRIBE: |
69 |
| - return &SubscribePacket{FixedHeader: FixedHeader{MessageType: SUBSCRIBE}} |
70 |
| - case SUBACK: |
71 |
| - return &SubackPacket{FixedHeader: FixedHeader{MessageType: SUBACK}} |
72 |
| - case UNSUBSCRIBE: |
73 |
| - return &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: UNSUBSCRIBE}} |
74 |
| - case UNSUBACK: |
75 |
| - return &UnsubackPacket{FixedHeader: FixedHeader{MessageType: UNSUBACK}} |
76 |
| - case PINGREQ: |
77 |
| - return &PingreqPacket{FixedHeader: FixedHeader{MessageType: PINGREQ}} |
78 |
| - case PINGRESP: |
79 |
| - return &PingrespPacket{FixedHeader: FixedHeader{MessageType: PINGRESP}} |
80 |
| - case DISCONNECT: |
81 |
| - return &DisconnectPacket{FixedHeader: FixedHeader{MessageType: DISCONNECT}} |
82 |
| - |
83 |
| - } |
84 |
| - return nil |
85 |
| -} |
86 |
| - |
87 |
| -func NewControlPacketWithHeader(fh FixedHeader) (ControlPacket, error) { |
88 |
| - switch fh.MessageType { |
89 |
| - case CONNECT: |
90 |
| - return &ConnectPacket{FixedHeader: fh}, nil |
91 |
| - case CONNACK: |
92 |
| - return &ConnackPacket{FixedHeader: fh}, nil |
93 |
| - case PUBLISH: |
94 |
| - return &PublishPacket{FixedHeader: fh}, nil |
95 |
| - case PUBACK: |
96 |
| - return &PubackPacket{FixedHeader: fh}, nil |
97 |
| - case PUBREC: |
98 |
| - return &PubrecPacket{FixedHeader: fh}, nil |
99 |
| - case PUBREL: |
100 |
| - return &PubrelPacket{FixedHeader: fh}, nil |
101 |
| - case PUBCOMP: |
102 |
| - return &PubcompPacket{FixedHeader: fh}, nil |
103 |
| - case SUBSCRIBE: |
104 |
| - return &SubscribePacket{FixedHeader: fh}, nil |
105 |
| - case SUBACK: |
106 |
| - return &SubackPacket{FixedHeader: fh}, nil |
107 |
| - case UNSUBSCRIBE: |
108 |
| - return &UnsubscribePacket{FixedHeader: fh}, nil |
109 |
| - case UNSUBACK: |
110 |
| - return &UnsubackPacket{FixedHeader: fh}, nil |
111 |
| - case PINGREQ: |
112 |
| - return &PingreqPacket{FixedHeader: fh}, nil |
113 |
| - case PINGRESP: |
114 |
| - return &PingrespPacket{FixedHeader: fh}, nil |
115 |
| - case DISCONNECT: |
116 |
| - return &DisconnectPacket{FixedHeader: fh}, nil |
117 |
| - default: |
118 |
| - return nil, fmt.Errorf("unsupported packet type 0x%x", fh.MessageType) |
| 60 | + // Protocol Version |
| 61 | + version, err := mqttproto.DecodeByte(r) |
| 62 | + if err != nil { |
| 63 | + return 0, err |
119 | 64 | }
|
| 65 | + return version, nil |
120 | 66 | }
|
0 commit comments