|
1 | 1 | package store
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "github.com/ByteStorage/FlyDB/config" |
| 7 | + "github.com/ByteStorage/FlyDB/lib/encoding" |
| 8 | + raftPB "github.com/ByteStorage/FlyDB/lib/proto/raft" |
4 | 9 | "github.com/hashicorp/raft"
|
5 | 10 | "google.golang.org/grpc"
|
6 | 11 | "io"
|
| 12 | + "sync" |
| 13 | + "time" |
7 | 14 | )
|
8 | 15 |
|
9 |
| -// transport implements raft.Transport interface |
| 16 | +type ClientConn struct { |
| 17 | + conn *grpc.ClientConn |
| 18 | + client raftPB.RaftServiceClient |
| 19 | + mtx sync.Mutex |
| 20 | +} |
| 21 | +type raftPipeline struct { |
| 22 | + stream raftPB.RaftService_AppendEntriesPipelineClient |
| 23 | + cancel func() |
| 24 | + inflightChMtx sync.Mutex |
| 25 | + inflightCh chan *appendFuture |
| 26 | + doneCh chan raft.AppendFuture |
| 27 | +} |
| 28 | + |
| 29 | +type appendFuture struct { |
| 30 | + raft.AppendFuture |
| 31 | + start time.Time |
| 32 | + request *raft.AppendEntriesRequest |
| 33 | + response raft.AppendEntriesResponse |
| 34 | + err error |
| 35 | + done chan struct{} |
| 36 | +} |
| 37 | + |
| 38 | +// Transport implements raft.Transport interface |
10 | 39 | // we can use it to send rpc to other raft nodes
|
11 | 40 | // and receive rpc from other raft nodes
|
12 |
| -type transport struct { |
| 41 | +type Transport struct { |
13 | 42 | //implement me
|
14 |
| - localAddr raft.ServerAddress |
15 |
| - consumer chan raft.RPC |
16 |
| - clients map[raft.ServerAddress]*grpc.ClientConn |
17 |
| - server *grpc.Server |
| 43 | + localAddr raft.ServerAddress |
| 44 | + consumer chan raft.RPC |
| 45 | + clients map[raft.ServerAddress]*ClientConn |
| 46 | + server *grpc.Server |
| 47 | + heartbeatFn func(raft.RPC) |
| 48 | + dialOptions []grpc.DialOption |
| 49 | + heartbeatTimeout time.Duration |
| 50 | + sync.RWMutex |
18 | 51 | }
|
19 | 52 |
|
20 | 53 | // NewTransport returns a new transport, it needs start a grpc server
|
21 |
| -func newTransport() raft.Transport { |
22 |
| - return &transport{} |
| 54 | +func newTransport(conf config.Config) raft.Transport { |
| 55 | + return &Transport{ |
| 56 | + localAddr: conf.LocalAddress, |
| 57 | + dialOptions: []grpc.DialOption{grpc.WithInsecure()}, |
| 58 | + heartbeatTimeout: conf.HeartbeatTimeout, |
| 59 | + consumer: make(chan raft.RPC), |
| 60 | + clients: map[raft.ServerAddress]*ClientConn{}, |
| 61 | + } |
23 | 62 | }
|
24 | 63 |
|
25 |
| -func (t *transport) AppendEntriesPipeline(id raft.ServerID, target raft.ServerAddress) (raft.AppendPipeline, error) { |
26 |
| - //TODO implement me |
27 |
| - panic("implement me") |
| 64 | +// AppendEntriesPipeline returns an interface that can be used to pipeline |
| 65 | +// AppendEntries requests. |
| 66 | +func (t *Transport) AppendEntriesPipeline(id raft.ServerID, target raft.ServerAddress) (raft.AppendPipeline, error) { |
| 67 | + c, err := t.getPeer(target) |
| 68 | + if err != nil { |
| 69 | + return nil, err |
| 70 | + } |
| 71 | + ctx := context.TODO() |
| 72 | + ctx, cancel := context.WithCancel(ctx) |
| 73 | + stream, err := c.AppendEntriesPipeline(ctx) |
| 74 | + if err != nil { |
| 75 | + cancel() |
| 76 | + return nil, err |
| 77 | + } |
| 78 | + rpa := raftPipeline{ |
| 79 | + stream: stream, |
| 80 | + cancel: cancel, |
| 81 | + inflightCh: make(chan *appendFuture, 20), |
| 82 | + doneCh: make(chan raft.AppendFuture, 20), |
| 83 | + } |
| 84 | + go rpa.receiver() |
| 85 | + return &rpa, nil |
28 | 86 | }
|
29 | 87 |
|
30 |
| -func (t *transport) AppendEntries(id raft.ServerID, target raft.ServerAddress, args *raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error { |
31 |
| - //TODO implement me |
32 |
| - panic("implement me") |
| 88 | +// AppendEntries sends the appropriate RPC to the target node. |
| 89 | +func (t *Transport) AppendEntries(id raft.ServerID, target raft.ServerAddress, args *raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error { |
| 90 | + c, err := t.getPeer(target) |
| 91 | + if err != nil { |
| 92 | + return err |
| 93 | + } |
| 94 | + ctx := context.TODO() |
| 95 | + if t.heartbeatTimeout > 0 && isHeartbeat(args) { |
| 96 | + var cancel context.CancelFunc |
| 97 | + ctx, cancel = context.WithTimeout(ctx, t.heartbeatTimeout) |
| 98 | + defer cancel() |
| 99 | + } |
| 100 | + ret, err := c.AppendEntries(ctx, encoding.EncodeAppendEntriesRequest(args)) |
| 101 | + if err != nil { |
| 102 | + return err |
| 103 | + } |
| 104 | + *resp = *encoding.DecodeAppendEntriesResponse(ret) |
| 105 | + return nil |
33 | 106 | }
|
34 | 107 |
|
35 |
| -func (t *transport) RequestVote(id raft.ServerID, target raft.ServerAddress, args *raft.RequestVoteRequest, resp *raft.RequestVoteResponse) error { |
36 |
| - //TODO implement me |
37 |
| - panic("implement me") |
| 108 | +// RequestVote sends the appropriate RPC to the target node. |
| 109 | +func (t *Transport) RequestVote(id raft.ServerID, target raft.ServerAddress, args *raft.RequestVoteRequest, resp *raft.RequestVoteResponse) error { |
| 110 | + c, err := t.getPeer(target) |
| 111 | + if err != nil { |
| 112 | + return err |
| 113 | + } |
| 114 | + vote, err := c.RequestVote(context.TODO(), encoding.EncodeRequestVoteRequest(args)) |
| 115 | + if err != nil { |
| 116 | + return err |
| 117 | + } |
| 118 | + *resp = *encoding.DecodeRequestVoteResponse(vote) |
| 119 | + return nil |
38 | 120 | }
|
39 | 121 |
|
40 |
| -func (t *transport) InstallSnapshot(id raft.ServerID, target raft.ServerAddress, args *raft.InstallSnapshotRequest, resp *raft.InstallSnapshotResponse, data io.Reader) error { |
41 |
| - //TODO implement me |
42 |
| - panic("implement me") |
| 122 | +// InstallSnapshot is used to push a snapshot down to a follower. The data is read from |
| 123 | +// the ReadCloser and streamed to the client. |
| 124 | +func (t *Transport) InstallSnapshot(id raft.ServerID, target raft.ServerAddress, args *raft.InstallSnapshotRequest, resp *raft.InstallSnapshotResponse, data io.Reader) error { |
| 125 | + c, err := t.getPeer(target) |
| 126 | + if err != nil { |
| 127 | + return err |
| 128 | + } |
| 129 | + inSnap, err := c.InstallSnapshot(context.TODO(), encoding.EncodeInstallSnapshotRequest(args)) |
| 130 | + if err != nil { |
| 131 | + return err |
| 132 | + } |
| 133 | + |
| 134 | + *resp = *encoding.DecodeInstallSnapshotResponse(inSnap) |
| 135 | + return nil |
43 | 136 | }
|
44 | 137 |
|
45 |
| -func (t *transport) TimeoutNow(id raft.ServerID, target raft.ServerAddress, args *raft.TimeoutNowRequest, resp *raft.TimeoutNowResponse) error { |
46 |
| - //TODO implement me |
47 |
| - panic("implement me") |
| 138 | +// TimeoutNow is used to start a leadership transfer to the target node. |
| 139 | +func (t *Transport) TimeoutNow(id raft.ServerID, target raft.ServerAddress, args *raft.TimeoutNowRequest, resp *raft.TimeoutNowResponse) error { |
| 140 | + c, err := t.getPeer(target) |
| 141 | + if err != nil { |
| 142 | + return err |
| 143 | + } |
| 144 | + ret, err := c.TimeoutNow(context.TODO(), encoding.EncodeTimeoutNowRequest(args)) |
| 145 | + if err != nil { |
| 146 | + return err |
| 147 | + } |
| 148 | + *resp = *encoding.DecodeTimeoutNowResponse(ret) |
| 149 | + return nil |
48 | 150 | }
|
49 | 151 |
|
50 |
| -func (t *transport) Consumer() <-chan raft.RPC { |
51 |
| - //implement me |
52 |
| - panic("implement me") |
| 152 | +// Consumer returns a channel that can be used to |
| 153 | +// consume and respond to RPC requests. |
| 154 | +func (t *Transport) Consumer() <-chan raft.RPC { |
| 155 | + return t.consumer |
53 | 156 | }
|
54 | 157 |
|
55 |
| -func (t *transport) LocalAddr() raft.ServerAddress { |
56 |
| - //implement me |
57 |
| - panic("implement me") |
| 158 | +// LocalAddr is used to return our local address to distinguish from our peers. |
| 159 | +func (t *Transport) LocalAddr() raft.ServerAddress { |
| 160 | + return t.localAddr |
58 | 161 | }
|
59 | 162 |
|
60 |
| -func (t *transport) EncodePeer(id raft.ServerID, addr raft.ServerAddress) []byte { |
61 |
| - //implement me |
62 |
| - panic("implement me") |
| 163 | +// EncodePeer is used to serialize a peer's address. |
| 164 | +func (t *Transport) EncodePeer(id raft.ServerID, addr raft.ServerAddress) []byte { |
| 165 | + return []byte(addr) |
63 | 166 | }
|
64 | 167 |
|
65 |
| -func (t *transport) DecodePeer([]byte) raft.ServerAddress { |
66 |
| - //implement me |
67 |
| - panic("implement me") |
| 168 | +// DecodePeer is used to deserialize a peer's address. |
| 169 | +func (t *Transport) DecodePeer(p []byte) raft.ServerAddress { |
| 170 | + return raft.ServerAddress(p) |
68 | 171 | }
|
69 | 172 |
|
70 |
| -func (t *transport) SetHeartbeatHandler(handler func(rpc raft.RPC)) { |
71 |
| - //implement me |
72 |
| - panic("implement me") |
| 173 | +// SetHeartbeatHandler is used to setup a heartbeat handler |
| 174 | +// as a fast-pass. This is to avoid head-of-line blocking from |
| 175 | +// disk IO. If Transport does not support this, it can simply |
| 176 | +// ignore the call, and push the heartbeat onto the Consumer channel. |
| 177 | +func (t *Transport) SetHeartbeatHandler(handler func(rpc raft.RPC)) { |
| 178 | + t.RWMutex.RLock() |
| 179 | + defer t.RWMutex.RUnlock() |
| 180 | + t.heartbeatFn = handler |
| 181 | +} |
| 182 | + |
| 183 | +func (t *Transport) getPeer(target raft.ServerAddress) (raftPB.RaftServiceClient, error) { |
| 184 | + t.RWMutex.Lock() // Locking here |
| 185 | + defer t.RWMutex.Unlock() // Unlocking after the map access is done |
| 186 | + |
| 187 | + c, ok := t.clients[target] |
| 188 | + |
| 189 | + if !ok { |
| 190 | + c = &ClientConn{} |
| 191 | + c.mtx.Lock() |
| 192 | + defer c.mtx.Unlock() // We know that Lock was obtained and can use defer here |
| 193 | + |
| 194 | + t.clients[target] = c |
| 195 | + |
| 196 | + if c.conn == nil { |
| 197 | + conn, err := grpc.Dial(string(target), t.dialOptions...) |
| 198 | + if err != nil { |
| 199 | + return nil, err |
| 200 | + } |
| 201 | + c.conn = conn |
| 202 | + c.client = raftPB.NewRaftServiceClient(conn) |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + return c.client, nil |
| 207 | +} |
| 208 | +func isHeartbeat(command interface{}) bool { |
| 209 | + req, ok := command.(*raft.AppendEntriesRequest) |
| 210 | + if !ok { |
| 211 | + return false |
| 212 | + } |
| 213 | + if req == nil { |
| 214 | + return false |
| 215 | + } |
| 216 | + return req.Term != 0 && |
| 217 | + len(req.Leader) != 0 && |
| 218 | + req.PrevLogEntry == 0 && |
| 219 | + req.PrevLogTerm == 0 && |
| 220 | + len(req.Entries) == 0 && |
| 221 | + req.LeaderCommitIndex == 0 |
| 222 | +} |
| 223 | + |
| 224 | +func (af *appendFuture) Error() error { |
| 225 | + <-af.done |
| 226 | + return af.err |
| 227 | +} |
| 228 | +func (af *appendFuture) Start() time.Time { |
| 229 | + return af.start |
| 230 | +} |
| 231 | + |
| 232 | +func (af *appendFuture) Request() *raft.AppendEntriesRequest { |
| 233 | + return af.request |
| 234 | +} |
| 235 | +func (af *appendFuture) Response() *raft.AppendEntriesResponse { |
| 236 | + return &af.response |
| 237 | +} |
| 238 | + |
| 239 | +// AppendEntries is used to add another request to the pipeline. |
| 240 | +// The send may block which is an effective form of back-pressure. |
| 241 | +func (r *raftPipeline) AppendEntries(req *raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) (raft.AppendFuture, error) { |
| 242 | + af := &appendFuture{ |
| 243 | + start: time.Now(), |
| 244 | + request: req, |
| 245 | + done: make(chan struct{}), |
| 246 | + } |
| 247 | + if err := r.stream.Send(encoding.EncodeAppendEntriesRequest(req)); err != nil { |
| 248 | + return nil, err |
| 249 | + } |
| 250 | + select { |
| 251 | + case <-r.stream.Context().Done(): |
| 252 | + return nil, r.stream.Context().Err() |
| 253 | + case r.inflightCh <- af: |
| 254 | + default: |
| 255 | + return nil, fmt.Errorf("failed to send request to inflightCh") |
| 256 | + } |
| 257 | + |
| 258 | + return af, nil |
| 259 | +} |
| 260 | + |
| 261 | +// Consumer returns a channel that can be used to consume |
| 262 | +// response futures when they are ready. |
| 263 | +func (r *raftPipeline) Consumer() <-chan raft.AppendFuture { |
| 264 | + return r.doneCh |
| 265 | +} |
| 266 | + |
| 267 | +// Close closes the pipeline and cancels all inflight RPCs |
| 268 | +func (r *raftPipeline) Close() error { |
| 269 | + r.cancel() |
| 270 | + r.inflightChMtx.Lock() |
| 271 | + defer r.inflightChMtx.Unlock() |
| 272 | + close(r.inflightCh) |
| 273 | + return nil |
| 274 | +} |
| 275 | + |
| 276 | +func (r *raftPipeline) receiver() { |
| 277 | + for af := range r.inflightCh { |
| 278 | + af.processMessage(r) |
| 279 | + } |
| 280 | +} |
| 281 | + |
| 282 | +// processMessage processes the appendFuture message. |
| 283 | +func (af *appendFuture) processMessage(r *raftPipeline) { |
| 284 | + msg, err := r.stream.Recv() |
| 285 | + if err != nil { |
| 286 | + af.err = err |
| 287 | + } else if msg != nil { |
| 288 | + af.response = *encoding.DecodeAppendEntriesResponse(msg) |
| 289 | + } |
| 290 | + close(af.done) |
| 291 | + r.doneCh <- af |
73 | 292 | }
|
0 commit comments