Skip to content

Commit b63d10a

Browse files
committed
fix: add throttled check when connection closed by peer
1 parent d5b0914 commit b63d10a

File tree

6 files changed

+150
-7
lines changed

6 files changed

+150
-7
lines changed

connection_reactor.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,31 @@ func (c *connection) pauseWrite() {
156156
// pauseRead removed the monitoring of read events.
157157
// pauseRead used in poller
158158
func (c *connection) pauseRead() {
159+
// Note that the poller ensure that every fd should read all left data in socket buffer before detach it.
160+
// So the operator mode should never be ophup.
161+
var changeTo PollEvent
159162
switch c.operator.getMode() {
160163
case opread:
161-
c.operator.Control(PollR2Hup)
164+
changeTo = PollR2Hup
162165
case opreadwrite:
163-
c.operator.Control(PollRW2W)
166+
changeTo = PollRW2W
167+
}
168+
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) {
169+
c.operator.Control(changeTo)
164170
}
165171
}
166172

167173
// resumeRead add the monitoring of read events.
168174
// resumeRead used by users
169175
func (c *connection) resumeRead() {
176+
var changeTo PollEvent
170177
switch c.operator.getMode() {
171178
case ophup:
172-
c.operator.Control(PollHup2R)
179+
changeTo = PollHup2R
173180
case opwrite:
174-
c.operator.Control(PollW2RW)
181+
changeTo = PollW2RW
182+
}
183+
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) {
184+
c.operator.Control(changeTo)
175185
}
176186
}

connection_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,65 @@ func TestConnectionReadThreshold(t *testing.T) {
763763

764764
wg.Wait()
765765
}
766+
767+
func TestConnectionReadThresholdWithClosed(t *testing.T) {
768+
var readThreshold int64 = 1024 * 100
769+
var opts = &options{}
770+
var trigger = make(chan struct{})
771+
opts.onRequest = func(ctx context.Context, connection Connection) error {
772+
if int64(connection.Reader().Len()) < readThreshold {
773+
return nil
774+
}
775+
Equal(t, connection.Reader().Len(), int(readThreshold))
776+
trigger <- struct{}{} // let client send final msg and close
777+
<-trigger // wait for client send and close
778+
779+
// read non-throttled data
780+
buf, err := connection.Reader().Next(int(readThreshold))
781+
Equal(t, int64(len(buf)), readThreshold)
782+
MustNil(t, err)
783+
err = connection.Reader().Release()
784+
MustNil(t, err)
785+
t.Logf("read non-throttled data")
786+
787+
// continue read throttled data
788+
buf, err = connection.Reader().Next(5)
789+
MustNil(t, err)
790+
t.Logf("read throttled data: [%s]", buf)
791+
Equal(t, len(buf), 5)
792+
MustNil(t, err)
793+
err = connection.Reader().Release()
794+
MustNil(t, err)
795+
Equal(t, connection.Reader().Len(), 0)
796+
797+
_, err = connection.Reader().Next(1)
798+
Assert(t, errors.Is(err, ErrEOF))
799+
trigger <- struct{}{}
800+
return nil
801+
}
802+
803+
WithReadBufferThreshold(readThreshold).f(opts)
804+
r, w := GetSysFdPairs()
805+
rconn, wconn := &connection{}, &connection{}
806+
rconn.init(&netFD{fd: r}, opts)
807+
wconn.init(&netFD{fd: w}, opts)
808+
Assert(t, rconn.readBufferThreshold == readThreshold)
809+
810+
msg := make([]byte, readThreshold)
811+
_, err := wconn.Writer().WriteBinary(msg)
812+
MustNil(t, err)
813+
err = wconn.Writer().Flush()
814+
MustNil(t, err)
815+
816+
<-trigger
817+
_, err = wconn.Writer().WriteString("hello")
818+
MustNil(t, err)
819+
err = wconn.Writer().Flush()
820+
MustNil(t, err)
821+
t.Logf("flush final msg")
822+
err = wconn.Close()
823+
MustNil(t, err)
824+
trigger <- struct{}{}
825+
826+
<-trigger
827+
}

