Skip to content

Commit 1ac9dfc

Browse files
committed
feat(ssh): add throttle package to network connections
1 parent 7783750 commit 1ac9dfc

File tree

2 files changed

+394
-0
lines changed

2 files changed

+394
-0
lines changed

ssh/pkg/dialer/throttle.go

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
package dialer
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"net"
8+
"sync"
9+
"time"
10+
11+
"golang.org/x/time/rate"
12+
)
13+
14+
// ErrNegativeLimit is returned when attempting to set a negative limit.
15+
var ErrNegativeLimit = errors.New("negative throttle limit")
16+
17+
// Option configures a Throttler.
18+
type Option func(*Throttler)
19+
20+
// WithReadLimit sets the read bytes-per-second limit and burst.
21+
// If bps <= 0 => unlimited. If burst <=0 it defaults to bps.
22+
func WithReadLimit(bps int, burst int) Option {
23+
return func(t *Throttler) {
24+
t.setLimiter(&t.readMu, &t.readLimiter, bps, burst)
25+
}
26+
}
27+
28+
// WithWriteLimit sets the write bytes-per-second limit and burst.
29+
// If bps <= 0 => unlimited. If burst <=0 it defaults to bps.
30+
func WithWriteLimit(bps int, burst int) Option {
31+
return func(t *Throttler) {
32+
t.setLimiter(&t.writeMu, &t.writeLimiter, bps, burst)
33+
}
34+
}
35+
36+
// Throttler wraps an underlying io.Reader / io.Writer (optionally both) and
37+
// enforces directional byte-per-second limits using token buckets.
38+
// It is safe for concurrent use of Read and Write.
39+
type Throttler struct {
40+
// Underlying read side (may be nil if only writing).
41+
R io.Reader
42+
// Underlying write side (may be nil if only reading).
43+
W io.Writer
44+
45+
readMu sync.RWMutex
46+
readLimiter *rate.Limiter
47+
48+
writeMu sync.RWMutex
49+
writeLimiter *rate.Limiter
50+
}
51+
52+
func NewThrottler(r io.Reader, w io.Writer, opts ...Option) *Throttler {
53+
t := &Throttler{R: r, W: w}
54+
55+
for _, o := range opts {
56+
o(t)
57+
}
58+
59+
return t
60+
}
61+
62+
// setLimiter (internal) creates or clears a limiter based on bps.
63+
func (t *Throttler) setLimiter(mu *sync.RWMutex, lim **rate.Limiter, bps int, burst int) {
64+
mu.Lock()
65+
defer mu.Unlock()
66+
67+
if bps <= 0 {
68+
*lim = nil
69+
70+
return
71+
}
72+
73+
if burst <= 0 {
74+
burst = bps
75+
}
76+
77+
*lim = rate.NewLimiter(rate.Limit(bps), burst)
78+
}
79+
80+
// UpdateReadLimit dynamically changes the read limit.
81+
func (t *Throttler) UpdateReadLimit(bps int, burst int) error {
82+
if bps < 0 || burst < 0 {
83+
return ErrNegativeLimit
84+
}
85+
86+
t.setLimiter(&t.readMu, &t.readLimiter, bps, burst)
87+
88+
return nil
89+
}
90+
91+
// UpdateWriteLimit dynamically changes the write limit.
92+
func (t *Throttler) UpdateWriteLimit(bps int, burst int) error {
93+
if bps < 0 || burst < 0 {
94+
return ErrNegativeLimit
95+
}
96+
97+
t.setLimiter(&t.writeMu, &t.writeLimiter, bps, burst)
98+
99+
return nil
100+
}
101+
102+
// Read implements io.Reader with throttling.
103+
func (t *Throttler) Read(p []byte) (int, error) {
104+
if t.R == nil {
105+
return 0, errors.New("read not supported (nil underlying Reader)")
106+
}
107+
108+
lim := t.getReadLimiter()
109+
110+
if lim == nil {
111+
return t.R.Read(p)
112+
}
113+
114+
maxChunk := lim.Burst()
115+
if maxChunk <= 0 {
116+
maxChunk = 32 * 1024
117+
}
118+
119+
total := 0
120+
for total < len(p) {
121+
remaining := len(p) - total
122+
chunk := min(remaining, maxChunk)
123+
124+
if err := lim.WaitN(context.Background(), chunk); err != nil {
125+
if total > 0 {
126+
return total, err
127+
}
128+
129+
return 0, err
130+
}
131+
132+
n, err := t.R.Read(p[total : total+chunk])
133+
total += n
134+
if err != nil || n == 0 {
135+
return total, err
136+
}
137+
138+
if n < chunk {
139+
break
140+
}
141+
}
142+
143+
return total, nil
144+
}
145+
146+
// Write implements io.Writer with throttling.
147+
func (t *Throttler) Write(p []byte) (int, error) {
148+
if t.W == nil {
149+
return 0, errors.New("write not supported (nil underlying Writer)")
150+
}
151+
152+
lim := t.getWriteLimiter()
153+
154+
if lim == nil {
155+
return t.W.Write(p)
156+
}
157+
158+
maxChunk := lim.Burst()
159+
if maxChunk <= 0 {
160+
maxChunk = 32 * 1024
161+
}
162+
163+
total := 0
164+
for total < len(p) {
165+
remaining := len(p) - total
166+
chunk := min(remaining, maxChunk)
167+
168+
if err := lim.WaitN(context.Background(), chunk); err != nil {
169+
if total > 0 {
170+
return total, err
171+
}
172+
173+
return 0, err
174+
}
175+
176+
n, err := t.W.Write(p[total : total+chunk])
177+
total += n
178+
if err != nil || n == 0 {
179+
return total, err
180+
}
181+
182+
if n < chunk {
183+
break
184+
}
185+
}
186+
187+
return total, nil
188+
}
189+
190+
// Helper getters with read locks for concurrency.
191+
func (t *Throttler) getReadLimiter() *rate.Limiter {
192+
t.readMu.RLock()
193+
defer t.readMu.RUnlock()
194+
195+
return t.readLimiter
196+
}
197+
198+
func (t *Throttler) getWriteLimiter() *rate.Limiter {
199+
t.writeMu.RLock()
200+
defer t.writeMu.RUnlock()
201+
202+
return t.writeLimiter
203+
}
204+
205+
type ConnThrottler struct {
206+
Conn net.Conn
207+
Throttler *Throttler
208+
}
209+
210+
func (c *ConnThrottler) Close() error {
211+
return c.Conn.Close()
212+
}
213+
214+
func (c *ConnThrottler) LocalAddr() net.Addr {
215+
return c.Conn.LocalAddr()
216+
}
217+
218+
func (c *ConnThrottler) Read(b []byte) (n int, err error) {
219+
return c.Throttler.Read(b)
220+
}
221+
222+
func (c *ConnThrottler) RemoteAddr() net.Addr {
223+
return c.Conn.RemoteAddr()
224+
}
225+
226+
func (c *ConnThrottler) SetDeadline(t time.Time) error {
227+
return c.Conn.SetDeadline(t)
228+
}
229+
230+
func (c *ConnThrottler) SetReadDeadline(t time.Time) error {
231+
return c.Conn.SetReadDeadline(t)
232+
}
233+
234+
func (c *ConnThrottler) SetWriteDeadline(t time.Time) error {
235+
return c.Conn.SetWriteDeadline(t)
236+
}
237+
238+
func (c *ConnThrottler) Write(b []byte) (n int, err error) {
239+
return c.Throttler.Write(b)
240+
}
241+
242+
func NewConnThrottler(conn net.Conn, readBps, readBurst, writeBps, writeBurst int) net.Conn {
243+
return &ConnThrottler{
244+
Conn: conn,
245+
Throttler: NewThrottler(conn, conn, WithReadLimit(readBps, readBurst), WithWriteLimit(writeBps, writeBurst)),
246+
}
247+
}

