Skip to content

Commit d7778be

Browse files
committed
Add ability to skip publishing to partial message capable peers
1 parent 10c55de commit d7778be

File tree

4 files changed

+127
-4
lines changed

4 files changed

+127
-4
lines changed

gossipsub.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,13 @@ func (gs *GossipSubRouter) PublishBatch(messages []*Message, opts *BatchPublishO
11831183
}
11841184
}
11851185

1186+
func (gs *GossipSubRouter) skipPartialMessageCapablePeers(topicID string) bool {
1187+
if t, ok := gs.p.myTopics[topicID]; ok {
1188+
return t.skipPublishingToPartialMessageCapablePeers
1189+
}
1190+
return false
1191+
}
1192+
11861193
func (gs *GossipSubRouter) Publish(msg *Message) {
11871194
for p, rpc := range gs.rpcs(msg) {
11881195
gs.sendRPC(p, rpc, false)
@@ -1259,8 +1266,9 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
12591266
}
12601267

12611268
out := rpcWithMessages(msg.Message)
1269+
skipPMPeers := gs.skipPartialMessageCapablePeers(msg.GetTopic())
12621270
for pid := range tosend {
1263-
if pid == from || pid == peer.ID(msg.GetFrom()) {
1271+
if pid == from || pid == peer.ID(msg.GetFrom()) || (skipPMPeers && gs.extensions.peerExtensions[pid].PartialMessages) {
12641272
continue
12651273
}
12661274

gossipsub_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4377,5 +4377,102 @@ outer:
43774377
if len(missing) != 0 {
43784378
t.Errorf("Expected no missing parts, got %v", missing)
43794379
}
4380+
}
4381+
4382+
func TestSkipPublishingToPeersWithPartialMessageSupport(t *testing.T) {
4383+
topicName := "test-topic"
4384+
4385+
// 3 hosts.
4386+
// hosts[0]: Publisher. Supports partial messages
4387+
// hosts[1]: Subscriber. Supports partial messages
4388+
// hosts[2]: Alternate publisher. Does not support partial messages. Only
4389+
// connected to hosts[0]
4390+
hosts := getDefaultHosts(t, 3)
4391+
4392+
partialExt := make([]*partialmessages.PartialMessageExtension, 2)
4393+
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
4394+
4395+
for i := range partialExt {
4396+
partialExt[i] = &partialmessages.PartialMessageExtension{
4397+
Logger: logger,
4398+
ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error {
4399+
return nil
4400+
},
4401+
EagerIWantLimitPerHeartbeat: 0,
4402+
IWantLimitPerHeartbeat: 1,
4403+
NewPartialMessage: func(topic string, groupID []byte) (partialmessages.PartialMessage, error) {
4404+
return &minimalTestPartialMessage{
4405+
Group: groupID,
4406+
onExtended: func(m *minimalTestPartialMessage) {
4407+
t.Logf("Received new part and extended partial message")
4408+
},
4409+
}, nil
4410+
},
4411+
}
4412+
}
4413+
4414+
psubs := make([]*PubSub, 0, len(hosts)-1)
4415+
for i, h := range hosts[:2] {
4416+
psub := getGossipsub(context.Background(), h, WithPartialMessagesExtension(partialExt[i]))
4417+
psubs = append(psubs, psub)
4418+
}
4419+
4420+
nonPartialPubsub := getGossipsub(context.Background(), hosts[2])
4421+
4422+
denseConnect(t, hosts[:2])
4423+
time.Sleep(2 * time.Second)
43804424

4425+
// Connect nonPartialPubsub to the publisher
4426+
connect(t, hosts[0], hosts[2])
4427+
4428+
var topics []*Topic
4429+
var subs []*Subscription
4430+
for _, psub := range psubs {
4431+
topic, err := psub.Join(topicName, WithSkipPublishingToPartialMessageCapablePeers())
4432+
if err != nil {
4433+
t.Fatal(err)
4434+
}
4435+
topics = append(topics, topic)
4436+
s, err := topic.Subscribe()
4437+
if err != nil {
4438+
t.Fatal(err)
4439+
}
4440+
subs = append(subs, s)
4441+
}
4442+
4443+
topicForNonPartial, err := nonPartialPubsub.Join(topicName)
4444+
if err != nil {
4445+
t.Fatal(err)
4446+
}
4447+
4448+
// Wait for subscriptions to propagate
4449+
time.Sleep(time.Second)
4450+
4451+
topics[0].Publish(context.Background(), []byte("Hello"))
4452+
4453+
// Publish from another peer, the publisher (psub[0]) should not forward this to psub[1].
4454+
topicForNonPartial.Publish(context.Background(), []byte("from non-partial"))
4455+
4456+
recvdMessage := make(chan struct{}, 1)
4457+
ctx, cancel := context.WithCancel(context.Background())
4458+
defer cancel()
4459+
go func() {
4460+
msg, err := subs[1].Next(ctx)
4461+
if err == context.Canceled {
4462+
return
4463+
}
4464+
if err != nil {
4465+
t.Log(err)
4466+
t.Fail()
4467+
return
4468+
}
4469+
t.Log("Received msg", string(msg.Data))
4470+
recvdMessage <- struct{}{}
4471+
}()
4472+
4473+
select {
4474+
case <-recvdMessage:
4475+
t.Fatal("Received message")
4476+
case <-time.After(2 * time.Second):
4477+
}
43814478
}

pubsub.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) {
13651365
continue
13661366
}
13671367

1368-
msg := &Message{pmsg, "", rpc.from, nil, false}
1368+
msg := &Message{Message: pmsg, ID: "", ReceivedFrom: rpc.from, ValidatorData: nil, Local: false}
13691369
if p.shouldPush(msg) {
13701370
toPush = append(toPush, msg)
13711371
}
@@ -1504,7 +1504,16 @@ type rmTopicReq struct {
15041504
resp chan error
15051505
}
15061506

1507-
type TopicOptions struct{}
1507+
type TopicOptions struct {
1508+
SkipPublishingToPartialMessageCapablePeers bool
1509+
}
1510+
1511+
func WithSkipPublishingToPartialMessageCapablePeers() TopicOpt {
1512+
return func(t *Topic) error {
1513+
t.skipPublishingToPartialMessageCapablePeers = true
1514+
return nil
1515+
}
1516+
}
15081517

15091518
type TopicOpt func(t *Topic) error
15101519

topic.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type Topic struct {
3232

3333
mux sync.RWMutex
3434
closed bool
35+
36+
skipPublishingToPartialMessageCapablePeers bool
3537
}
3638

3739
// String returns the topic associated with t
@@ -348,7 +350,14 @@ func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Mes
348350
}
349351
}
350352

351-
msg := &Message{m, "", t.p.host.ID(), pub.validatorData, pub.local}
353+
msg := &Message{
354+
Message: m,
355+
ID: "",
356+
ReceivedFrom: t.p.host.ID(),
357+
ValidatorData: pub.validatorData,
358+
Local: pub.local,
359+
}
360+
352361
t.p.eval <- func() {
353362
t.p.rt.Preprocess(t.p.host.ID(), []*Message{msg})
354363
}

0 commit comments

Comments
 (0)