Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions examples/kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TMessageHolder<TLogEntry> TKv::Prepare(TMessageHolder<TCommandRequest> command)
}

template<typename TPoller, typename TSocket>
NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
NNet::TFuture<void> Client(TPoller& poller, TSocket socket, uint32_t flags) {
using TFileHandle = typename TPoller::TFileHandle;
TFileHandle input{0, poller}; // stdin
co_await socket.Connect();
Expand Down Expand Up @@ -112,11 +112,14 @@ NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
auto key = strtok(nullptr, sep);
auto size = strlen(key);
auto mes = NewHoldedMessage<TReadKv>(sizeof(TReadKv) + size);
mes->Flags = flags;
mes->KeySize = size;
memcpy(mes->Data, key, size);
req = mes;
} else if (!strcmp(prefix, "list")) {
req = NewHoldedMessage<TReadKv>(sizeof(TReadKv));
auto mes = NewHoldedMessage<TReadKv>(sizeof(TReadKv));
mes->Flags = flags;
req = mes;
} else if (!strcmp(prefix, "del")) {
auto key = strtok(nullptr, sep);
auto size = strlen(key);
Expand Down Expand Up @@ -145,7 +148,7 @@ NNet::TFuture<void> Client(TPoller& poller, TSocket socket) {
}

void usage(const char* prog) {
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist]" << "\n";
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...] [--persist] [--stale] [--consistent]" << "\n";
exit(0);
}

