Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ static_assert(sizeof(TMessage) == 8);

struct TLogEntry: public TMessage {
static constexpr EMessageType MessageType = EMessageType::LOG_ENTRY;
enum EFlags {
ENone = 0,
EStub = 1,
};
uint64_t Term = 1;
uint64_t Flags = 0;
char Data[0];
};

Expand Down Expand Up @@ -87,18 +92,21 @@ struct TCommandRequest: public TMessage {
EWrite = 1,
};
uint32_t Flags = ENone;
uint32_t Cookie = 0;
char Data[0];
};

static_assert(sizeof(TCommandRequest) == sizeof(TMessage) + 4);
static_assert(sizeof(TCommandRequest) == sizeof(TMessage) + 8);

struct TCommandResponse: public TMessage {
static constexpr EMessageType MessageType = EMessageType::COMMAND_RESPONSE;
uint64_t Index;
uint64_t Index = 0;
uint32_t Cookie = 0;
uint32_t ErrorCode = 0;
char Data[0];
};

static_assert(sizeof(TCommandResponse) == sizeof(TMessage) + 8);
static_assert(sizeof(TCommandResponse) == sizeof(TMessage) + 16);

struct TTimeout {
static constexpr std::chrono::milliseconds Election = std::chrono::milliseconds(5000);
Expand Down
48 changes: 35 additions & 13 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,25 @@ void TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
}

void TRaft::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
if (command->Flags & TCommandRequest::EWrite) {
auto entry = Rsm->Prepare(command, State->CurrentTerm);
State->Append(std::move(entry));
}
auto index = State->LastLogIndex;
if (replyTo) {
Waiting.emplace(TWaiting{index, std::move(command), replyTo});
// TODO: move this logic to separate class
if (StateName == EState::LEADER) {
if (command->Flags & TCommandRequest::EWrite) {
auto entry = Rsm->Prepare(command, State->CurrentTerm);
State->Append(std::move(entry));
}
auto index = State->LastLogIndex;
if (replyTo) {
Waiting.emplace(TWaiting{index, std::move(command), replyTo});
}
} else if (StateName == EState::FOLLOWER && replyTo) {
if (command->Flags & TCommandRequest::EWrite) {
// TODO: send error code
replyTo->Send(NewHoldedMessage(TCommandResponse {.Index = 0}));
} else {
Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo});
}
} else if (StateName == EState::CANDIDATE && replyTo) {
// wait
}
}

Expand Down Expand Up @@ -338,16 +350,14 @@ void TRaft::Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message) {
}
}

void TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo) {
void TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message) {
if (auto maybeAppendEntries = message.Maybe<TAppendEntriesResponse>()) {
OnAppendEntries(std::move(maybeAppendEntries.Cast()));
} else if (auto maybeCommandRequest = message.Maybe<TCommandRequest>()) {
OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo);
} else if (auto maybeVoteRequest = message.Maybe<TRequestVoteRequest>()) {
OnRequestVote(now, std::move(maybeVoteRequest.Cast()));
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
OnAppendEntries(now, std::move(maybeAppendEntries.Cast()));
}
}
}

void TRaft::Become(EState newStateName) {
Expand All @@ -357,6 +367,11 @@ void TRaft::Become(EState newStateName) {
}

void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo) {
// client request
if (auto maybeCommandRequest = message.Maybe<TCommandRequest>()) {
return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo);
}

if (message.IsEx()) {
auto messageEx = message.Cast<TMessageEx>();
if (messageEx->Term > State->CurrentTerm) {
Expand All @@ -369,6 +384,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
}
}
}

switch (StateName) {
case EState::FOLLOWER:
Follower(now, std::move(message));
Expand All @@ -377,7 +393,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
Candidate(now, std::move(message));
break;
case EState::LEADER:
Leader(now, std::move(message), replyTo);
Leader(now, std::move(message));
break;
default:
throw std::logic_error("Unknown state");
Expand All @@ -387,7 +403,11 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
void TRaft::ProcessCommitted() {
auto commitIndex = VolatileState->CommitIndex;
for (auto i = VolatileState->LastApplied+1; i <= commitIndex; i++) {
auto reply = Rsm->Write(State->Get(i-1), i);
auto entry = State->Get(i-1);
if (entry->Flags == TLogEntry::EStub) {
continue;
}
auto reply = Rsm->Write(entry, i);
WriteAnswers.emplace(TAnswer {
.Index = i,
.Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i})
Expand Down Expand Up @@ -422,6 +442,7 @@ void TRaft::FollowerTimeout(ITimeSource::Time now) {
}

ProcessCommitted();
ProcessWaiting(); // For forwarded requests
}

void TRaft::CandidateTimeout(ITimeSource::Time now) {
Expand Down Expand Up @@ -497,6 +518,7 @@ void TRaft::ProcessTimeout(ITimeSource::Time now) {
{
auto empty = NewHoldedMessage<TLogEntry>();
empty->Term = State->CurrentTerm;
empty->Flags = TLogEntry::EStub;
State->Append(std::move(empty));
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ class TRaft {
private:
void Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message);
void Follower(ITimeSource::Time now, TMessageHolder<TMessage> message);
void Leader(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo);
void Leader(ITimeSource::Time now, TMessageHolder<TMessage> message);

void OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequest> message);
void OnRequestVote(TMessageHolder<TRequestVoteResponse> message);
void OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntriesRequest> message);
void OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message);

void OnCommandRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);

void LeaderTimeout(ITimeSource::Time now);
Expand Down
2 changes: 1 addition & 1 deletion test/test_raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void test_message_create(void** state) {
}

void test_message_cast(void** state) {
TMessageHolder<TMessage> mes = NewHoldedMessage<TLogEntry>(16);
TMessageHolder<TMessage> mes = NewHoldedMessage(TLogEntry{});
auto casted = mes.Cast<TLogEntry>();
assert_true(mes.RawData == casted.RawData);
assert_true(mes->Len == casted->Len);
Expand Down
Loading