Skip to content

rebroadcast empty votes until a round advances #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ const (
DefaultMaxRoundWindow = 10
DefaultMaxPendingBlocks = 20

DefaultMaxProposalWaitTime = 5 * time.Second
DefaultReplicationRequestTimeout = 5 * time.Second
DefaultMaxProposalWaitTime = 5 * time.Second
DefaultReplicationRequestTimeout = 5 * time.Second
DefaultEmptyVoteRebroadcastTimeout = 5 * time.Second
EmptyVoteTimeoutID = "rebroadcast_empty_vote"
)

type EmptyVoteSet struct {
Expand Down Expand Up @@ -54,6 +56,7 @@ func NewRound(block VerifiedBlock) *Round {

type EpochConfig struct {
MaxProposalWait time.Duration
MaxRebrodcastWait time.Duration
QCDeserializer QCDeserializer
Logger Logger
ID NodeID
Expand Down Expand Up @@ -92,8 +95,8 @@ type Epoch struct {
monitor *Monitor
haltedError error
cancelWaitForBlockNotarization context.CancelFunc

replicationState *ReplicationState
timeoutHandler *TimeoutHandler
replicationState *ReplicationState
}

func NewEpoch(conf EpochConfig) (*Epoch, error) {
Expand All @@ -107,6 +110,7 @@ func NewEpoch(conf EpochConfig) (*Epoch, error) {
func (e *Epoch) AdvanceTime(t time.Time) {
e.monitor.AdvanceTime(t)
e.replicationState.AdvanceTime(t)
e.timeoutHandler.Tick(t)
}

// HandleMessage notifies the engine about a reception of a message.
Expand Down Expand Up @@ -176,6 +180,7 @@ func (e *Epoch) init() error {
e.eligibleNodeIDs = make(map[string]struct{}, len(e.nodes))
e.futureMessages = make(messagesFromNode, len(e.nodes))
e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.maxRoundWindow, e.ReplicationEnabled, e.StartTime)
e.timeoutHandler = NewTimeoutHandler(e.Logger, e.StartTime, e.nodes)

for _, node := range e.nodes {
e.futureMessages[string(node)] = make(map[uint64]*messagesForRound)
Expand Down Expand Up @@ -1955,12 +1960,29 @@ func (e *Epoch) triggerProposalWaitTimeExpired(round uint64) {

e.Comm.Broadcast(&Message{EmptyVoteMessage: &signedEV})

e.addEmptyVoteRebroadcastTimeout(&signedEV)

if err := e.maybeAssembleEmptyNotarization(); err != nil {
e.Logger.Error("Failed assembling empty notarization", zap.Error(err))
e.haltedError = err
}
}

func (e *Epoch) addEmptyVoteRebroadcastTimeout(vote *EmptyVote) {
task := &TimeoutTask{
NodeID: e.ID,
TaskID: EmptyVoteTimeoutID,
Deadline: e.timeoutHandler.GetTime().Add(e.EpochConfig.MaxRebrodcastWait),
Task: func() {
e.Logger.Debug("Rebroadcasting empty vote because round has not advanced", zap.Uint64("round", vote.Vote.Round))
e.Comm.Broadcast(&Message{EmptyVoteMessage: vote})
e.addEmptyVoteRebroadcastTimeout(vote)
},
}

e.timeoutHandler.AddTask(task)
}

func (e *Epoch) monitorProgress(round uint64) {
e.Logger.Debug("Monitoring progress", zap.Uint64("round", round))
ctx, cancelContext := context.WithCancel(context.Background())
Expand Down Expand Up @@ -2103,6 +2125,8 @@ func (e *Epoch) increaseRound() {
// we advanced to the next round.
e.cancelWaitForBlockNotarization()

// remove the rebroadcast empty vote task
e.timeoutHandler.RemoveTask(e.ID, EmptyVoteTimeoutID)
e.deleteEmptyVoteForPreviousRound()

leader := LeaderForRound(e.nodes, e.round)
Expand Down
113 changes: 113 additions & 0 deletions epoch_failover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package simplex_test
import (
"context"
"fmt"
"simplex"
. "simplex"
"simplex/testutil"
"sync/atomic"
Expand Down Expand Up @@ -765,6 +766,118 @@ func TestEpochLeaderFailoverNotNeeded(t *testing.T) {
require.False(t, timedOut.Load())
}

type rebroadcastCounterComm struct {
nodes []NodeID
count atomic.Uint64
mostRecentRound atomic.Uint64
}

func (r *rebroadcastCounterComm) ListNodes() []NodeID {
return r.nodes
}

func (r *rebroadcastCounterComm) SendMessage(*Message, NodeID) {

}

func (r *rebroadcastCounterComm) Broadcast(msg *Message) {
if msg.EmptyVoteMessage != nil {
r.mostRecentRound.Store(msg.EmptyVoteMessage.Vote.Round)
r.count.Add(1)
}
}

func TestEpochRebroadcastsEmptyVote(t *testing.T) {
l := testutil.MakeLogger(t, 2)
bb := &testBlockBuilder{out: make(chan *testBlock, 1), blockShouldBeBuilt: make(chan struct{}, 1)}
storage := newInMemStorage()

nodes := []NodeID{{1}, {2}, {3}, {4}}

wal := newTestWAL(t)

epochTime := time.Now()
comm := &rebroadcastCounterComm{
nodes: nodes,
}
conf := EpochConfig{
MaxProposalWait: DefaultMaxProposalWaitTime,
MaxRebrodcastWait: 500 * time.Millisecond,
StartTime: epochTime,
Logger: l,
ID: nodes[3], // so we are not the leader
Signer: &testSigner{},
WAL: wal,
Verifier: &testVerifier{},
Storage: storage,
Comm: comm,
BlockBuilder: bb,
SignatureAggregator: &testSignatureAggregator{},
}

e, err := NewEpoch(conf)
require.NoError(t, err)

require.NoError(t, e.Start())
require.Equal(t, uint64(0), e.Metadata().Round)
require.Equal(t, uint64(0), e.Metadata().Round)
require.False(t, wal.containsEmptyVote(0))

bb.blockShouldBeBuilt <- struct{}{}
time.Sleep(10 * time.Millisecond)
epochTime = epochTime.Add(DefaultMaxProposalWaitTime)
e.AdvanceTime(epochTime)
wal.assertEmptyVote(0)
wal.assertWALSize(1)
require.Equal(t, uint64(1), comm.count.Load())

// reset to get rebroadcast count
comm.count.Store(0)
for i := range 10 {
epochTime = epochTime.Add(e.MaxRebrodcastWait)
e.AdvanceTime(epochTime)
time.Sleep(10 * time.Millisecond)
require.Equal(t, uint64(i + 1), comm.count.Load())
require.Equal(t, uint64(0), comm.mostRecentRound.Load())
wal.assertWALSize(1)
}

emptyNotarization := newEmptyNotarization(nodes, 0, 0)
e.HandleMessage(&simplex.Message{
EmptyNotarization: emptyNotarization,
}, nodes[2])

wal.assertNotarization(0)

comm.count.Store(0)
// ensure the rebroadcast was canceled
for range 10 {
epochTime = epochTime.Add(e.MaxRebrodcastWait)
e.AdvanceTime(epochTime)
time.Sleep(10 * time.Millisecond)
require.Equal(t, uint64(0), comm.count.Load())
}

// assert that we continue to rebraodcast, but for a different round now
bb.blockShouldBeBuilt <- struct{}{}
time.Sleep(10 * time.Millisecond)
epochTime = epochTime.Add(DefaultMaxProposalWaitTime)
e.AdvanceTime(epochTime)
wal.assertEmptyVote(1)
wal.assertWALSize(3)

// reset to get rebroadcast count
comm.count.Store(0)
for i := range 10 {
epochTime = epochTime.Add(e.MaxRebrodcastWait)
e.AdvanceTime(epochTime)
time.Sleep(10 * time.Millisecond)
require.Equal(t, uint64(i + 1), comm.count.Load())
require.Equal(t, uint64(1), comm.mostRecentRound.Load())
wal.assertWALSize(3)
}
}

func runCrashAndRestartExecution(t *testing.T, e *Epoch, bb *testBlockBuilder, wal *testWAL, storage *InMemStorage, f epochExecution) {
// Split the test into two scenarios:
// 1) The node proceeds as usual.
Expand Down
23 changes: 23 additions & 0 deletions epoch_multinode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,29 @@ func (tw *testWAL) assertNotarizationOrFinalization(round uint64, qc QCDeseriali

}

func (tw *testWAL) assertEmptyVote(round uint64) {
tw.lock.Lock()
defer tw.lock.Unlock()

for {
rawRecords, err := tw.WriteAheadLog.ReadAll()
require.NoError(tw.t, err)

for _, rawRecord := range rawRecords {
if binary.BigEndian.Uint16(rawRecord[:2]) == record.EmptyVoteRecordType {
vote, err := ParseEmptyVoteRecord(rawRecord)
require.NoError(tw.t, err)

if vote.Round == round {
return
}
}
}

tw.signal.Wait()
}
}

func (tw *testWAL) containsEmptyVote(round uint64) bool {
tw.lock.Lock()
defer tw.lock.Unlock()
Expand Down
32 changes: 28 additions & 4 deletions replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ func (r *ReplicationState) createReplicationTimeoutTask(start, end uint64, nodes
r.sendRequestToNode(start, end, nodes, (index+1)%len(nodes))
}
timeoutTask := &TimeoutTask{
Start: start,
End: end,
Data: &ReplicationTimeoutData{
Start: start,
End: end,
},
NodeID: nodes[index],
TaskID: getTimeoutID(start, end),
Task: taskFunc,
Expand All @@ -193,15 +195,17 @@ func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node

slices.Sort(seqs)

task := r.timeoutHandler.FindTask(node, seqs)
task := FindTask(r.timeoutHandler, node, seqs)
if task == nil {
r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node))
return
}
r.timeoutHandler.RemoveTask(node, task.TaskID)

taskData := task.Data
replicationData := taskData.(*ReplicationTimeoutData)
// we found the timeout, now make sure all seqs were returned
missing := findMissingNumbersInRange(task.Start, task.End, seqs)
missing := findMissingNumbersInRange(replicationData.Start, replicationData.End, seqs)
if len(missing) == 0 {
return
}
Expand Down Expand Up @@ -322,3 +326,23 @@ func (r *ReplicationState) GetQuroumRoundWithSeq(seq uint64) *QuorumRound {
}
return nil
}

// FindTask returns the first TimeoutTask assigned to [node] that contains any sequence in [seqs].
// A sequence is considered "contained" if it falls between a task's Start (inclusive) and End (inclusive).
// func (t *TimeoutHandler) FindTask(node NodeID, seqs []uint64) *TimeoutTask {
func FindTask(t *TimeoutHandler, node NodeID, seqs []uint64) *TimeoutTask {
t.lock.Lock()
defer t.lock.Unlock()

for _, task := range t.tasks[string(node)] {
data := task.Data
replicationData := data.(*ReplicationTimeoutData)
for _, seq := range seqs {
if seq >= replicationData.Start && seq <= replicationData.End {
return task
}
}
}

return nil
}
25 changes: 6 additions & 19 deletions timeout_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ import (
"go.uber.org/zap"
)

type ReplicationTimeoutData struct {
Start uint64
End uint64
}

type TimeoutTask struct {
NodeID NodeID
TaskID string
Start uint64
End uint64
Data any
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanted to make TimeoutTask more generic

Task func()
Deadline time.Time

Expand Down Expand Up @@ -168,23 +172,6 @@ func (t *TimeoutHandler) Close() {
}
}

// FindTask returns the first TimeoutTask assigned to [node] that contains any sequence in [seqs].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you move this to replication? Just because it doesn't have a th receiver?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

// A sequence is considered "contained" if it falls between a task's Start (inclusive) and End (inclusive).
func (t *TimeoutHandler) FindTask(node NodeID, seqs []uint64) *TimeoutTask {
t.lock.Lock()
defer t.lock.Unlock()

for _, seq := range seqs {
for _, t := range t.tasks[string(node)] {
if seq >= t.Start && seq <= t.End {
return t
}
}
}

return nil
}

const delimiter = "_"

func getTimeoutID(start, end uint64) string {
Expand Down
26 changes: 17 additions & 9 deletions timeout_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,29 +231,37 @@ func TestFindTask(t *testing.T) {
task1 := &simplex.TimeoutTask{
TaskID: "task1",
NodeID: nodes[0],
Start: 5,
End: 10,
Data: &simplex.ReplicationTimeoutData{
Start: 5,
End: 10,
},
}

taskSameRangeDiffNode := &simplex.TimeoutTask{
TaskID: "taskSameDiff",
NodeID: nodes[1],
Start: 5,
End: 10,
Data: &simplex.ReplicationTimeoutData{
Start: 5,
End: 10,
},
}

task3 := &simplex.TimeoutTask{
TaskID: "task3",
NodeID: nodes[1],
Start: 25,
End: 30,
Data: &simplex.ReplicationTimeoutData{
Start: 25,
End: 30,
},
}

task4 := &simplex.TimeoutTask{
TaskID: "task4",
NodeID: nodes[1],
Start: 31,
End: 36,
Data: &simplex.ReplicationTimeoutData{
Start: 31,
End: 36,
},
}

// Add tasks to handler
Expand Down Expand Up @@ -320,7 +328,7 @@ func TestFindTask(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.FindTask(tt.node, tt.seqs)
result := simplex.FindTask(handler, tt.node, tt.seqs)
require.Equal(t, tt.expected, result)
})
}
Expand Down
Loading