ssh/pkg/dialer/throttle_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package dialer
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"net"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func expectedMinDuration(total, bps, burst int) time.Duration {
14+
if bps <= 0 {
15+
return 0
16+
}
17+
18+
remaining := total - burst
19+
if remaining <= 0 {
20+
return 0
21+
}
22+
23+
secs := float64(remaining) / float64(bps)
24+
25+
return time.Duration(secs * float64(time.Second))
26+
}
27+
28+
func TestThrottler_TableDriven(t *testing.T) {
29+
cases := []struct {
30+
name string
31+
run func(t *testing.T)
32+
}{
33+
{
34+
name: "UnlimitedReadFast",
35+
run: func(t *testing.T) {
36+
data := bytes.Repeat([]byte("x"), 1024)
37+
r := bytes.NewReader(data)
38+
39+
th := NewThrottler(r, nil) // no limits
40+
41+
buf := make([]byte, len(data))
42+
start := time.Now()
43+
n, err := th.Read(buf)
44+
dur := time.Since(start)
45+
46+
assert.Truef(t, err == nil || err == io.EOF, "unexpected read error: %v", err)
47+
assert.Equal(t, len(data), n, "read bytes mismatch")
48+
assert.LessOrEqual(t, dur, 100*time.Millisecond, "unlimited read took too long")
49+
},
50+
},
51+
{
52+
name: "NegativeLimitValidation",
53+
run: func(t *testing.T) {
54+
th := NewThrottler(nil, nil)
55+
err := th.UpdateReadLimit(-1, 1)
56+
assert.Equal(t, ErrNegativeLimit, err)
57+
err = th.UpdateWriteLimit(-1, 1)
58+
assert.Equal(t, ErrNegativeLimit, err)
59+
},
60+
},
61+
{
62+
name: "ReadRateEnforced",
63+
run: func(t *testing.T) {
64+
total := 200
65+
bps := 50
66+
burst := 10
67+
68+
data := bytes.Repeat([]byte("r"), total)
69+
r := bytes.NewReader(data)
70+
th := NewThrottler(r, nil, WithReadLimit(bps, burst))
71+
72+
buf := make([]byte, total)
73+
start := time.Now()
74+
n, err := th.Read(buf)
75+
dur := time.Since(start)
76+
77+
assert.Truef(t, err == nil || err == io.EOF, "unexpected read error: %v", err)
78+
assert.Equal(t, total, n, "read bytes mismatch")
79+
80+
expect := expectedMinDuration(total, bps, burst)
81+
// allow 20% timing slack for scheduler and test flakiness
82+
slack := expect / 5
83+
assert.Truef(t, dur+slack >= expect, "read duration = %v; want at least ~%v (with slack %v)", dur, expect, slack)
84+
},
85+
},
86+
{
87+
name: "WriteRateEnforced",
88+
run: func(t *testing.T) {
89+
total := 200
90+
bps := 50
91+
burst := 10
92+
93+
var bufOut bytes.Buffer
94+
th := NewThrottler(nil, &bufOut, WithWriteLimit(bps, burst))
95+
96+
data := bytes.Repeat([]byte("w"), total)
97+
start := time.Now()
98+
n, err := th.Write(data)
99+
dur := time.Since(start)
100+
101+
assert.NoError(t, err, "unexpected write error")
102+
assert.Equal(t, total, n, "written bytes mismatch")
103+
104+
expect := expectedMinDuration(total, bps, burst)
105+
slack := expect / 5
106+
assert.Truef(t, dur+slack >= expect, "write duration = %v; want at least ~%v (with slack %v)", dur, expect, slack)
107+
},
108+
},
109+
{
110+
name: "ConnThrottlerPassthrough",
111+
run: func(t *testing.T) {
112+
c1, c2 := net.Pipe()
113+
t.Cleanup(func() { c1.Close(); c2.Close() })
114+
115+
// Wrap c2 with unlimited throttler
116+
thrConn := NewConnThrottler(c2, 0, 0, 0, 0)
117+
118+
// write from c1, read from thrConn
119+
msg := []byte("hello-throttle")
120+
121+
done := make(chan error, 1)
122+
go func() {
123+
defer c1.Close()
124+
_, err := c1.Write(msg)
125+
done <- err
126+
}()
127+
128+
// read on wrapped conn
129+
got := make([]byte, len(msg))
130+
n, err := thrConn.Read(got)
131+
assert.Truef(t, err == nil || err == io.EOF, "conn read error: %v", err)
132+
assert.Equal(t, len(msg), n, "conn read bytes mismatch")
133+
assert.Equal(t, msg, got, "conn read data mismatch")
134+
135+
// ensure writer had no error
136+
err = <-done
137+
assert.NoError(t, err, "writer error")
138+
},
139+
},
140+
}
141+
142+
for _, tc := range cases {
143+
t.Run(tc.name, func(t *testing.T) {
144+
tc.run(t)
145+
})
146+
}
147+
}

0 commit comments

Comments
 (0)