diff --git a/src/messages.h b/src/messages.h index 57ae6b6..85e982f 100644 --- a/src/messages.h +++ b/src/messages.h @@ -30,7 +30,12 @@ static_assert(sizeof(TMessage) == 8); struct TLogEntry: public TMessage { static constexpr EMessageType MessageType = EMessageType::LOG_ENTRY; + enum EFlags { + ENone = 0, + EStub = 1, + }; uint64_t Term = 1; + uint64_t Flags = 0; char Data[0]; }; @@ -87,18 +92,21 @@ struct TCommandRequest: public TMessage { EWrite = 1, }; uint32_t Flags = ENone; + uint32_t Cookie = 0; char Data[0]; }; -static_assert(sizeof(TCommandRequest) == sizeof(TMessage) + 4); +static_assert(sizeof(TCommandRequest) == sizeof(TMessage) + 8); struct TCommandResponse: public TMessage { static constexpr EMessageType MessageType = EMessageType::COMMAND_RESPONSE; - uint64_t Index; + uint64_t Index = 0; + uint32_t Cookie = 0; + uint32_t ErrorCode = 0; char Data[0]; }; -static_assert(sizeof(TCommandResponse) == sizeof(TMessage) + 8); +static_assert(sizeof(TCommandResponse) == sizeof(TMessage) + 16); struct TTimeout { static constexpr std::chrono::milliseconds Election = std::chrono::milliseconds(5000); diff --git a/src/raft.cpp b/src/raft.cpp index aed6e58..400a759 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -271,13 +271,25 @@ void TRaft::OnAppendEntries(TMessageHolder message) { } void TRaft::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { - if (command->Flags & TCommandRequest::EWrite) { - auto entry = Rsm->Prepare(command, State->CurrentTerm); - State->Append(std::move(entry)); - } - auto index = State->LastLogIndex; - if (replyTo) { - Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + // TODO: move this logic to separate class + if (StateName == EState::LEADER) { + if (command->Flags & TCommandRequest::EWrite) { + auto entry = Rsm->Prepare(command, State->CurrentTerm); + State->Append(std::move(entry)); + } + auto index = State->LastLogIndex; + if (replyTo) { + Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + } + } else if (StateName == EState::FOLLOWER && replyTo) { + if (command->Flags & TCommandRequest::EWrite) { + // TODO: send error code + replyTo->Send(NewHoldedMessage(TCommandResponse {.Index = 0})); + } else { + Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo}); + } + } else if (StateName == EState::CANDIDATE && replyTo) { + // wait } } @@ -338,16 +350,14 @@ void TRaft::Candidate(ITimeSource::Time now, TMessageHolder message) { } } -void TRaft::Leader(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo) { +void TRaft::Leader(ITimeSource::Time now, TMessageHolder message) { if (auto maybeAppendEntries = message.Maybe()) { OnAppendEntries(std::move(maybeAppendEntries.Cast())); - } else if (auto maybeCommandRequest = message.Maybe()) { - OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo); } else if (auto maybeVoteRequest = message.Maybe()) { OnRequestVote(now, std::move(maybeVoteRequest.Cast())); } else if (auto maybeAppendEntries = message.Maybe()) { OnAppendEntries(now, std::move(maybeAppendEntries.Cast())); - } + } } void TRaft::Become(EState newStateName) { @@ -357,6 +367,11 @@ void TRaft::Become(EState newStateName) { } void TRaft::Process(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo) { + // client request + if (auto maybeCommandRequest = message.Maybe()) { + return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo); + } + if (message.IsEx()) { auto messageEx = message.Cast(); if (messageEx->Term > State->CurrentTerm) { @@ -369,6 +384,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con } } } + switch (StateName) { case EState::FOLLOWER: Follower(now, std::move(message)); @@ -377,7 +393,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con Candidate(now, std::move(message)); break; case EState::LEADER: - Leader(now, std::move(message), replyTo); + Leader(now, std::move(message)); break; default: throw std::logic_error("Unknown state"); @@ -387,7 +403,11 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con void TRaft::ProcessCommitted() { auto commitIndex = VolatileState->CommitIndex; for (auto i = VolatileState->LastApplied+1; i <= commitIndex; i++) { - auto reply = Rsm->Write(State->Get(i-1), i); + auto entry = State->Get(i-1); + if (entry->Flags == TLogEntry::EStub) { + continue; + } + auto reply = Rsm->Write(entry, i); WriteAnswers.emplace(TAnswer { .Index = i, .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) @@ -422,6 +442,7 @@ void TRaft::FollowerTimeout(ITimeSource::Time now) { } ProcessCommitted(); + ProcessWaiting(); // For forwarded requests } void TRaft::CandidateTimeout(ITimeSource::Time now) { @@ -497,6 +518,7 @@ void TRaft::ProcessTimeout(ITimeSource::Time now) { { auto empty = NewHoldedMessage(); empty->Term = State->CurrentTerm; + empty->Flags = TLogEntry::EStub; State->Append(std::move(empty)); } } diff --git a/src/raft.h b/src/raft.h index c80c481..f068a52 100644 --- a/src/raft.h +++ b/src/raft.h @@ -120,12 +120,13 @@ class TRaft { private: void Candidate(ITimeSource::Time now, TMessageHolder message); void Follower(ITimeSource::Time now, TMessageHolder message); - void Leader(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo); + void Leader(ITimeSource::Time now, TMessageHolder message); void OnRequestVote(ITimeSource::Time now, TMessageHolder message); void OnRequestVote(TMessageHolder message); void OnAppendEntries(ITimeSource::Time now, TMessageHolder message); void OnAppendEntries(TMessageHolder message); + void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); void LeaderTimeout(ITimeSource::Time now); diff --git a/test/test_raft.cpp b/test/test_raft.cpp index bd8268d..7d3454e 100644 --- a/test/test_raft.cpp +++ b/test/test_raft.cpp @@ -127,7 +127,7 @@ void test_message_create(void** state) { } void test_message_cast(void** state) { - TMessageHolder mes = NewHoldedMessage(16); + TMessageHolder mes = NewHoldedMessage(TLogEntry{}); auto casted = mes.Cast(); assert_true(mes.RawData == casted.RawData); assert_true(mes->Len == casted->Len);