Expand All @@ -157,6 +160,7 @@ int main(int argc, char** argv) {
uint32_t id = 0;
bool server = false;
bool persist = false;
uint32_t flags = 0;
for (int i = 1; i < argc; i++) {
if (!strcmp(argv[i], "--server")) {
server = true;
Expand All @@ -167,6 +171,10 @@ int main(int argc, char** argv) {
id = atoi(argv[++i]);
} else if (!strcmp(argv[i], "--persist")) {
persist = true;
} else if (!strcmp(argv[i], "--stale")) {
flags |= TCommandRequest::EStale;
} else if (!strcmp(argv[i], "--consistent")) {
flags |= TCommandRequest::EConsistent;
} else if (!strcmp(argv[i], "--help")) {
usage(argv[0]);
}
Expand Down Expand Up @@ -212,7 +220,7 @@ int main(int argc, char** argv) {
NNet::TAddress addr{hosts[0].Address, hosts[0].Port};
NNet::TSocket socket(std::move(addr), loop.Poller());

auto h = Client(loop.Poller(), std::move(socket));
auto h = Client(loop.Poller(), std::move(socket), flags);
while (!h.done()) {
loop.Step();
}
Expand Down
15 changes: 11 additions & 4 deletions src/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ enum class EMessageType : uint32_t {
REQUEST_VOTE_RESPONSE = 3,
APPEND_ENTRIES_REQUEST = 4,
APPEND_ENTRIES_RESPONSE = 5,
COMMAND_REQUEST = 6,
COMMAND_RESPONSE = 7,
INSTALL_SNAPSHOT_REQUEST = 6, // TODO: not implemented it
INSTALL_SNAPSHOT_RESPONSE = 7, // TODO: not implemented it
COMMAND_REQUEST = 8,
COMMAND_RESPONSE = 9,
};

struct TMessage {
Expand All @@ -43,9 +45,10 @@ struct TMessageEx: public TMessage {
uint32_t Src = 0;
uint32_t Dst = 0;
uint64_t Term = 0;
uint64_t Seqno = 0;
};

static_assert(sizeof(TMessageEx) == sizeof(TMessage)+16);
static_assert(sizeof(TMessageEx) == sizeof(TMessage)+24);

struct TRequestVoteRequest: public TMessageEx {
static constexpr EMessageType MessageType = EMessageType::REQUEST_VOTE_REQUEST;
Expand Down Expand Up @@ -89,7 +92,11 @@ struct TCommandRequest: public TMessage {
static constexpr EMessageType MessageType = EMessageType::COMMAND_REQUEST;
enum EFlags {
ENone = 0,
EWrite = 1,
EWrite = 1, //

// read semantics, default: read from leader w/o ping check, possible stale reads if there are 2 leaders
EStale = 2, // stale read, can read from follower
EConsistent = 4, // strong consistent read (wait for pings, low latency, no stale read)
};
uint32_t Flags = ENone;
uint32_t Cookie = 0;
Expand Down
123 changes: 88 additions & 35 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ TVolatileState& TVolatileState::Vote(uint32_t nodeId)
return *this;
}

TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state, uint64_t seqno)
{
auto lastIndex = state.LastLogIndex;
Indices.clear(); Indices.reserve(nservers);
Expand All @@ -82,9 +82,9 @@ TVolatileState& TVolatileState::CommitAdvance(int nservers, const IState& state)
std::sort(Indices.begin(), Indices.end());
auto commitIndex = std::max(CommitIndex, Indices[nservers / 2]);
if (state.LogTerm(commitIndex) == state.CurrentTerm) {
CommitSeqno = std::max(CommitSeqno, seqno);
CommitIndex = commitIndex;
}
// TODO: If state.LogTerm(commitIndex) < state.CurrentTerm need to append empty message to log
return *this;
}

Expand Down Expand Up @@ -196,6 +196,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
.Src = Id,
.Dst = message->Src,
.Term = State->CurrentTerm,
.Seqno = message->Seqno,
},
TAppendEntriesResponse {
.MatchIndex = 0,
Expand Down Expand Up @@ -234,7 +235,7 @@ void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntries
}

auto reply = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm, .Seqno = message->Seqno},
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

(*VolatileState)
Expand All @@ -260,7 +261,7 @@ void TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
.SetRpcDue(nodeId, ITimeSource::Time{})
.SetBatchSize(nodeId, 1024)
.SetBackOff(nodeId, 1)
.CommitAdvance(Nservers, *State);
.CommitAdvance(Nservers, *State, message->Seqno);
} else {
auto backOff = std::max(VolatileState->BackOff[nodeId], 1);
auto nextIndex = VolatileState->NextIndex[nodeId] > backOff
Expand Down Expand Up @@ -294,7 +295,7 @@ TMessageHolder<TAppendEntriesRequest> TRaft::CreateAppendEntries(uint32_t nodeId
}

auto mes = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm},
TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm, .Seqno = Seqno++},
TAppendEntriesRequest {
.PrevLogIndex = prevIndex,
.PrevLogTerm = State->LogTerm(prevIndex),
Expand Down Expand Up @@ -338,7 +339,7 @@ void TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message) {
OnRequestVote(now, std::move(maybeVoteRequest.Cast()));
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
OnAppendEntries(now, std::move(maybeAppendEntries.Cast()));
}
}
}

void TRaft::Become(EState newStateName) {
Expand Down Expand Up @@ -408,6 +409,18 @@ void TRaft::LeaderTimeout(ITimeSource::Time now) {
}
}

uint64_t TRaft::ApproveRead() {
int seqno = Seqno;
for (auto& [id, node] : Nodes) {
node->Send(CreateAppendEntries(id));
}
return seqno;
}

uint64_t TRaft::CommitSeqno() const {
return VolatileState->CommitSeqno;
}

void TRaft::ProcessTimeout(ITimeSource::Time now) {
if (StateName == EState::CANDIDATE || StateName == EState::FOLLOWER) {
if (VolatileState->ElectionDue <= now) {
Expand Down Expand Up @@ -508,33 +521,9 @@ void TRequestProcessor::CheckStateChange() {
}
}

void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto stateName = Raft->CurrentStateName();
auto leaderId = Raft->GetLeaderId();

// read request
if (! (command->Flags & TCommandRequest::EWrite)) {
if (replyTo) {
// TODO: possible stale read, should use max(LastIndex, LeaderLastIndex)
assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex());
Waiting.emplace(TWaiting{Raft->GetLastIndex(), std::move(command), replyTo});
}
return;
}

// write request
if (stateName == EState::LEADER) {
auto index = Raft->Append(std::move(Rsm->Prepare(command)));
if (replyTo) {
assert(Waiting.empty() || Waiting.back().Index <= index);
Waiting.emplace(TWaiting{index, std::move(command), replyTo});
}
return;
}

// forwarding write request
void TRequestProcessor::Forward(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo)
{
if (!replyTo) {
// nothing to forward
return;
}

Expand All @@ -543,9 +532,11 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command
replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1}));
return;
}