fd_operator.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ type FDOperator struct {
5151
// poll is the registered location of the file descriptor.
5252
poll Poll
5353

54-
mode int32
54+
mode int32
55+
throttled int32
5556

5657
// private, used by operatorCache
5758
next *FDOperator
@@ -112,4 +113,5 @@ func (op *FDOperator) reset() {
112113
op.Outputs, op.OutputAck = nil, nil
113114
op.poll = nil
114115
op.mode = 0
116+
op.throttled = 0
115117
}

netpoll_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,73 @@ func TestReadThresholdOption(t *testing.T) {
505505
wg.Wait()
506506
}
507507

508+
func TestReadThresholdClosed(t *testing.T) {
509+
/*
510+
client => server: 102400 bytes + 5 bytes
511+
client => server: close connection
512+
server cached: 102400 bytes, and throttled
513+
server read: 102400 bytes, and unthrottled
514+
server cached: 5 bytes
515+
server read: 5 bytes
516+
*/
517+
readThreshold := 1024 * 100
518+
trigger := make(chan struct{})
519+
msg1 := make([]byte, readThreshold)
520+
msg2 := []byte("hello")
521+
522+
// server
523+
ln, err := CreateListener("tcp", ":12345")
524+
MustNil(t, err)
525+
svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error {
526+
if connection.Reader().Len() < readThreshold {
527+
return nil
528+
}
529+
// server read
530+
t.Logf("server reading msg1")
531+
trigger <- struct{}{} // let client send msg2
532+
<-trigger // ensure client send msg2 and closed
533+
total := 0
534+
for {
535+
msg, err := connection.Reader().Next(1)
536+
total += len(msg)
537+
if errors.Is(err, ErrEOF) {
538+
break
539+
}
540+
_ = msg
541+
}
542+
Equal(t, total, readThreshold+5)
543+
close(trigger)
544+
return nil
545+
}, WithReadBufferThreshold(int64(readThreshold)))
546+
defer svr.Shutdown(context.Background())
547+
go func() {
548+
svr.Serve(ln)
549+
}()
550+
time.Sleep(time.Millisecond * 100)
551+
552+
// client write
553+
dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold)))
554+
cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second)
555+
MustNil(t, err)
556+
t.Logf("client writing msg1")
557+
_, err = cli.Writer().WriteBinary(msg1)
558+
MustNil(t, err)
559+
err = cli.Writer().Flush()
560+
MustNil(t, err)
561+
<-trigger
562+
time.Sleep(time.Millisecond * 100)
563+
t.Logf("client writing msg2")
564+
_, err = cli.Writer().WriteBinary(msg2)
565+
MustNil(t, err)
566+
err = cli.Writer().Flush()
567+
MustNil(t, err)
568+
err = cli.Close()
569+
MustNil(t, err)
570+
t.Logf("client closed")
571+
trigger <- struct{}{}
572+
<-trigger
573+
}
574+
508575
func createTestListener(network, address string) (Listener, error) {
509576
for {
510577
ln, err := CreateListener(network, address)

poll_default_bsd.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ func (p *defaultPoll) Wait() error {
115115
}
116116
}
117117
if triggerHup {
118-
if triggerRead && operator.Inputs != nil {
118+
// if peer closed with throttled state, we should ensure we read all left data to avoid data loss
119+
if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil {
119120
var leftRead int
120121
// read all left data if peer send and close
121122
if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) {

poll_default_linux.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
168168
}
169169
}
170170
if triggerHup {
171-
if triggerRead && operator.Inputs != nil {
171+
// if peer closed with throttled state, we should ensure we read all left data to avoid data loss
172+
if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil {
172173
// read all left data if peer send and close
173174
var leftRead int
174175
// read all left data if peer send and close

0 commit comments

Comments
 (0)