auto stateName = Raft->CurrentStateName();
auto leaderId = Raft->GetLeaderId();
if (stateName == EState::CANDIDATE || leaderId == 0) {
WaitingStateChange.emplace(TWaiting{0, std::move(command), replyTo});
WaitingStateChange.emplace(TWaiting{0, 0, std::move(command), replyTo});
return;
}

Expand All @@ -563,6 +554,56 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command
assert(false && "Wrong state");
}

void TRequestProcessor::OnReadRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo)
{
auto stateName = Raft->CurrentStateName();
auto flags = command->Flags;
assert(!(flags & TCommandRequest::EWrite));

// stale read, default read (from leader)
if ((flags & TCommandRequest::EStale) || (!(flags & TCommandRequest::EConsistent) && stateName == EState::LEADER)) {
assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex());
Waiting.emplace(TWaiting{Raft->GetLastIndex(), 0, std::move(command), replyTo});
return;
}

if (stateName != EState::LEADER) {
return Forward(std::move(command), replyTo);
}

// Consistent read
auto seqno = Raft->ApproveRead();
assert(StrongWaiting.empty() || (StrongWaiting.back().Index <= Raft->GetLastIndex() && StrongWaiting.back().Seqno <= seqno));
StrongWaiting.emplace(TWaiting{Raft->GetLastIndex(), seqno, std::move(command), replyTo});
}

void TRequestProcessor::OnWriteRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto stateName = Raft->CurrentStateName();
auto flags = command->Flags;
assert((flags & TCommandRequest::EWrite));

if (stateName == EState::LEADER) {
uint64_t index = Raft->Append(std::move(Rsm->Prepare(command)));
if (replyTo) {
assert(Waiting.empty() || Waiting.back().Index <= index);
// TODO: cleanup these queues on state change
Waiting.emplace(TWaiting{index, 0, std::move(command), replyTo});
}
} else {
Forward(std::move(command), replyTo);
}
}

void TRequestProcessor::OnCommandRequest(TMessageHolder<TCommandRequest> command, const std::shared_ptr<INode>& replyTo) {
auto flags = command->Flags;

if (!(flags & TCommandRequest::EWrite)) {
return OnReadRequest(std::move(command), replyTo);
} else {
return OnWriteRequest(std::move(command), replyTo);
}
}

void TRequestProcessor::OnCommandResponse(TMessageHolder<TCommandResponse> command) {
// forwarded
auto it = Cookie2Client.find(command->Cookie);
Expand Down Expand Up @@ -611,6 +652,7 @@ void TRequestProcessor::ProcessWaiting() {
while (!Waiting.empty() && Waiting.back().Index <= lastApplied) {
auto w = Waiting.back(); Waiting.pop();
TMessageHolder<TCommandResponse> reply;
auto cookie = w.Command->Cookie;;
if (w.Command->Flags & TCommandRequest::EWrite) {
while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) {
WriteAnswers.pop();
Expand All @@ -622,7 +664,18 @@ void TRequestProcessor::ProcessWaiting() {
} else {
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
}
reply->Cookie = w.Command->Cookie;
reply->Cookie = cookie;
w.ReplyTo->Send(std::move(reply));
}

auto seqno = Raft->CommitSeqno();
while (!StrongWaiting.empty() && StrongWaiting.back().Index <= lastApplied && StrongWaiting.back().Seqno <= seqno) {
auto w = StrongWaiting.back(); StrongWaiting.pop();
TMessageHolder<TCommandResponse> reply;
assert (!(w.Command->Flags & TCommandRequest::EWrite));
auto cookie = w.Command->Cookie;
reply = Rsm->Read(std::move(w.Command), w.Index).Cast<TCommandResponse>();
reply->Cookie = cookie;
w.ReplyTo->Send(std::move(reply));
}
}
Expand Down
12 changes: 11 additions & 1 deletion src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using TNodeDict = std::unordered_map<uint32_t, std::shared_ptr<INode>>;

struct TVolatileState {
uint64_t CommitIndex = 0;
uint64_t CommitSeqno = 0;
uint32_t LeaderId = 0;
std::unordered_map<uint32_t, uint64_t> NextIndex;
std::unordered_map<uint32_t, uint64_t> MatchIndex;
Expand All @@ -54,7 +55,7 @@ struct TVolatileState {
std::vector<uint64_t> Indices;

TVolatileState& Vote(uint32_t id);
TVolatileState& CommitAdvance(int nservers, const IState& state);
TVolatileState& CommitAdvance(int nservers, const IState& state, uint64_t seqno = 0);
TVolatileState& SetCommitIndex(int index);
TVolatileState& SetElectionDue(ITimeSource::Time);
TVolatileState& SetNextIndex(uint32_t id, uint64_t nextIndex);
Expand Down Expand Up @@ -89,6 +90,8 @@ class TRaft {
uint64_t Append(TMessageHolder<TLogEntry> entry);
uint32_t GetLeaderId() const;
uint64_t GetLastIndex() const;
uint64_t ApproveRead();
uint64_t CommitSeqno() const;

// ut
const auto& GetState() const {
Expand Down Expand Up @@ -146,6 +149,7 @@ class TRaft {
int Nservers;
std::shared_ptr<IState> State;
std::unique_ptr<TVolatileState> VolatileState;
uint64_t Seqno = 0; // for matching responses

EState StateName;
uint32_t Seed = 31337;
Expand All @@ -167,16 +171,22 @@ class TRequestProcessor {
void CleanUp(const std::shared_ptr<INode>& replyTo);

private:
void Forward(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnReadRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);
void OnWriteRequest(TMessageHolder<TCommandRequest> message, const std::shared_ptr<INode>& replyTo);

std::shared_ptr<TRaft> Raft;
std::shared_ptr<IRsm> Rsm;
TNodeDict Nodes;

struct TWaiting {
uint64_t Index;
uint64_t Seqno = 0;
TMessageHolder<TCommandRequest> Command;
std::shared_ptr<INode> ReplyTo;
};
std::queue<TWaiting> Waiting;
std::queue<TWaiting> StrongWaiting;
std::queue<TWaiting> WaitingStateChange;

struct TAnswer {
Expand Down
6 changes: 4 additions & 2 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ NNet::TVoidTask TRaftServer<TSocket>::InboundConnection(TSocket socket) {
Nodes.insert(client);
while (true) {
auto mes = co_await TMessageReader(client->Sock()).Read();
// client request
// client request
if (auto maybeCommandRequest = mes.template Maybe<TCommandRequest>()) {
RequestProcessor->OnCommandRequest(std::move(maybeCommandRequest.Cast()), client);
} else if (auto maybeCommandResponse = mes.template Maybe<TCommandResponse>()) {
Expand Down Expand Up @@ -170,10 +170,12 @@ NNet::TVoidTask TRaftServer<TSocket>::OutboundServe(std::shared_ptr<TNode<TSocke
while (true) {
bool error = false;
try {
if (!node->IsConnected()) {
throw std::runtime_error("Not connected");
}
auto mes = co_await TMessageReader(node->Sock()).Read();
if (auto maybeCommandResponse = mes.template Maybe<TCommandResponse>()) {
RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast()));
RequestProcessor->ProcessWaiting();
DrainNodes();
} else {
std::cerr << "Wrong message type: " << mes->Type << std::endl;
Expand Down
Loading