diff --git a/CHANGELOG.md b/CHANGELOG.md index dff1d2de4..06e13dfae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,6 @@ All notable changes to this project will be documented in this file. ### Added -- Validate AccessPass before client connection (CLI) ([#1356](https://github.com/malbeclabs/doublezero/issues/1356)) - Onchain programs - Check if `accesspass.owner` is equal to system program ([malbeclabs/doublezero#2088](https://github.com/malbeclabs/doublezero/pull/2088)) @@ -14,8 +13,10 @@ All notable changes to this project will be documented in this file. - Add a new sub-type field to interface definitions to support CYOA and DIA interfaces. This sub-type allows the system to distinguish between standard Physical/Loopback interfaces and specialized CYOA/DIA interfaces, enabling proper classification, validation, and configuration handling across the DZD. - Improve error message when connecting to a device that is at capacity or has max_users=0. Users now receive "Device is not accepting more users (at capacity or max_users=0)" instead of the confusing "Device not found" error when explicitly specifying an ineligible device. - Add `link latency` command to display latency statistics from the telemetry program. Supports filtering by percentile (p50, p90, p95, p99, mean, min, max, stddev, all), querying by link code or all links, and filtering by epoch. Resolves: [#1942](https://github.com/malbeclabs/doublezero/issues/1942) - -- Added `--contributor | -c` filter to `device list`, `interface list`, and `link list` commands. (#1274) + - Added `--contributor | -c` filter to `device list`, `interface list`, and `link list` commands. (#1274) + - Validate AccessPass before client connection ([#1356](https://github.com/malbeclabs/doublezero/issues/1356)) +- Client + - Add initial route liveness probing, initially disabled for rollout ### Breaking diff --git a/Makefile b/Makefile index a0a4ccae6..5676873a5 100644 --- a/Makefile +++ b/Makefile @@ -50,7 +50,7 @@ nocontainertest: .PHONY: go-fuzz go-fuzz: cd tools/twamp && $(MAKE) fuzz - cd tools/uping && $(MAKE) fuzz + cd client/doublezerod && $(MAKE) fuzz .PHONY: go-container-test go-container-test: diff --git a/client/doublezerod/Makefile b/client/doublezerod/Makefile index 460b6b9fa..324623b43 100644 --- a/client/doublezerod/Makefile +++ b/client/doublezerod/Makefile @@ -18,3 +18,11 @@ lint: .PHONY: build build: CGO_ENABLED=0 go build -v $(LDFLAGS) -o bin/doublezerod cmd/doublezerod/main.go + +FUZZTIME ?= 10s +.PHONY: fuzz +fuzz: + @for f in $$(go test ./internal/liveness -list=Fuzz | grep '^Fuzz'); do \ + echo "==> Fuzzing $$f"; \ + go test ./internal/liveness -run=^$$ -fuzz=$$f -fuzztime=$(FUZZTIME) || exit 1; \ + done diff --git a/client/doublezerod/cmd/doublezerod/main.go b/client/doublezerod/cmd/doublezerod/main.go index 55f464480..20afcc9ca 100644 --- a/client/doublezerod/cmd/doublezerod/main.go +++ b/client/doublezerod/cmd/doublezerod/main.go @@ -11,7 +11,9 @@ import ( "os" "os/signal" "syscall" + "time" + "github.com/malbeclabs/doublezero/client/doublezerod/internal/liveness" "github.com/malbeclabs/doublezero/client/doublezerod/internal/runtime" "github.com/malbeclabs/doublezero/config" "github.com/prometheus/client_golang/prometheus" @@ -34,12 +36,38 @@ var ( metricsAddr = flag.String("metrics-addr", "localhost:0", "Address to listen on for prometheus metrics") routeConfigPath = flag.String("route-config", "/var/lib/doublezerod/route-config.json", "path to route config file (unstable)") + // Route liveness configuration flags. + routeLivenessTxMin = flag.Duration("route-liveness-tx-min", defaultRouteLivenessTxMin, "route liveness tx min") + routeLivenessRxMin = flag.Duration("route-liveness-rx-min", defaultRouteLivenessRxMin, "route liveness rx min") + routeLivenessDetectMult = flag.Uint("route-liveness-detect-mult", defaultRouteLivenessDetectMult, "route liveness detect mult") + routeLivenessMinTxFloor = flag.Duration("route-liveness-min-tx-floor", defaultRouteLivenessMinTxFloor, "route liveness min tx floor") + routeLivenessMaxTxCeil = flag.Duration("route-liveness-max-tx-ceil", defaultRouteLivenessMaxTxCeil, "route liveness max tx ceil") + + // TODO(snormore): These flags are temporary for initial rollout testing. + // They will be superceded by a single `route-liveness-enable` flag, where false means + // passive-mode and true means active-mode. + routeLivenessEnablePassive = flag.Bool("route-liveness-enable-passive", false, "enables route liveness in passive mode (experimental)") + routeLivenessEnableActive = flag.Bool("route-liveness-enable-active", false, "enables route liveness in active mode (experimental)") + // set by LDFLAGS version = "dev" commit = "none" date = "unknown" ) +const ( + defaultRouteLivenessTxMin = 300 * time.Millisecond + defaultRouteLivenessRxMin = 300 * time.Millisecond + defaultRouteLivenessDetectMult = 3 + defaultRouteLivenessMinTxFloor = 50 * time.Millisecond + defaultRouteLivenessMaxTxCeil = 1 * time.Second + + // The liveness port is not configurable since clients need to use the same one so they know + // how to connect to each other. + defaultRouteLivenessPort = 44880 + defaultRouteLivenessBindIP = "0.0.0.0" +) + func main() { flag.Parse() @@ -112,10 +140,33 @@ func main() { }() } + // If either passive or active mode is enabled, create a manager config. + // If neither is enabled, completely disable the liveness subsystem. + // TODO(snormore): The scenario where the liveness subsystem is completely disabled is + // temporary for initial rollout testing. + var lmc *liveness.ManagerConfig + if *routeLivenessEnablePassive || *routeLivenessEnableActive { + lmc = &liveness.ManagerConfig{ + Logger: slog.Default(), + BindIP: defaultRouteLivenessBindIP, + Port: defaultRouteLivenessPort, + + // If active mode is enabled, set passive mode to false. + // The manager only knows about passive mode, with the negation of it being active mode. + PassiveMode: !*routeLivenessEnableActive, + + TxMin: *routeLivenessTxMin, + RxMin: *routeLivenessRxMin, + DetectMult: uint8(*routeLivenessDetectMult), + MinTxFloor: *routeLivenessMinTxFloor, + MaxTxCeil: *routeLivenessMaxTxCeil, + } + } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - if err := runtime.Run(ctx, *sockFile, *routeConfigPath, *enableLatencyProbing, *enableLatencyMetrics, *programId, *rpcEndpoint, *probeInterval, *cacheUpdateInterval); err != nil { + if err := runtime.Run(ctx, *sockFile, *routeConfigPath, *enableLatencyProbing, *enableLatencyMetrics, *programId, *rpcEndpoint, *probeInterval, *cacheUpdateInterval, lmc); err != nil { slog.Error("runtime error", "error", err) os.Exit(1) } diff --git a/client/doublezerod/internal/bgp/bgp.go b/client/doublezerod/internal/bgp/bgp.go index eb046979c..28c8f79ac 100644 --- a/client/doublezerod/internal/bgp/bgp.go +++ b/client/doublezerod/internal/bgp/bgp.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/jwhited/corebgp" + "github.com/malbeclabs/doublezero/client/doublezerod/internal/liveness" "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" ) @@ -90,15 +91,17 @@ type RouteReaderWriter interface { } type PeerConfig struct { - LocalAddress net.IP - RemoteAddress net.IP - LocalAs uint32 - RemoteAs uint32 - Port int - RouteSrc net.IP - RouteTable int - FlushRoutes bool - NoInstall bool + LocalAddress net.IP + RemoteAddress net.IP + LocalAs uint32 + RemoteAs uint32 + Port int + RouteSrc net.IP + RouteTable int + FlushRoutes bool + NoInstall bool + Interface string + LivenessEnabled bool } type BgpServer struct { @@ -107,9 +110,10 @@ type BgpServer struct { peerStatus map[string]Session peerStatusLock sync.Mutex routeReaderWriter RouteReaderWriter + livenessManager *liveness.Manager } -func NewBgpServer(routerID net.IP, r RouteReaderWriter) (*BgpServer, error) { +func NewBgpServer(routerID net.IP, r RouteReaderWriter, lm *liveness.Manager) (*BgpServer, error) { corebgp.SetLogger(log.Print) srv, err := corebgp.NewServer(netip.MustParseAddr(routerID.String())) if err != nil { @@ -121,6 +125,7 @@ func NewBgpServer(routerID net.IP, r RouteReaderWriter) (*BgpServer, error) { peerStatus: make(map[string]Session), peerStatusLock: sync.Mutex{}, routeReaderWriter: r, + livenessManager: lm, }, nil } @@ -142,7 +147,11 @@ func (b *BgpServer) AddPeer(p *PeerConfig, advertised []NLRI) error { if p.Port != 0 { peerOpts = append(peerOpts, corebgp.WithPort(p.Port)) } - plugin := NewBgpPlugin(advertised, p.RouteSrc, p.RouteTable, b.peerStatusChan, p.FlushRoutes, p.NoInstall, b.routeReaderWriter) + rrw := b.routeReaderWriter + if p.LivenessEnabled && b.livenessManager != nil { + rrw = liveness.NewRouteReaderWriter(b.livenessManager, b.routeReaderWriter, p.Interface) + } + plugin := NewBgpPlugin(advertised, p.RouteSrc, p.RouteTable, b.peerStatusChan, p.FlushRoutes, p.NoInstall, rrw) err := b.server.AddPeer(corebgp.PeerConfig{ RemoteAddress: netip.MustParseAddr(p.RemoteAddress.String()), LocalAS: p.LocalAs, diff --git a/client/doublezerod/internal/bgp/bgp_test.go b/client/doublezerod/internal/bgp/bgp_test.go index 6e08e138a..247c1a2b6 100644 --- a/client/doublezerod/internal/bgp/bgp_test.go +++ b/client/doublezerod/internal/bgp/bgp_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "log/slog" "net" "net/netip" "strings" @@ -15,9 +16,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/jwhited/corebgp" "github.com/malbeclabs/doublezero/client/doublezerod/internal/bgp" + "github.com/malbeclabs/doublezero/client/doublezerod/internal/liveness" "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" gobgp "github.com/osrg/gobgp/pkg/packet/bgp" "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" "golang.org/x/sys/unix" ) @@ -114,7 +117,20 @@ func (p *dummyPlugin) handleUpdate(peer corebgp.PeerConfig, u []byte) *corebgp.N func TestBgpServer(t *testing.T) { nlr := &mockRouteReaderWriter{} - b, err := bgp.NewBgpServer(net.IP{1, 1, 1, 1}, nlr) + lm, err := liveness.NewManager(t.Context(), &liveness.ManagerConfig{ + Logger: slog.Default(), + Netlinker: nlr, + BindIP: "127.0.0.1", + Port: 0, + TxMin: 100 * time.Millisecond, + RxMin: 100 * time.Millisecond, + DetectMult: 3, + MinTxFloor: 50 * time.Millisecond, + MaxTxCeil: 1 * time.Second, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = lm.Close() }) + b, err := bgp.NewBgpServer(net.IP{1, 1, 1, 1}, nlr, lm) if err != nil { t.Fatalf("error creating bgp server: %v", err) } diff --git a/client/doublezerod/internal/bgp/plugin.go b/client/doublezerod/internal/bgp/plugin.go index 36ffb4f38..39c31d0b6 100644 --- a/client/doublezerod/internal/bgp/plugin.go +++ b/client/doublezerod/internal/bgp/plugin.go @@ -93,11 +93,11 @@ func (p *Plugin) OnClose(peer corebgp.PeerConfig) { protocol := unix.RTPROT_BGP // 186 routes, err := p.RouteReaderWriter.RouteByProtocol(protocol) if err != nil { - slog.Error("routes: error getting routes by protocol", "protocol", protocol) + slog.Error("routes: error getting routes by protocol on peer close", "protocol", protocol, "error", err) } for _, route := range routes { if err := p.RouteReaderWriter.RouteDelete(route); err != nil { - slog.Error("Error deleting route", "route", route) + slog.Error("routes: error deleting route on peer close", "route", route.String(), "error", err) continue } } @@ -126,7 +126,7 @@ func (p *Plugin) handleUpdate(peer corebgp.PeerConfig, u []byte) *corebgp.Notifi slog.Info("routes: removing route from table", "table", p.RouteTable, "dz route", route.String()) err := p.RouteReaderWriter.RouteDelete(route) if err != nil { - slog.Error("routes: error removing route from table", "table", p.RouteTable, "error", err) + slog.Error("routes: error removing route from table", "table", p.RouteTable, "error", err, "route", route.String()) } } @@ -152,7 +152,7 @@ func (p *Plugin) handleUpdate(peer corebgp.PeerConfig, u []byte) *corebgp.Notifi Protocol: unix.RTPROT_BGP} slog.Info("routes: writing route", "table", p.RouteTable, "dz route", route.String()) if err := p.RouteReaderWriter.RouteAdd(route); err != nil { - slog.Error("routes: error writing route", "table", p.RouteTable, "error", err) + slog.Error("routes: error writing route", "table", p.RouteTable, "error", err, "route", route.String()) } } return nil diff --git a/client/doublezerod/internal/liveness/manager.go b/client/doublezerod/internal/liveness/manager.go new file mode 100644 index 000000000..e23663c29 --- /dev/null +++ b/client/doublezerod/internal/liveness/manager.go @@ -0,0 +1,487 @@ +package liveness + +import ( + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "log/slog" + "net" + "sync" + "time" + + "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" +) + +const ( + // Default floors/ceilings for TX interval clamping; chosen to avoid + // overly chatty probes and to keep failure detection reasonably fast. + defaultMinTxFloor = 50 * time.Millisecond + defaultMaxTxCeil = 1 * time.Second + + defaultMaxEvents = 10240 +) + +// Peer identifies a remote endpoint and the local interface context used to reach it. +// LocalIP is the IP on which we send/receive; PeerIP is the peer’s address. +type Peer struct { + Interface string + LocalIP string + PeerIP string +} + +func (p *Peer) String() string { + return fmt.Sprintf("interface: %s, localIP: %s, peerIP: %s", p.Interface, p.LocalIP, p.PeerIP) +} + +// RouteKey uniquely identifies a desired/installed route in the kernel. +// This is used as a stable key in Manager maps across lifecycle events. +type RouteKey struct { + Interface string + SrcIP string + Table int + DstPrefix string + NextHop string +} + +// ManagerConfig controls Manager behavior, routing integration, and liveness timings. +type ManagerConfig struct { + Logger *slog.Logger + Netlinker RouteReaderWriter + UDP *UDPService + + BindIP string // local bind address for the UDP socket (IPv4) + Port int // UDP port to listen/transmit on + + // PassiveMode: if true, Manager does NOT manage kernel routes automatically. + // Instead it defers to Netlinker calls made by the caller. This enables + // incremental rollout (observe liveness without changing dataplane). + PassiveMode bool + + // Local desired probe intervals and detection multiplier for new sessions. + TxMin time.Duration + RxMin time.Duration + DetectMult uint8 + + // Global bounds for interval clamping and exponential backoff. + MinTxFloor time.Duration + MaxTxCeil time.Duration + BackoffMax time.Duration + + // Maximum number of events to keep in the scheduler queue. + // This is an upper bound for safety to prevent unbounded + // memory usage in the event of regressions. + // suggested: 4 * expected number of sessions + // default: 10,240 + MaxEvents int +} + +// Validate fills defaults and enforces constraints for ManagerConfig. +// Returns a descriptive error when required fields are missing/invalid. +func (c *ManagerConfig) Validate() error { + if c.Logger == nil { + return errors.New("logger is required") + } + if c.Netlinker == nil { + return errors.New("netlinker is required") + } + if c.BindIP == "" { + return errors.New("bind IP is required") + } + if c.Port < 0 { + return errors.New("port must be greater than or equal to 0") + } + if c.TxMin <= 0 { + return errors.New("txMin must be greater than 0") + } + if c.RxMin <= 0 { + return errors.New("rxMin must be greater than 0") + } + if c.DetectMult <= 0 { + return errors.New("detectMult must be greater than 0") + } + if c.MinTxFloor == 0 { + c.MinTxFloor = defaultMinTxFloor + } + if c.MinTxFloor < 0 { + return errors.New("minTxFloor must be greater than 0") + } + if c.MaxTxCeil == 0 { + c.MaxTxCeil = defaultMaxTxCeil + } + if c.MaxTxCeil < 0 { + return errors.New("maxTxCeil must be greater than 0") + } + if c.MaxTxCeil < c.MinTxFloor { + return errors.New("maxTxCeil must be greater than minTxFloor") + } + if c.BackoffMax == 0 { + c.BackoffMax = c.MaxTxCeil + } + if c.BackoffMax < 0 { + return errors.New("backoffMax must be greater than 0") + } + if c.BackoffMax < c.MinTxFloor { + return errors.New("backoffMax must be greater than or equal to minTxFloor") + } + if c.MaxEvents == 0 { + c.MaxEvents = defaultMaxEvents + } + if c.MaxEvents < 0 { + return errors.New("maxEvents must be greater than 0") + } + return nil +} + +// Manager orchestrates liveness sessions per peer, integrates with routing, +// and runs the receiver/scheduler goroutines. It is safe for concurrent use. +type Manager struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + errCh chan error + + log *slog.Logger + cfg *ManagerConfig + udp *UDPService // shared UDP transport + + sched *Scheduler // time-wheel/event-loop for TX/detect + recv *Receiver // UDP packet reader → HandleRx + + mu sync.Mutex + sessions map[Peer]*Session // active sessions keyed by Peer + desired map[RouteKey]*routing.Route // routes we want installed + installed map[RouteKey]bool // whether route is in kernel + + // Rate-limited warnings for packets from unknown peers (not in sessions). + unkownPeerErrWarnEvery time.Duration + unkownPeerErrWarnLast time.Time + unkownPeerErrWarnMu sync.Mutex +} + +// NewManager constructs a Manager, opens the UDP socket, and launches the +// receiver and scheduler loops. The context governs their lifetime. +func NewManager(ctx context.Context, cfg *ManagerConfig) (*Manager, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("error validating manager config: %v", err) + } + + udp := cfg.UDP + if udp == nil { + var err error + udp, err = ListenUDP(cfg.BindIP, cfg.Port) + if err != nil { + return nil, fmt.Errorf("error creating UDP connection: %w", err) + } + } + + log := cfg.Logger + log.Info("liveness: manager starting", "localAddr", udp.LocalAddr().String(), "txMin", cfg.TxMin, "rxMin", cfg.RxMin, "detectMult", cfg.DetectMult, "passiveMode", cfg.PassiveMode) + + ctx, cancel := context.WithCancel(ctx) + m := &Manager{ + ctx: ctx, + cancel: cancel, + errCh: make(chan error, 10), + + log: log, + cfg: cfg, + udp: udp, + + sessions: make(map[Peer]*Session), + desired: make(map[RouteKey]*routing.Route), + installed: make(map[RouteKey]bool), + + unkownPeerErrWarnEvery: 5 * time.Second, + } + + // Wire up IO loops. + m.recv = NewReceiver(m.log, m.udp, m.HandleRx) + m.sched = NewScheduler(m.log, m.udp, m.onSessionDown, m.cfg.MaxEvents) + + // Receiver goroutine: parses control packets and dispatches to HandleRx. + m.wg.Add(1) + go func() { + defer m.wg.Done() + err := m.recv.Run(m.ctx) + if err != nil { + m.log.Error("liveness: error running receiver", "error", err) + cancel() + m.errCh <- err + } + }() + + // Scheduler goroutine: handles periodic TX and detect expirations. + m.wg.Add(1) + go func() { + defer m.wg.Done() + err := m.sched.Run(m.ctx) + if err != nil { + m.log.Error("liveness: error running scheduler", "error", err) + cancel() + m.errCh <- err + } + }() + + return m, nil +} + +// Err returns a channel that will receive any errors from the manager. +func (m *Manager) Err() chan error { + return m.errCh +} + +// RegisterRoute declares interest in monitoring reachability for route r via iface. +// It optionally installs the route immediately in PassiveMode, then creates or +// reuses a liveness Session and schedules immediate TX to begin handshake. +func (m *Manager) RegisterRoute(r *routing.Route, iface string) error { + // Check that the route src and dst are valid IPv4 addresses. + if r.Src == nil || r.Dst.IP == nil { + return fmt.Errorf("error registering route: nil source or destination IP") + } + if r.Src.To4() == nil || r.Dst.IP.To4() == nil { + return fmt.Errorf("error registering route: non-IPv4 source (%s) or destination IP (%s)", r.Src.String(), r.Dst.IP.String()) + } + srcIP := r.Src.To4().String() + dstIP := r.Dst.IP.To4().String() + + if m.cfg.PassiveMode { + // In passive-mode we still update the kernel immediately (caller’s policy), + // while also running liveness for observability. + if err := m.cfg.Netlinker.RouteAdd(r); err != nil { + return fmt.Errorf("error registering route: %v", err) + } + } + + peerAddr, err := net.ResolveUDPAddr("udp", peerAddrFor(r, m.cfg.Port)) + if err != nil { + return fmt.Errorf("error resolving peer address: %v", err) + } + + k := routeKeyFor(iface, r) + m.mu.Lock() + m.desired[k] = r + m.mu.Unlock() + + peer := Peer{Interface: iface, LocalIP: srcIP, PeerIP: dstIP} + m.log.Info("liveness: registering route", "route", r.String(), "peerAddr", peerAddr.String()) + + m.mu.Lock() + if _, ok := m.sessions[peer]; ok { + m.mu.Unlock() + return nil // session already exists + } + // Create a fresh session in Down with a random non-zero discriminator. + s := &Session{ + route: r, + localDiscr: rand32(), + state: StateDown, // Initial Phase: start Down until handshake + detectMult: m.cfg.DetectMult, // governs detect timeout = mult × rxInterval + localTxMin: m.cfg.TxMin, + localRxMin: m.cfg.RxMin, + peer: &peer, + peerAddr: peerAddr, + alive: true, // session is under management (TX/detect active) + minTxFloor: m.cfg.MinTxFloor, // clamp lower bound + maxTxCeil: m.cfg.MaxTxCeil, // clamp upper bound + backoffMax: m.cfg.BackoffMax, // cap for exponential backoff while Down + backoffFactor: 1, + } + m.sessions[peer] = s + // Kick off the first TX immediately; detect is armed after we see valid RX. + m.sched.scheduleTx(time.Now(), s) + m.mu.Unlock() + + return nil +} + +// WithdrawRoute removes interest in r via iface. It tears down the session, +// marks it not managed (alive=false), and withdraws the route if needed. +func (m *Manager) WithdrawRoute(r *routing.Route, iface string) error { + // Check that the route src and dst are valid IPv4 addresses. + if r.Src == nil || r.Dst.IP == nil { + return fmt.Errorf("error withdrawing route: nil source or destination IP") + } + if r.Src.To4() == nil || r.Dst.IP.To4() == nil { + return fmt.Errorf("error withdrawing route: non-IPv4 source (%s) or destination IP (%s)", r.Src.String(), r.Dst.IP.String()) + } + srcIP := r.Src.To4().String() + dstIP := r.Dst.IP.To4().String() + + m.log.Info("liveness: withdrawing route", "route", r.String(), "iface", iface) + + if m.cfg.PassiveMode { + // Passive-mode: caller wants immediate kernel update independent of liveness. + if err := m.cfg.Netlinker.RouteDelete(r); err != nil { + return fmt.Errorf("error withdrawing route: %v", err) + } + } + + rk := routeKeyFor(iface, r) + m.mu.Lock() + delete(m.desired, rk) + wasInstalled := m.installed[rk] + delete(m.installed, rk) + m.mu.Unlock() + + peer := Peer{Interface: iface, LocalIP: srcIP, PeerIP: dstIP} + + // Mark session no longer managed and drop it from tracking. + m.mu.Lock() + if sess := m.sessions[peer]; sess != nil { + sess.mu.Lock() + sess.alive = false + sess.mu.Unlock() + } + delete(m.sessions, peer) + m.mu.Unlock() + + // If we previously installed the route (and not in PassiveMode), remove it now. + if wasInstalled && !m.cfg.PassiveMode { + return m.cfg.Netlinker.RouteDelete(r) + } + return nil +} + +// LocalAddr exposes the bound UDP address if available (or nil if closed/unset). +func (m *Manager) LocalAddr() *net.UDPAddr { + m.mu.Lock() + defer m.mu.Unlock() + if m.udp == nil { + return nil + } + if addr, ok := m.udp.LocalAddr().(*net.UDPAddr); ok { + return addr + } + return nil +} + +// Close stops goroutines, waits for exit, and closes the UDP socket. +// Returns the last close error, if any. +func (m *Manager) Close() error { + m.cancel() + m.wg.Wait() + + var cerr error + m.mu.Lock() + if m.udp != nil { + if err := m.udp.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + m.log.Warn("liveness: error closing connection", "error", err) + cerr = err + } + m.udp = nil + } + m.mu.Unlock() + + return cerr +} + +// HandleRx is the receiver callback: it routes an inbound control packet to the +// correct Session, drives its state machine, and schedules detect as needed. +func (m *Manager) HandleRx(ctrl *ControlPacket, peer Peer) { + now := time.Now() + + m.mu.Lock() + sess := m.sessions[peer] + if sess == nil { + // Throttle warnings for packets from unknown peers to avoid log spam. + m.unkownPeerErrWarnMu.Lock() + if m.unkownPeerErrWarnLast.IsZero() || time.Since(m.unkownPeerErrWarnLast) >= m.unkownPeerErrWarnEvery { + m.unkownPeerErrWarnLast = time.Now() + m.log.Warn("liveness: received control packet for unknown peer", "peer", peer.String(), "peerDiscrr", ctrl.peerDiscrr, "localDiscrr", ctrl.LocalDiscrr, "state", ctrl.State) + + } + m.unkownPeerErrWarnMu.Unlock() + + m.mu.Unlock() + return + } + + // Apply RX to the session FSM; only act when state actually changes. + changed := sess.HandleRx(now, ctrl) + + if changed { + switch sess.state { + case StateUp: + go m.onSessionUp(sess) + m.sched.scheduleDetect(now, sess) // keep detect armed while Up + case StateInit: + m.sched.scheduleDetect(now, sess) // arm detect; next >=Init promotes to Up + case StateDown: + // Transitioned to Down; withdraw and do NOT re-arm detect. + go m.onSessionDown(sess) + } + } else { + // No state change: just keep detect ticking for active states. + switch sess.state { + case StateUp, StateInit: + m.sched.scheduleDetect(now, sess) + default: + // Down/AdminDown: do nothing; avoid noisy logs. + } + } + m.mu.Unlock() +} + +// onSessionUp installs the route if it is desired and not already installed. +// In PassiveMode, install was already done at registration time. +func (m *Manager) onSessionUp(sess *Session) { + rk := routeKeyFor(sess.peer.Interface, sess.route) + m.mu.Lock() + route := m.desired[rk] + if route == nil || m.installed[rk] { + m.mu.Unlock() + return + } + m.installed[rk] = true + m.mu.Unlock() + if !m.cfg.PassiveMode { + err := m.cfg.Netlinker.RouteAdd(route) + if err != nil { + m.log.Error("liveness: error adding route on session up", "error", err, "route", route.String()) + } + } + m.log.Info("liveness: session up", "peer", sess.peer.String(), "route", sess.route.String()) +} + +// onSessionDown withdraws the route if currently installed (unless PassiveMode). +func (m *Manager) onSessionDown(sess *Session) { + rk := routeKeyFor(sess.peer.Interface, sess.route) + m.mu.Lock() + route := m.desired[rk] + wasInstalled := m.installed[rk] + m.installed[rk] = false + m.mu.Unlock() + if wasInstalled && route != nil { + if !m.cfg.PassiveMode { + err := m.cfg.Netlinker.RouteDelete(route) + if err != nil { + m.log.Error("liveness: error deleting route on session down", "error", err, "route", route.String()) + } + } + m.log.Info("liveness: session down", "peer", sess.peer.String(), "route", sess.route.String()) + } +} + +// rand32 returns a non-zero random uint32 for use as a discriminator. +// (BFD treats 0 as invalid; ensure we never emit 0.) +func rand32() uint32 { + var b [4]byte + _, _ = rand.Read(b[:]) + v := binary.BigEndian.Uint32(b[:]) + if v == 0 { + v = 1 + } + return v +} + +// routeKeyFor builds a RouteKey for map indexing based on interface + route fields. +func routeKeyFor(iface string, r *routing.Route) RouteKey { + return RouteKey{Interface: iface, SrcIP: r.Src.To4().String(), Table: r.Table, DstPrefix: r.Dst.IP.To4().String(), NextHop: r.NextHop.To4().String()} +} + +// peerAddrFor returns ":" for UDP control messages to a peer. +func peerAddrFor(r *routing.Route, port int) string { + return fmt.Sprintf("%s:%d", r.Dst.IP.To4().String(), port) +} diff --git a/client/doublezerod/internal/liveness/manager_test.go b/client/doublezerod/internal/liveness/manager_test.go new file mode 100644 index 000000000..c73d04d2f --- /dev/null +++ b/client/doublezerod/internal/liveness/manager_test.go @@ -0,0 +1,573 @@ +package liveness + +import ( + "errors" + "log/slog" + "net" + "sync" + "testing" + "time" + + "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +func TestClient_LivenessManager_ConfigValidate(t *testing.T) { + t.Parallel() + log := newTestLogger(t) + + err := (&ManagerConfig{Netlinker: &MockRouteReaderWriter{}, BindIP: "127.0.0.1"}).Validate() + require.Error(t, err) + + err = (&ManagerConfig{Logger: log, BindIP: "127.0.0.1"}).Validate() + require.Error(t, err) + + err = (&ManagerConfig{Logger: log, Netlinker: &MockRouteReaderWriter{}, BindIP: ""}).Validate() + require.Error(t, err) + + err = (&ManagerConfig{Logger: log, Netlinker: &MockRouteReaderWriter{}, BindIP: "127.0.0.1", MinTxFloor: -1}).Validate() + require.Error(t, err) + err = (&ManagerConfig{Logger: log, Netlinker: &MockRouteReaderWriter{}, BindIP: "127.0.0.1", MaxTxCeil: -1}).Validate() + require.Error(t, err) + err = (&ManagerConfig{Logger: log, Netlinker: &MockRouteReaderWriter{}, BindIP: "127.0.0.1", BackoffMax: -1}).Validate() + require.Error(t, err) + + err = (&ManagerConfig{ + Logger: log, + Netlinker: &MockRouteReaderWriter{}, + BindIP: "127.0.0.1", + TxMin: 100 * time.Millisecond, + RxMin: 100 * time.Millisecond, + DetectMult: 3, + MinTxFloor: 200 * time.Millisecond, + MaxTxCeil: 100 * time.Millisecond, + Port: -1, // invalid port + }).Validate() + require.EqualError(t, err, "port must be greater than or equal to 0") + + cfg := &ManagerConfig{ + Logger: log, + Netlinker: &MockRouteReaderWriter{}, + BindIP: "127.0.0.1", + TxMin: 100 * time.Millisecond, + RxMin: 100 * time.Millisecond, + DetectMult: 3, + MinTxFloor: 50 * time.Millisecond, + MaxTxCeil: 1 * time.Second, + } + err = cfg.Validate() + require.NoError(t, err) + require.NotZero(t, cfg.MinTxFloor) + require.NotZero(t, cfg.MaxTxCeil) + require.NotZero(t, cfg.BackoffMax) + require.GreaterOrEqual(t, int64(cfg.MaxTxCeil), int64(cfg.MinTxFloor)) + require.GreaterOrEqual(t, int64(cfg.BackoffMax), int64(cfg.MinTxFloor)) +} + +func TestClient_LivenessManager_NewManager_BindsAndLocalAddr(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + la := m.LocalAddr() + require.NotNil(t, la) + require.Equal(t, "127.0.0.1", la.IP.String()) + require.NotZero(t, la.Port) +} + +func TestClient_LivenessManager_RegisterRoute_Deduplicates(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + r := newTestRoute(func(r *routing.Route) { + r.Src = net.IPv4(127, 0, 0, 1) + r.Dst = &net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(32, 32)} + }) + + err = m.RegisterRoute(r, "lo") + require.NoError(t, err) + err = m.RegisterRoute(r, "lo") + require.NoError(t, err) + + m.mu.Lock() + require.Len(t, m.sessions, 1) + require.Contains(t, m.sessions, Peer{Interface: "lo", LocalIP: r.Src.String(), PeerIP: r.Dst.IP.String()}) + require.NotContains(t, m.sessions, Peer{Interface: "lo", LocalIP: r.Dst.IP.String(), PeerIP: r.Src.String()}) + m.mu.Unlock() +} + +func TestClient_LivenessManager_HandleRx_Transitions_AddAndDelete(t *testing.T) { + t.Parallel() + + addCh := make(chan *routing.Route, 1) + delCh := make(chan *routing.Route, 1) + + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.Netlinker = &MockRouteReaderWriter{ + RouteAddFunc: func(r *routing.Route) error { addCh <- r; return nil }, + RouteDeleteFunc: func(r *routing.Route) error { delCh <- r; return nil }, + RouteGetFunc: func(net.IP) ([]*routing.Route, error) { return nil, nil }, + RouteByProtocolFunc: func(int) ([]*routing.Route, error) { return nil, nil }, + } + }) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + r := newTestRoute(func(r *routing.Route) { + r.Src = net.IPv4(127, 0, 0, 1) + r.Dst = &net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(32, 32)} + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + + var sess *Session + var peer Peer + func() { + m.mu.Lock() + defer m.mu.Unlock() + for p, s := range m.sessions { + peer = p + sess = s + break + } + }() + require.NotNil(t, sess) + + m.HandleRx(&ControlPacket{peerDiscrr: 0, LocalDiscrr: 1234, State: StateDown}, peer) + func() { + sess.mu.Lock() + defer sess.mu.Unlock() + require.Equal(t, StateInit, sess.state) + require.EqualValues(t, 1234, sess.peerDiscr) + }() + + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateInit}, peer) + added := wait(t, addCh, 2*time.Second, "RouteAdd after Up") + require.Equal(t, r.Table, added.Table) + require.Equal(t, r.Src.String(), added.Src.String()) + require.Equal(t, r.Dst.String(), added.Dst.String()) + require.Equal(t, r.NextHop.String(), added.NextHop.String()) + + m.mu.Lock() + require.Len(t, m.sessions, 1) + require.Contains(t, m.sessions, peer) + require.NotContains(t, m.sessions, Peer{Interface: "lo", LocalIP: r.Dst.IP.String(), PeerIP: r.Src.String()}) + require.Equal(t, StateUp, sess.state) + m.mu.Unlock() + + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateDown}, peer) + deleted := wait(t, delCh, 2*time.Second, "RouteDelete after Down") + require.Equal(t, r.Table, deleted.Table) + require.Equal(t, r.Src.String(), deleted.Src.String()) + require.Equal(t, r.Dst.String(), deleted.Dst.String()) + + m.mu.Lock() + require.Len(t, m.sessions, 1) + require.Contains(t, m.sessions, peer) + require.NotContains(t, m.sessions, Peer{Interface: "lo", LocalIP: r.Dst.IP.String(), PeerIP: r.Src.String()}) + require.Equal(t, StateDown, sess.state) + m.mu.Unlock() +} + +func TestClient_LivenessManager_WithdrawRoute_RemovesSessionAndDeletesIfInstalled(t *testing.T) { + t.Parallel() + + addCh := make(chan *routing.Route, 1) + delCh := make(chan *routing.Route, 1) + nlr := &MockRouteReaderWriter{ + RouteAddFunc: func(r *routing.Route) error { addCh <- r; return nil }, + RouteDeleteFunc: func(r *routing.Route) error { delCh <- r; return nil }, + RouteGetFunc: func(net.IP) ([]*routing.Route, error) { return nil, nil }, + RouteByProtocolFunc: func(int) ([]*routing.Route, error) { return nil, nil }, + } + + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.Netlinker = nlr + }) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + r := newTestRoute(func(r *routing.Route) { + r.Dst = &net.IPNet{IP: m.LocalAddr().IP, Mask: net.CIDRMask(32, 32)} + r.Src = m.LocalAddr().IP + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + + var peer Peer + var sess *Session + func() { + m.mu.Lock() + defer m.mu.Unlock() + for p, s := range m.sessions { + peer, sess = p, s + break + } + }() + // Down -> Init (learn peerDiscr) + m.HandleRx(&ControlPacket{peerDiscrr: 0, LocalDiscrr: 1, State: StateInit}, peer) + // Init -> Up requires explicit echo (peerDiscrr == localDiscr) + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateInit}, peer) + wait(t, addCh, 2*time.Second, "RouteAdd before withdraw") + + require.NoError(t, m.WithdrawRoute(r, "lo")) + wait(t, delCh, 2*time.Second, "RouteDelete on withdraw") + + m.mu.Lock() + _, still := m.sessions[peer] + m.mu.Unlock() + require.False(t, still, "session should be removed after withdraw") + + sess.mu.Lock() + require.False(t, sess.alive) + sess.mu.Unlock() +} + +func TestClient_LivenessManager_Close_Idempotent(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.Netlinker = &MockRouteReaderWriter{} + }) + require.NoError(t, err) + require.NoError(t, m.Close()) + require.NoError(t, m.Close()) +} + +func TestClient_LivenessManager_HandleRx_UnknownPeer_NoEffect(t *testing.T) { + t.Parallel() + + nlr := &MockRouteReaderWriter{ + RouteAddFunc: func(*routing.Route) error { return nil }, + RouteDeleteFunc: func(*routing.Route) error { return nil }, + RouteGetFunc: func(net.IP) ([]*routing.Route, error) { return nil, nil }, + RouteByProtocolFunc: func(int) ([]*routing.Route, error) { return nil, nil }, + } + + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.Netlinker = nlr + }) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + // Register a real session to ensure maps are non-empty. + r := newTestRoute(func(r *routing.Route) { + r.Dst = &net.IPNet{IP: m.LocalAddr().IP, Mask: net.CIDRMask(32, 32)} + r.Src = m.LocalAddr().IP + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + + m.mu.Lock() + prevSessions := len(m.sessions) + prevInstalled := len(m.installed) + m.mu.Unlock() + + // Construct a peer key that doesn't exist. + unknown := Peer{Interface: "lo", LocalIP: "127.0.0.2", PeerIP: "127.0.0.3"} + m.HandleRx(&ControlPacket{peerDiscrr: 0, LocalDiscrr: 1, State: StateInit}, unknown) + + // Assert no changes. + m.mu.Lock() + defer m.mu.Unlock() + require.Equal(t, prevSessions, len(m.sessions)) + require.Equal(t, prevInstalled, len(m.installed)) +} + +func TestClient_LivenessManager_NetlinkerErrors_NoCrash(t *testing.T) { + t.Parallel() + + addErr := errors.New("add boom") + delErr := errors.New("del boom") + nlr := &MockRouteReaderWriter{ + RouteAddFunc: func(*routing.Route) error { return addErr }, + RouteDeleteFunc: func(*routing.Route) error { return delErr }, + RouteGetFunc: func(net.IP) ([]*routing.Route, error) { return nil, nil }, + RouteByProtocolFunc: func(int) ([]*routing.Route, error) { return nil, nil }, + } + + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.Netlinker = nlr + }) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + + r := newTestRoute(func(r *routing.Route) { + r.Dst = &net.IPNet{IP: m.LocalAddr().IP, Mask: net.CIDRMask(32, 32)} + r.Src = m.LocalAddr().IP + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + + // Grab session+peer key to inspect installed flags. + var peer Peer + var sess *Session + func() { + m.mu.Lock() + defer m.mu.Unlock() + for p, s := range m.sessions { + peer, sess = p, s + break + } + }() + require.NotNil(t, sess) + + // Drive to Up (RouteAdd returns error but should not crash; installed set true). + m.HandleRx(&ControlPacket{peerDiscrr: 0, LocalDiscrr: 99, State: StateDown}, peer) // Down -> Init + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateUp}, peer) // Init -> Up + + rk := routeKeyFor(peer.Interface, sess.route) + time.Sleep(50 * time.Millisecond) // allow onSessionUp goroutine to run + + m.mu.Lock() + require.True(t, m.installed[rk], "installed should be true after Up even if RouteAdd errored") + m.mu.Unlock() + + // Drive to Down (RouteDelete returns error; should not crash; installed set false). + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateDown}, peer) + time.Sleep(50 * time.Millisecond) + + m.mu.Lock() + require.False(t, m.installed[rk], "installed should be false after Down even if RouteDelete errored") + m.mu.Unlock() +} + +func TestClient_LivenessManager_PassiveMode_ImmediateInstall_NoAutoWithdraw(t *testing.T) { + t.Parallel() + addCh := make(chan *routing.Route, 1) + delCh := make(chan *routing.Route, 1) + m, err := newTestManager(t, func(cfg *ManagerConfig) { + cfg.PassiveMode = true + cfg.Netlinker = &MockRouteReaderWriter{ + RouteAddFunc: func(r *routing.Route) error { addCh <- r; return nil }, + RouteDeleteFunc: func(r *routing.Route) error { delCh <- r; return nil }, + } + }) + require.NoError(t, err) + defer m.Close() + + r := newTestRoute(func(r *routing.Route) { + r.Src = net.IPv4(127, 0, 0, 1) + r.Dst = &net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(32, 32)} + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + _ = wait(t, addCh, time.Second, "immediate RouteAdd in PassiveMode") + + // drive Up then Down; expect no RouteDelete (caller owns dataplane) + var peer Peer + var sess *Session + func() { + m.mu.Lock() + defer m.mu.Unlock() + for p, s := range m.sessions { + peer, sess = p, s + break + } + }() + m.HandleRx(&ControlPacket{peerDiscrr: 0, LocalDiscrr: 1, State: StateInit}, peer) + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateUp}, peer) + m.HandleRx(&ControlPacket{peerDiscrr: sess.localDiscr, LocalDiscrr: sess.peerDiscr, State: StateDown}, peer) + + select { + case <-delCh: + t.Fatalf("unexpected RouteDelete in PassiveMode") + case <-time.After(150 * time.Millisecond): + } +} + +func TestClient_LivenessManager_LocalAddrNilAfterClose(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + require.NoError(t, m.Close()) + require.Nil(t, m.LocalAddr()) +} + +func TestClient_LivenessManager_PeerKey_IPv4Canonicalization(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + defer m.Close() + + r := newTestRoute(func(r *routing.Route) { + r.Src = net.IPv4(127, 0, 0, 1) + r.Dst = &net.IPNet{IP: net.IPv4(127, 0, 0, 2), Mask: net.CIDRMask(32, 32)} + }) + require.NoError(t, m.RegisterRoute(r, "lo")) + m.mu.Lock() + _, ok := m.sessions[Peer{Interface: "lo", LocalIP: r.Src.To4().String(), PeerIP: r.Dst.IP.To4().String()}] + m.mu.Unlock() + require.True(t, ok, "peer key should use IPv4 string forms") +} + +func TestClient_Liveness_Manager_ReceiverFailure_PropagatesOnErr(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + defer func() { _ = m.Close() }() + + errCh := m.Err() + + // Close the UDP socket directly to force Receiver.Run to error out. + var udp *UDPService + m.mu.Lock() + udp = m.udp + m.mu.Unlock() + require.NotNil(t, udp) + _ = udp.Close() + + // Expect an error to surface on Err(). + select { + case e := <-errCh: + require.Error(t, e) + default: + select { + case e := <-errCh: + require.Error(t, e) + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for error from manager.Err after UDP close") + } + } + + // Close should complete cleanly after the receiver failure. + require.NoError(t, m.Close()) +} + +func TestClient_Liveness_Manager_Close_NoErrOnErrCh(t *testing.T) { + t.Parallel() + m, err := newTestManager(t, nil) + require.NoError(t, err) + + // No spurious errors before close. + func() { + timer := time.NewTimer(200 * time.Millisecond) + defer timer.Stop() + select { + case <-timer.C: + return + case <-m.Err(): + t.Fatalf("unexpected error before Close") + } + }() + + require.NoError(t, m.Close()) + + // No spurious errors after close either. + func() { + timer := time.NewTimer(200 * time.Millisecond) + defer timer.Stop() + select { + case <-timer.C: + return + case <-m.Err(): + t.Fatalf("unexpected error after Close") + } + }() +} + +func newTestManager(t *testing.T, mutate func(*ManagerConfig)) (*Manager, error) { + cfg := &ManagerConfig{ + Logger: newTestLogger(t), + Netlinker: &MockRouteReaderWriter{}, + BindIP: "127.0.0.1", + Port: 0, + TxMin: 100 * time.Millisecond, + RxMin: 100 * time.Millisecond, + DetectMult: 3, + MinTxFloor: 50 * time.Millisecond, + MaxTxCeil: 1 * time.Second, + BackoffMax: 1 * time.Second, + } + if mutate != nil { + mutate(cfg) + } + return NewManager(t.Context(), cfg) +} + +type testWriter struct { + t *testing.T + mu sync.Mutex +} + +func (w *testWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + w.t.Logf("%s", p) + return len(p), nil +} + +func newTestLogger(t *testing.T) *slog.Logger { + w := &testWriter{t: t} + h := slog.NewTextHandler(w, &slog.HandlerOptions{Level: slog.LevelInfo}) + return slog.New(h) +} + +func wait[T any](t *testing.T, ch <-chan T, d time.Duration, name string) T { + t.Helper() + select { + case v := <-ch: + return v + case <-time.After(d): + t.Fatalf("timeout waiting for %s", name) + var z T + return z + } +} + +func newTestRoute(mutate func(*routing.Route)) *routing.Route { + r := &routing.Route{ + Table: 100, + Src: net.IPv4(10, 4, 0, 1), + Dst: &net.IPNet{IP: net.IPv4(10, 4, 0, 11), Mask: net.CIDRMask(32, 32)}, + NextHop: net.IPv4(10, 5, 0, 1), + Protocol: unix.RTPROT_BGP, + } + if mutate != nil { + mutate(r) + } + return r +} + +type MockRouteReaderWriter struct { + RouteAddFunc func(*routing.Route) error + RouteDeleteFunc func(*routing.Route) error + RouteGetFunc func(net.IP) ([]*routing.Route, error) + RouteByProtocolFunc func(int) ([]*routing.Route, error) + + mu sync.Mutex +} + +func (m *MockRouteReaderWriter) RouteAdd(r *routing.Route) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.RouteAddFunc == nil { + return nil + } + return m.RouteAddFunc(r) +} + +func (m *MockRouteReaderWriter) RouteDelete(r *routing.Route) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.RouteDeleteFunc == nil { + return nil + } + return m.RouteDeleteFunc(r) +} + +func (m *MockRouteReaderWriter) RouteGet(ip net.IP) ([]*routing.Route, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.RouteGetFunc == nil { + return nil, nil + } + return m.RouteGetFunc(ip) +} + +func (m *MockRouteReaderWriter) RouteByProtocol(protocol int) ([]*routing.Route, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.RouteByProtocolFunc == nil { + return nil, nil + } + return m.RouteByProtocolFunc(protocol) +} diff --git a/client/doublezerod/internal/liveness/packet.go b/client/doublezerod/internal/liveness/packet.go new file mode 100644 index 000000000..6c2e25aef --- /dev/null +++ b/client/doublezerod/internal/liveness/packet.go @@ -0,0 +1,94 @@ +package liveness + +import ( + "encoding/binary" + "fmt" +) + +// State encodes the finite-state machine for a BFD-like session. +// The progression follows AdminDown → Down → Init → Up, with +// transitions driven by control messages and detect timeouts. +type State uint8 + +const ( + StateAdminDown State = iota // administratively disabled, no detection + StateDown // no session detected or timed out + StateInit // attempting to establish connectivity + StateUp // session fully established +) + +// ControlPacket represents the wire format of a minimal BFD control packet. +// Fields mirror RFC 5880 §4.1 in a compact form using microsecond units for timers. +type ControlPacket struct { + Version uint8 // protocol version; expected to be 1 + State State // sender's current session state + DetectMult uint8 // detection multiplier (used by peer for detect timeout) + Length uint8 // total length, always 40 for this fixed-size implementation + LocalDiscrr uint32 // sender's discriminator (unique session ID) + peerDiscrr uint32 // discriminator of the remote session (echo back) + DesiredMinTxUs uint32 // minimum TX interval desired by sender (microseconds) + RequiredMinRxUs uint32 // minimum RX interval the sender can handle (microseconds) +} + +// Marshal serializes a ControlPacket into its fixed 40-byte wire format. +// +// Field layout (Big Endian): +// +// 0: Version (3 high bits) | 5 bits unused (zero) +// 1: State (2 high bits) | 6 bits unused (zero) +// 2: DetectMult +// 3: Length (always 40) +// 4–7: LocalDiscrr +// 8–11: peerDiscrr +// +// 12–15: DesiredMinTxUs +// 16–19: RequiredMinRxUs +// 20–39: zero padding (unused / reserved) +// +// Only a subset of the full BFD header is implemented; authentication and +// optional fields are omitted for simplicity. +func (c *ControlPacket) Marshal() []byte { + b := make([]byte, 40) + // Version (3 bits) and State (2 bits in high order of next byte) + vd := (c.Version & 0x7) << 5 + sf := (uint8(c.State) & 0x3) << 6 + b[0], b[1], b[2], b[3] = vd, sf, c.DetectMult, 40 + be := binary.BigEndian + be.PutUint32(b[4:8], c.LocalDiscrr) + be.PutUint32(b[8:12], c.peerDiscrr) + be.PutUint32(b[12:16], c.DesiredMinTxUs) + be.PutUint32(b[16:20], c.RequiredMinRxUs) + // Remaining bytes [20:40] are reserved/padding → left zeroed + return b +} + +// UnmarshalControlPacket parses a 40-byte control message from the wire +// into a ControlPacket. It validates the version and length fields and +// extracts all header values using big-endian order. +func UnmarshalControlPacket(b []byte) (*ControlPacket, error) { + if len(b) < 40 { + return nil, fmt.Errorf("short packet") + } + if b[3] != 40 { + return nil, fmt.Errorf("invalid length") + } + vd, sf := b[0], b[1] + ver := (vd >> 5) & 0x7 + if ver != 1 { + return nil, fmt.Errorf("unsupported version: %d", ver) + } + + c := &ControlPacket{ + Version: ver, + State: State((sf >> 6) & 0x3), + DetectMult: b[2], + Length: b[3], + } + + rd := func(off int) uint32 { return binary.BigEndian.Uint32(b[off : off+4]) } + c.LocalDiscrr = rd(4) + c.peerDiscrr = rd(8) + c.DesiredMinTxUs = rd(12) + c.RequiredMinRxUs = rd(16) + return c, nil +} diff --git a/client/doublezerod/internal/liveness/packet_fuzz_test.go b/client/doublezerod/internal/liveness/packet_fuzz_test.go new file mode 100644 index 000000000..1d03768a8 --- /dev/null +++ b/client/doublezerod/internal/liveness/packet_fuzz_test.go @@ -0,0 +1,13 @@ +package liveness + +import "testing" + +func FuzzClient_Liveness_Packet_Unmarshal_NoPanic(f *testing.F) { + f.Add(make([]byte, 40)) + f.Fuzz(func(t *testing.T, b []byte) { + if len(b) < 40 { + b = append(b, make([]byte, 40-len(b))...) + } + _, _ = UnmarshalControlPacket(b[:40]) + }) +} diff --git a/client/doublezerod/internal/liveness/packet_test.go b/client/doublezerod/internal/liveness/packet_test.go new file mode 100644 index 000000000..f749adad6 --- /dev/null +++ b/client/doublezerod/internal/liveness/packet_test.go @@ -0,0 +1,110 @@ +package liveness + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClient_Liveness_Packet_MarshalEncodesHeaderAndFields(t *testing.T) { + t.Parallel() + cp := &ControlPacket{ + Version: 5, + State: StateUp, + DetectMult: 3, + LocalDiscrr: 0x11223344, + peerDiscrr: 0x55667788, + DesiredMinTxUs: 0x01020304, + RequiredMinRxUs: 0x0A0B0C0D, + } + + b := cp.Marshal() + require.Len(t, b, 40) + require.Equal(t, uint8(40), b[3]) + require.Equal(t, uint8((5&0x7)<<5), b[0]) + require.Equal(t, uint8((uint8(StateUp)&0x3)<<6), b[1]) + require.Equal(t, uint8(3), b[2]) + + require.Equal(t, uint32(0x11223344), binary.BigEndian.Uint32(b[4:8])) + require.Equal(t, uint32(0x55667788), binary.BigEndian.Uint32(b[8:12])) + require.Equal(t, uint32(0x01020304), binary.BigEndian.Uint32(b[12:16])) + require.Equal(t, uint32(0x0A0B0C0D), binary.BigEndian.Uint32(b[16:20])) + + require.True(t, bytes.Equal(b[20:40], make([]byte, 20))) +} + +func TestClient_Liveness_Packet_UnmarshalRoundTrip(t *testing.T) { + t.Parallel() + orig := &ControlPacket{ + Version: 1, + State: StateInit, + DetectMult: 7, + LocalDiscrr: 1, + peerDiscrr: 2, + DesiredMinTxUs: 3, + RequiredMinRxUs: 4, + } + b := orig.Marshal() + got, err := UnmarshalControlPacket(b) + require.NoError(t, err) + + require.Equal(t, uint8(1), got.Version) + require.Equal(t, StateInit, got.State) + require.Equal(t, uint8(7), got.DetectMult) + require.Equal(t, uint8(40), got.Length) + require.Equal(t, uint32(1), got.LocalDiscrr) + require.Equal(t, uint32(2), got.peerDiscrr) + require.Equal(t, uint32(3), got.DesiredMinTxUs) + require.Equal(t, uint32(4), got.RequiredMinRxUs) +} + +func TestClient_Liveness_Packet_UnmarshalShort(t *testing.T) { + t.Parallel() + _, err := UnmarshalControlPacket(make([]byte, 39)) + require.EqualError(t, err, "short packet") +} + +func TestClient_Liveness_Packet_UnmarshalBadLength(t *testing.T) { + t.Parallel() + cp := (&ControlPacket{Version: 1}).Marshal() + cp[3] = 99 + _, err := UnmarshalControlPacket(cp) + require.EqualError(t, err, "invalid length") +} + +func TestClient_Liveness_Packet_BitMaskingVersionAndState_MarshalOnly(t *testing.T) { + t.Parallel() + cp := &ControlPacket{ + Version: 0xFF, + State: State(7), + DetectMult: 1, + } + b := cp.Marshal() + require.Equal(t, uint8(0xE0), b[0]) + require.Equal(t, uint8(0xC0), b[1]) +} + +func TestClient_Liveness_Packet_UnmarshalUnsupportedVersion(t *testing.T) { + t.Parallel() + cp := (&ControlPacket{Version: 7, State: StateUp, DetectMult: 1}).Marshal() + _, err := UnmarshalControlPacket(cp) + require.EqualError(t, err, "unsupported version: 7") +} + +func TestClient_Liveness_Packet_UnmarshalStateMaskWithV1(t *testing.T) { + t.Parallel() + cp := (&ControlPacket{Version: 1, State: State(7), DetectMult: 1}).Marshal() + got, err := UnmarshalControlPacket(cp) + require.NoError(t, err) + require.Equal(t, uint8(1), got.Version) + require.Equal(t, StateUp, got.State) // state masked to 2 bits +} + +func TestClient_Liveness_Packet_PaddingRemainsZero(t *testing.T) { + t.Parallel() + cp := &ControlPacket{Version: 3, State: StateDown, DetectMult: 5} + b := cp.Marshal() + require.True(t, bytes.Equal(b[20:], make([]byte, 20))) +} diff --git a/client/doublezerod/internal/liveness/receiver.go b/client/doublezerod/internal/liveness/receiver.go new file mode 100644 index 000000000..ac39e15a2 --- /dev/null +++ b/client/doublezerod/internal/liveness/receiver.go @@ -0,0 +1,183 @@ +package liveness + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "sync" + "syscall" + "time" +) + +// Receiver is a long-lived goroutine that continuously reads UDP control packets +// from the shared transport socket and passes valid ones to a handler. +// +// It abstracts read-loop robustness: manages deadlines, throttles noisy logs, +// detects fatal network conditions, and honors context cancellation cleanly. +type Receiver struct { + log *slog.Logger // structured logger for debug and warnings + udp *UDPService // underlying socket with control message support + handleRx HandleRxFunc // callback invoked for each valid ControlPacket + + readErrWarnEvery time.Duration // min interval between repeated read warnings + readErrWarnLast time.Time // last time a warning was logged + readErrWarnMu sync.Mutex // guards readErrWarnLast +} + +// HandleRxFunc defines the handler signature for received control packets. +// The callback is invoked for every successfully decoded ControlPacket, +// along with a Peer descriptor identifying interface and IP context. +type HandleRxFunc func(pkt *ControlPacket, peer Peer) + +// NewReceiver constructs a new Receiver bound to the given UDPService and handler. +// By default, it throttles repeated read errors to once every 5 seconds. +func NewReceiver(log *slog.Logger, udp *UDPService, handleRx HandleRxFunc) *Receiver { + return &Receiver{ + log: log, + udp: udp, + handleRx: handleRx, + readErrWarnEvery: 5 * time.Second, + } +} + +// Run executes the main receive loop until ctx is canceled or the socket fails. +// It continually reads packets, unmarshals them into ControlPackets, and passes +// them to handleRx. Errors are rate-limited and fatal errors terminate the loop. +func (r *Receiver) Run(ctx context.Context) error { + r.log.Debug("liveness.recv: rx loop started") + buf := make([]byte, 1500) // typical MTU-sized buffer + + for { + // Early exit if caller canceled context. + select { + case <-ctx.Done(): + r.log.Debug("liveness.recv: rx loop stopped by context done", "reason", ctx.Err()) + return nil + default: + } + + // Periodically set a read deadline to make the loop interruptible. + if err := r.udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil { + // Respect cancellation immediately if already stopped. + select { + case <-ctx.Done(): + r.log.Debug("liveness.recv: rx loop stopped by context done", "reason", ctx.Err()) + return nil + default: + } + if errors.Is(err, net.ErrClosed) { + r.log.Debug("liveness.recv: socket closed during SetReadDeadline; exiting") + return fmt.Errorf("socket closed during SetReadDeadline: %w", err) + } + + // Log throttled warnings for transient errors (e.g., bad FD state). + now := time.Now() + r.readErrWarnMu.Lock() + if r.readErrWarnLast.IsZero() || now.Sub(r.readErrWarnLast) >= r.readErrWarnEvery { + r.readErrWarnLast = now + r.readErrWarnMu.Unlock() + r.log.Warn("liveness.recv: SetReadDeadline error", "error", err) + } else { + r.readErrWarnMu.Unlock() + } + + // Exit for fatal kernel or network-level errors. + if isFatalNetErr(err) { + return fmt.Errorf("fatal network error during SetReadDeadline: %w", err) + } + + // Brief delay prevents a tight loop in persistent error states. + time.Sleep(50 * time.Millisecond) + continue + } + + // Perform the actual UDP read with control message extraction. + n, peerAddr, localIP, ifname, err := r.udp.ReadFrom(buf) + if err != nil { + // Stop cleanly on context cancellation. + select { + case <-ctx.Done(): + r.log.Debug("liveness.recv: rx loop stopped by context done", "reason", ctx.Err()) + return nil + default: + } + + // Deadline timeout: simply continue polling. + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + + // Closed socket: terminate immediately. + if errors.Is(err, net.ErrClosed) { + r.log.Debug("liveness.recv: socket closed; exiting") + return fmt.Errorf("socket closed during ReadFrom: %w", err) + } + + // Log other transient read errors, throttled. + now := time.Now() + r.readErrWarnMu.Lock() + if r.readErrWarnLast.IsZero() || now.Sub(r.readErrWarnLast) >= r.readErrWarnEvery { + r.readErrWarnLast = now + r.readErrWarnMu.Unlock() + r.log.Warn("liveness.recv: non-timeout read error", "error", err) + } else { + r.readErrWarnMu.Unlock() + } + + if isFatalNetErr(err) { + return fmt.Errorf("fatal network error during ReadFrom: %w", err) + } + continue + } + + // Attempt to parse the received packet into a ControlPacket struct. + ctrl, err := UnmarshalControlPacket(buf[:n]) + if err != nil { + r.log.Error("liveness.recv: error parsing control packet", "error", err) + continue + } + + // Skip packets that are not IPv4. + if localIP.To4() == nil || peerAddr.IP.To4() == nil { + continue + } + + // Populate the peer descriptor: identifies which local interface/IP + // the packet arrived on and the remote endpoint that sent it. + peer := Peer{ + Interface: ifname, + LocalIP: localIP.To4().String(), + PeerIP: peerAddr.IP.To4().String(), + } + + // Delegate to session or higher-level handler for processing. + r.handleRx(ctrl, peer) + } +} + +// isFatalNetErr determines whether a network-related error is non-recoverable. +// It checks for known fatal errno codes and unwraps platform-specific net errors. +func isFatalNetErr(err error) bool { + // Closed socket explicitly fatal. + if errors.Is(err, net.ErrClosed) { + return true + } + + // Inspect underlying syscall errno for hardware or interface removal. + var se syscall.Errno + if errors.As(err, &se) { + switch se { + case syscall.EBADF, syscall.ENETDOWN, syscall.ENODEV, syscall.ENXIO: + return true + } + } + + // On some systems, fatal syscall errors are wrapped in *net.OpError. + var oe *net.OpError + if errors.As(err, &oe) && !oe.Timeout() && !oe.Temporary() { + return true + } + return false +} diff --git a/client/doublezerod/internal/liveness/receiver_test.go b/client/doublezerod/internal/liveness/receiver_test.go new file mode 100644 index 000000000..021bb2f9b --- /dev/null +++ b/client/doublezerod/internal/liveness/receiver_test.go @@ -0,0 +1,162 @@ +package liveness + +import ( + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClient_Liveness_Receiver_CancelStopsLoop(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + udp, err := ListenUDP("127.0.0.1", 0) + require.NoError(t, err) + defer udp.Close() + + rx := NewReceiver(newTestLogger(t), udp, func(*ControlPacket, Peer) {}) + + done := make(chan struct{}) + go func() { + err := rx.Run(ctx) + require.NoError(t, err) + close(done) + }() + + // Nudge the loop to ensure it has started by forcing one deadline cycle. + time.Sleep(50 * time.Millisecond) + + // Cancel and close to unblock any in-flight ReadFrom immediately. + cancel() + _ = udp.Close() + + require.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, 3*time.Second, 25*time.Millisecond, "receiver did not exit after cancel+close") +} + +func TestClient_Liveness_Receiver_IgnoresMalformedPacket(t *testing.T) { + t.Parallel() + + udp, err := ListenUDP("127.0.0.1", 0) + require.NoError(t, err) + defer udp.Close() + + var calls int32 + rx := NewReceiver(newTestLogger(t), udp, func(*ControlPacket, Peer) { + atomic.AddInt32(&calls, 1) + }) + + ctx, cancel := context.WithCancel(t.Context()) + done := make(chan struct{}) + go func() { + err := rx.Run(ctx) + require.NoError(t, err) + close(done) + }() + + // Ensure loop is running: send malformed (<40 bytes) + cl, err := net.DialUDP("udp4", nil, udp.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + _, err = cl.Write(make([]byte, 20)) + require.NoError(t, err) + _ = cl.Close() + + time.Sleep(25 * time.Millisecond) // tiny nudge + + // Cancel, then close socket to force immediate unblock + cancel() + _ = udp.Close() + + require.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, 5*time.Second, 100*time.Millisecond, "receiver did not exit after cancel+close") + + require.Equal(t, int32(0), atomic.LoadInt32(&calls)) +} + +func TestClient_Liveness_Receiver_HandlerInvoked_WithPeerContext(t *testing.T) { + t.Parallel() + udp, err := ListenUDP("127.0.0.1", 0) + require.NoError(t, err) + defer udp.Close() + + var got Peer + calls := int32(0) + rx := NewReceiver(newTestLogger(t), udp, func(cp *ControlPacket, p Peer) { got = p; atomic.AddInt32(&calls, 1) }) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + done := make(chan struct{}) + go func() { require.NoError(t, rx.Run(ctx)); close(done) }() + + // send a valid control packet + cl, err := net.DialUDP("udp4", nil, udp.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + defer cl.Close() + pkt := (&ControlPacket{Version: 1, State: StateInit, DetectMult: 1, Length: 40}).Marshal() + _, err = cl.Write(pkt) + require.NoError(t, err) + + require.Eventually(t, func() bool { return atomic.LoadInt32(&calls) == 1 }, time.Second, 10*time.Millisecond) + require.NotEmpty(t, got.Interface) // usually lo/lo0 + require.Equal(t, "127.0.0.1", got.LocalIP) + require.Equal(t, "127.0.0.1", got.PeerIP) + + cancel() + _ = udp.Close() + <-done +} + +func TestClient_Liveness_Receiver_DeadlineTimeoutsAreSilent(t *testing.T) { + t.Parallel() + udp, err := ListenUDP("127.0.0.1", 0) + require.NoError(t, err) + defer udp.Close() + + rx := NewReceiver(newTestLogger(t), udp, func(*ControlPacket, Peer) {}) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + done := make(chan struct{}) + go func() { require.NoError(t, rx.Run(ctx)); close(done) }() + + // no traffic; ensure loop keeps running past a few deadlines + time.Sleep(600 * time.Millisecond) + cancel() + _ = udp.Close() + <-done +} + +func TestClient_Liveness_Receiver_SocketClosed_ReturnsError(t *testing.T) { + t.Parallel() + ctx := context.Background() + udp, err := ListenUDP("127.0.0.1", 0) + require.NoError(t, err) + + rx := NewReceiver(newTestLogger(t), udp, func(*ControlPacket, Peer) {}) + errCh := make(chan error, 1) + go func() { errCh <- rx.Run(ctx) }() + + time.Sleep(50 * time.Millisecond) + _ = udp.Close() + err = <-errCh + require.Error(t, err) + require.Contains(t, err.Error(), "socket closed") +} diff --git a/client/doublezerod/internal/liveness/routerw.go b/client/doublezerod/internal/liveness/routerw.go new file mode 100644 index 000000000..e9e37b9db --- /dev/null +++ b/client/doublezerod/internal/liveness/routerw.go @@ -0,0 +1,55 @@ +package liveness + +import ( + "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" +) + +// RouteReaderWriter is the minimal interface for interacting with the routing +// backend. It allows adding, deleting, and listing routes by protocol. +// The BGP plugin uses this to interact with the kernel routing table through +// the liveness subsystem, without depending on its internal implementation. +type RouteReaderWriter interface { + RouteAdd(*routing.Route) error + RouteDelete(*routing.Route) error + RouteByProtocol(int) ([]*routing.Route, error) +} + +// routeReaderWriter is an interface-specific adapter that connects a single +// network interface (iface) to the liveness Manager. It is typically created +// by the BGP plugin so that each managed interface has its own scoped view +// of route registration and withdrawal through the Manager. +type routeReaderWriter struct { + lm *Manager // liveness manager handling route lifecycle + rrw RouteReaderWriter // underlying netlink backend + iface string // interface name associated with these routes +} + +// NewRouteReaderWriter creates an interface-scoped RouteReaderWriter that +// wraps the liveness Manager and a concrete routing backend. This allows the +// BGP plugin to use standard routing calls while the Manager tracks route +// liveness on a per-interface basis. +func NewRouteReaderWriter(lm *Manager, rrw RouteReaderWriter, iface string) *routeReaderWriter { + return &routeReaderWriter{ + lm: lm, + rrw: rrw, + iface: iface, + } +} + +// RouteAdd registers the route with the liveness Manager for the given iface, +// enabling the Manager to monitor reachability before installation. +func (m *routeReaderWriter) RouteAdd(r *routing.Route) error { + return m.lm.RegisterRoute(r, m.iface) +} + +// RouteDelete withdraws the route and removes it from liveness tracking for +// the associated interface. +func (m *routeReaderWriter) RouteDelete(r *routing.Route) error { + return m.lm.WithdrawRoute(r, m.iface) +} + +// RouteByProtocol delegates to the underlying backend to list routes by +// protocol ID without involving the Manager. +func (m *routeReaderWriter) RouteByProtocol(protocol int) ([]*routing.Route, error) { + return m.rrw.RouteByProtocol(protocol) +} diff --git a/client/doublezerod/internal/liveness/scheduler.go b/client/doublezerod/internal/liveness/scheduler.go new file mode 100644 index 000000000..e5d734cfb --- /dev/null +++ b/client/doublezerod/internal/liveness/scheduler.go @@ -0,0 +1,359 @@ +package liveness + +import ( + "container/heap" + "context" + "log/slog" + "net" + "sync" + "time" +) + +// evType distinguishes between scheduled transmit (TX) and detect-timeout (Detect) events. +type eventType uint8 + +const ( + eventTypeTX eventType = 1 // transmit control packet + eventTypeDetect eventType = 2 // detect timeout check +) + +// event represents a single scheduled action tied to a session. +// Each event is timestamped and sequence-numbered to ensure stable ordering in the heap. +type event struct { + when time.Time // time when the event should fire + eventType eventType // type of event (TX or Detect) + session *Session // owning session + seq uint64 // sequence number for deterministic ordering +} + +// EventQueue is a thread-safe priority queue of scheduled events. +// It supports pushing events and popping those whose time has come (or is nearest). +type EventQueue struct { + mu sync.Mutex + pq eventHeap // min-heap of events ordered by time then seq + seq uint64 // global sequence counter for tie-breaking +} + +// NewEventQueue constructs an initialized empty heap-based event queue. +func NewEventQueue() *EventQueue { + h := eventHeap{} + heap.Init(&h) + return &EventQueue{pq: h} +} + +// Push inserts a new event into the queue and assigns it a sequence number. +// Later events with identical timestamps are ordered by insertion. +func (q *EventQueue) Push(e *event) { + q.mu.Lock() + q.seq++ + e.seq = q.seq + heap.Push(&q.pq, e) + q.mu.Unlock() +} + +// Pop removes and returns the next (earliest) event from the queue, or nil if empty. +func (q *EventQueue) Pop() *event { + q.mu.Lock() + if q.pq.Len() == 0 { + q.mu.Unlock() + return nil + } + ev := heap.Pop(&q.pq).(*event) + q.mu.Unlock() + return ev +} + +// PopIfDue returns the next event if its scheduled time is due (<= now). +// Otherwise, it returns nil and the duration until the next event’s time, +// allowing the caller to sleep until that deadline. +func (q *EventQueue) PopIfDue(now time.Time) (*event, time.Duration) { + q.mu.Lock() + if q.pq.Len() == 0 { + q.mu.Unlock() + return nil, 10 * time.Millisecond + } + ev := q.pq[0] + if d := ev.when.Sub(now); d > 0 { + q.mu.Unlock() + return nil, d + } + ev = heap.Pop(&q.pq).(*event) + q.mu.Unlock() + return ev, 0 +} + +// CountFor returns the number of events in the queue for a given interface and local IP. +func (q *EventQueue) CountFor(iface, localIP string) int { + q.mu.Lock() + defer q.mu.Unlock() + c := 0 + for _, ev := range q.pq { + if ev != nil && ev.session != nil && ev.session.peer != nil { + p := ev.session.peer + if p.Interface == iface && p.LocalIP == localIP { + c++ + } + } + } + return c +} + +// Len returns the total number of events in the queue. +func (q *EventQueue) Len() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.pq.Len() +} + +// eventHeap implements heap.Interface for event scheduling by time then seq. +type eventHeap []*event + +func (h eventHeap) Len() int { return len(h) } + +func (h eventHeap) Less(i, j int) bool { + if h[i].when.Equal(h[j].when) { + return h[i].seq < h[j].seq + } + return h[i].when.Before(h[j].when) +} + +func (h eventHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *eventHeap) Push(x any) { *h = append(*h, x.(*event)) } +func (h *eventHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} + +// Scheduler drives session state progression and control message exchange. +// It runs a single event loop that processes transmit (TX) and detect events across sessions. +// New sessions schedule TX immediately; detect is armed/re-armed after valid RX during Init/Up. +type Scheduler struct { + log *slog.Logger // structured logger for observability + udp *UDPService // shared UDP transport for all sessions + onSessionDown func(s *Session) // callback invoked when a session transitions to Down + eq *EventQueue // global time-ordered event queue + maxEvents int // 0 = unlimited + + writeErrWarnEvery time.Duration // min interval between repeated write warnings + writeErrWarnLast time.Time // last time a warning was logged + writeErrWarnMu sync.Mutex // guards writeErrWarnLast +} + +// NewScheduler constructs a Scheduler bound to a UDP transport and logger. +// onSessionDown is called asynchronously whenever a session is detected as failed. +func NewScheduler(log *slog.Logger, udp *UDPService, onSessionDown func(s *Session), maxEvents int) *Scheduler { + eq := NewEventQueue() + return &Scheduler{ + log: log, + udp: udp, + onSessionDown: onSessionDown, + eq: eq, + writeErrWarnEvery: 5 * time.Second, + maxEvents: maxEvents, + } +} + +// Run executes the scheduler’s main loop until ctx is canceled. +// It continuously pops and processes due events, sleeping until the next one if necessary. +// Each TX event sends a control packet and re-schedules the next TX; +// each Detect event checks for timeout and invokes onSessionDown if expired. +func (s *Scheduler) Run(ctx context.Context) error { + s.log.Debug("liveness.scheduler: tx loop started") + + t := time.NewTimer(time.Hour) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + s.log.Debug("liveness.scheduler: stopped by context done", "reason", ctx.Err()) + return nil + default: + } + + now := time.Now() + ev, wait := s.eq.PopIfDue(now) + if ev == nil { + // No due events — sleep until next one or timeout. + if wait <= 0 { + wait = 10 * time.Millisecond + } + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(wait) + select { + case <-ctx.Done(): + s.log.Debug("liveness.scheduler: stopped by context done", "reason", ctx.Err()) + return nil + case <-t.C: + continue + } + } + + switch ev.eventType { + case eventTypeTX: + ev.session.mu.Lock() + if ev.when.Equal(ev.session.nextTxScheduled) { + ev.session.nextTxScheduled = time.Time{} + } + ev.session.mu.Unlock() + s.doTX(ev.session) + s.scheduleTx(time.Now(), ev.session) + case eventTypeDetect: + // drop stale detect events + ev.session.mu.Lock() + if !ev.when.Equal(ev.session.detectDeadline) { + if ev.when.Equal(ev.session.nextDetectScheduled) { + ev.session.nextDetectScheduled = time.Time{} + } + ev.session.mu.Unlock() + continue + } + if ev.when.Equal(ev.session.nextDetectScheduled) { + ev.session.nextDetectScheduled = time.Time{} + } + ev.session.mu.Unlock() + + if s.tryExpire(ev.session) { + // Expiration triggers asynchronous session-down handling. + go s.onSessionDown(ev.session) + continue + } + // Still active; re-arm detect timer for next interval. + ev.session.mu.Lock() + st := ev.session.state + ev.session.mu.Unlock() + if st == StateUp || st == StateInit { + s.scheduleDetect(time.Now(), ev.session) + } + } + } +} + +func (s *Scheduler) maybeDropOnOverflow(et eventType) bool { + if s.maxEvents <= 0 { + return false + } + if s.eq.Len() < s.maxEvents { + return false + } + if et == eventTypeTX { + // never drop TX + return false + } + return true +} + +// scheduleTx schedules the next transmit event for the given session. +// Skips sessions that are not alive or are AdminDown; backoff is handled by ComputeNextTx. +func (s *Scheduler) scheduleTx(now time.Time, sess *Session) { + // If TX already scheduled, bail without recomputing. + sess.mu.Lock() + if !sess.alive || sess.state == StateAdminDown || !sess.nextTxScheduled.IsZero() { + sess.mu.Unlock() + return + } + sess.mu.Unlock() + + // Compute next (locks internally, updates sess.nextTx) + next := sess.ComputeNextTx(now, nil) + + // Publish the scheduled marker (re-check in case of race). + sess.mu.Lock() + if !sess.alive || sess.state == StateAdminDown || !sess.nextTxScheduled.IsZero() { + sess.mu.Unlock() + return + } + sess.nextTxScheduled = next + sess.mu.Unlock() + + s.eq.Push(&event{when: next, eventType: eventTypeTX, session: sess}) +} + +// scheduleDetect arms or re-arms a session’s detection timer and enqueues a detect event. +// If the session is not alive or lacks a valid deadline, nothing is scheduled. +func (s *Scheduler) scheduleDetect(now time.Time, sess *Session) { + ddl, ok := sess.ArmDetect(now) + if !ok { + return + } + + sess.mu.Lock() + if sess.nextDetectScheduled.Equal(ddl) { + sess.mu.Unlock() + return // already scheduled for this exact deadline + } + sess.nextDetectScheduled = ddl + sess.mu.Unlock() + + if s.maybeDropOnOverflow(eventTypeDetect) { + // undo marker since we didn’t enqueue + sess.mu.Lock() + if sess.nextDetectScheduled.Equal(ddl) { + sess.nextDetectScheduled = time.Time{} + } + sess.mu.Unlock() + return + } + + s.eq.Push(&event{when: ddl, eventType: eventTypeDetect, session: sess}) +} + +// doTX builds and transmits a ControlPacket representing the session’s current state. +// It reads protected fields under lock, serializes the packet, and sends via UDPService. +// Any transient send errors are logged at debug level. +func (s *Scheduler) doTX(sess *Session) { + sess.mu.Lock() + if !sess.alive || sess.state == StateAdminDown { + sess.mu.Unlock() + return + } + pkt := (&ControlPacket{ + Version: 1, + State: sess.state, + DetectMult: sess.detectMult, + Length: 40, + LocalDiscrr: sess.localDiscr, + peerDiscrr: sess.peerDiscr, + DesiredMinTxUs: uint32(sess.localTxMin / time.Microsecond), + RequiredMinRxUs: uint32(sess.localRxMin / time.Microsecond), + }).Marshal() + sess.mu.Unlock() + src := net.IP(nil) + if sess.route != nil { + src = sess.route.Src + } + _, err := s.udp.WriteTo(pkt, sess.peerAddr, sess.peer.Interface, src) + if err != nil { + // Log throttled warnings for transient errors (e.g., bad FD state). + now := time.Now() + s.writeErrWarnMu.Lock() + if s.writeErrWarnLast.IsZero() || now.Sub(s.writeErrWarnLast) >= s.writeErrWarnEvery { + s.writeErrWarnLast = now + s.writeErrWarnMu.Unlock() + s.log.Warn("liveness.scheduler: error writing UDP packet", "error", err) + } else { + s.writeErrWarnMu.Unlock() + } + } +} + +// tryExpire checks whether the session’s detect deadline has passed. +// If so, it transitions the session to Down, triggers an immediate TX +// to advertise the Down state, and returns true to signal expiration. +func (s *Scheduler) tryExpire(sess *Session) bool { + now := time.Now() + if sess.ExpireIfDue(now) { + s.eq.Push(&event{when: now, eventType: eventTypeTX, session: sess}) + return true + } + return false +} diff --git a/client/doublezerod/internal/liveness/scheduler_test.go b/client/doublezerod/internal/liveness/scheduler_test.go new file mode 100644 index 000000000..64d1e1737 --- /dev/null +++ b/client/doublezerod/internal/liveness/scheduler_test.go @@ -0,0 +1,480 @@ +package liveness + +import ( + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClient_Liveness_Scheduler_EventQueueOrdering(t *testing.T) { + t.Parallel() + + q := NewEventQueue() + now := time.Now() + e1 := &event{when: now} + e2 := &event{when: now} + e3 := &event{when: now.Add(5 * time.Millisecond)} + + q.Push(e1) + q.Push(e2) + q.Push(e3) + + // First PopIfDue returns first event immediately, zero wait + ev, wait := q.PopIfDue(now) + require.Equal(t, e1, ev) + require.Zero(t, wait) + + // Second PopIfDue returns second event immediately, still zero wait + ev, wait = q.PopIfDue(now) + require.Equal(t, e2, ev) + require.Zero(t, wait) + + // Third PopIfDue should not return yet, wait ~5ms + ev, wait = q.PopIfDue(now) + require.Nil(t, ev) + require.InDelta(t, 5*time.Millisecond, wait, float64(time.Millisecond)) +} + +func TestClient_Liveness_Scheduler_TryExpireEnqueuesImmediateTX(t *testing.T) { + t.Parallel() + + // minimal scheduler with a real EventQueue; udp/log not used here + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateUp, + detectDeadline: time.Now().Add(-time.Millisecond), + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + } + ok := s.tryExpire(sess) + require.True(t, ok) + + // first event should be immediate TX + ev := s.eq.Pop() + require.NotNil(t, ev) + require.Equal(t, eventTypeTX, ev.eventType) + + // and state flipped to Down, detect cleared + require.Equal(t, StateDown, sess.state) + require.True(t, sess.detectDeadline.IsZero()) +} + +func TestClient_Liveness_Scheduler_ScheduleDetect_NoArmNoEnqueue(t *testing.T) { + t.Parallel() + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{alive: false} // ArmDetect will return false + + s.scheduleDetect(time.Now(), sess) + require.Nil(t, s.eq.Pop()) // queue stays empty +} + +func TestClient_Liveness_Scheduler_ScheduleDetect_EnqueuesDeadline(t *testing.T) { + t.Parallel() + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + alive: true, + detectDeadline: time.Now().Add(50 * time.Millisecond), + detectMult: 1, + minTxFloor: time.Millisecond, + } + + s.scheduleDetect(time.Now(), sess) + ev := s.eq.Pop() + require.NotNil(t, ev) + require.Equal(t, eventTypeDetect, ev.eventType) +} + +func TestClient_Liveness_Scheduler_TryExpire_Idempotent(t *testing.T) { + t.Parallel() + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateUp, + detectDeadline: time.Now().Add(-time.Millisecond), + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + } + require.True(t, s.tryExpire(sess)) + require.False(t, s.tryExpire(sess)) // second call no effect +} + +func TestClient_Liveness_Scheduler_ScheduleTx_NoEnqueueWhenAdminDown(t *testing.T) { + t.Parallel() + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateAdminDown, + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + } + s.scheduleTx(time.Now(), sess) + require.Nil(t, s.eq.Pop(), "no TX should be scheduled while AdminDown") +} + +func TestClient_Liveness_Scheduler_ScheduleTx_AdaptiveBackoffWhenDown(t *testing.T) { + t.Parallel() + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateDown, + alive: true, + detectMult: 1, + localTxMin: 20 * time.Millisecond, + localRxMin: 20 * time.Millisecond, + minTxFloor: 10 * time.Millisecond, + maxTxCeil: 1 * time.Second, + backoffMax: 150 * time.Millisecond, + backoffFactor: 1, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + now := time.Now() + + // First schedule: should enqueue a TX and bump backoffFactor in ComputeNextTx. + s.scheduleTx(now, sess) + ev1 := s.eq.Pop() + require.NotNil(t, ev1) + require.Equal(t, eventTypeTX, ev1.eventType) + require.Greater(t, sess.backoffFactor, uint32(1)) // doubled to 2 + require.True(t, ev1.when.After(now)) + + // Simulate Run loop clearing the pending TX marker when the event is consumed. + sess.mu.Lock() + if ev1.when.Equal(sess.nextTxScheduled) { + sess.nextTxScheduled = time.Time{} + } + sess.mu.Unlock() + + // Second schedule: allowed now, should enqueue another TX and further backoff (up to cap). + s.scheduleTx(now.Add(time.Millisecond), sess) + ev2 := s.eq.Pop() + require.NotNil(t, ev2) + require.Equal(t, eventTypeTX, ev2.eventType) + require.GreaterOrEqual(t, sess.backoffFactor, uint32(4)) + require.True(t, ev2.when.After(now)) + + // Bound first interval by backoffMax (+ jitter slack) + require.LessOrEqual(t, time.Until(ev1.when), time.Duration(float64(150*time.Millisecond)*1.5)) +} + +func TestClient_Liveness_Scheduler_Run_SendsAndReschedules(t *testing.T) { + t.Parallel() + // real UDP to count packets + srv, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer srv.Close() + r, _ := NewUDPService(srv) + cl, _ := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + defer cl.Close() + w, _ := NewUDPService(cl) + + pkts := int32(0) + stop := make(chan struct{}) + go func() { + buf := make([]byte, 128) + _ = srv.SetReadDeadline(time.Now().Add(2 * time.Second)) + for { + _, _, _, _, err := r.ReadFrom(buf) + if err != nil { + return + } + atomic.AddInt32(&pkts, 1) + } + }() + + log := newTestLogger(t) + s := NewScheduler(log, w, func(*Session) {}, 0) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + go func() { + require.NoError(t, s.Run(ctx)) + }() + + sess := &Session{ + state: StateInit, + alive: true, + localTxMin: 20 * time.Millisecond, + localRxMin: 20 * time.Millisecond, + minTxFloor: 10 * time.Millisecond, + maxTxCeil: 200 * time.Millisecond, + detectMult: 3, + peer: &Peer{Interface: "", LocalIP: cl.LocalAddr().(*net.UDPAddr).IP.String()}, + peerAddr: srv.LocalAddr().(*net.UDPAddr), + backoffMax: 200 * time.Millisecond, + backoffFactor: 1, + } + s.scheduleTx(time.Now(), sess) + time.Sleep(120 * time.Millisecond) + cancel() + close(stop) + + require.GreaterOrEqual(t, atomic.LoadInt32(&pkts), int32(2)) +} + +func TestClient_Liveness_Scheduler_ScheduleDetect_DedupSameDeadline(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + // Use a fixed 'now' strictly before detectDeadline so ArmDetect does not re-arm. + fixedNow := time.Now() + sess.mu.Lock() + sess.detectDeadline = fixedNow.Add(50 * time.Millisecond) + sess.mu.Unlock() + + // First enqueue for the deadline. + s.scheduleDetect(fixedNow, sess) + // Spam scheduleDetect with the SAME fixed 'now'; must not enqueue duplicates. + for i := 0; i < 100; i++ { + s.scheduleDetect(fixedNow, sess) + } + + require.Equal(t, 1, s.eq.CountFor("eth0", "192.0.2.1")) + + ev := s.eq.Pop() + require.NotNil(t, ev) + require.Equal(t, eventTypeDetect, ev.eventType) + require.Nil(t, s.eq.Pop()) +} + +func TestClient_Liveness_Scheduler_ScheduleDetect_AllowsNewDeadlineButStillDedupsPerDeadline(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + base := time.Now() + d1 := base.Add(40 * time.Millisecond) + + // Phase 1: schedule for D1 with fixed time < D1 + sess.mu.Lock() + sess.detectDeadline = d1 + sess.mu.Unlock() + s.scheduleDetect(base, sess) + for i := 0; i < 10; i++ { + s.scheduleDetect(base, sess) + } + require.Equal(t, 1, s.eq.CountFor("eth0", "192.0.2.1")) + + // Phase 2: move to a new deadline D2; still call with fixed time < D2 + d2 := base.Add(90 * time.Millisecond) + sess.mu.Lock() + sess.detectDeadline = d2 + sess.mu.Unlock() + for i := 0; i < 10; i++ { + s.scheduleDetect(base, sess) + } + + // Exactly two detect events queued for this peer (D1 and D2) + require.Equal(t, 2, s.eq.CountFor("eth0", "192.0.2.1")) + + // Pop order must be D1 then D2 + ev1 := s.eq.Pop() + require.NotNil(t, ev1) + require.Equal(t, eventTypeDetect, ev1.eventType) + ev2 := s.eq.Pop() + require.NotNil(t, ev2) + require.Equal(t, eventTypeDetect, ev2.eventType) + require.True(t, ev1.when.Before(ev2.when) || ev1.when.Equal(ev2.when)) + + require.Nil(t, s.eq.Pop()) +} + +func TestClient_Liveness_Scheduler_ScheduleTx_DedupWhilePending(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateInit, + alive: true, + localTxMin: 20 * time.Millisecond, + localRxMin: 20 * time.Millisecond, + minTxFloor: 10 * time.Millisecond, + maxTxCeil: 200 * time.Millisecond, + backoffMax: 200 * time.Millisecond, + backoffFactor: 1, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + // First schedule should enqueue exactly one TX. + now := time.Now() + s.scheduleTx(now, sess) + + // Repeated schedules while a TX is already pending must NOT enqueue more. + for i := 0; i < 100; i++ { + s.scheduleTx(now.Add(time.Duration(i)*time.Millisecond), sess) + } + + require.Equal(t, 1, s.eq.CountFor("eth0", "192.0.2.1")) + + ev := s.eq.Pop() + require.NotNil(t, ev) + require.Equal(t, eventTypeTX, ev.eventType) + require.Nil(t, s.eq.Pop()) +} + +func TestClient_Liveness_Scheduler_ScheduleTx_AllowsRescheduleAfterPop(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue()} + sess := &Session{ + state: StateInit, + alive: true, + localTxMin: 20 * time.Millisecond, + localRxMin: 20 * time.Millisecond, + minTxFloor: 10 * time.Millisecond, + maxTxCeil: 200 * time.Millisecond, + backoffMax: 200 * time.Millisecond, + backoffFactor: 1, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + now := time.Now() + s.scheduleTx(now, sess) + ev := s.eq.Pop() + require.NotNil(t, ev) + require.Equal(t, eventTypeTX, ev.eventType) + + // Simulate the Run loop clearing the scheduled marker when the TX event is consumed. + sess.mu.Lock() + if ev.when.Equal(sess.nextTxScheduled) { + sess.nextTxScheduled = time.Time{} + } + sess.mu.Unlock() + + // Now we should be able to schedule the next TX. + s.scheduleTx(now.Add(5*time.Millisecond), sess) + require.Equal(t, 1, s.eq.CountFor("eth0", "192.0.2.1")) + + ev2 := s.eq.Pop() + require.NotNil(t, ev2) + require.Equal(t, eventTypeTX, ev2.eventType) + require.Nil(t, s.eq.Pop()) +} + +func TestClient_Liveness_Scheduler_ScheduleDetect_DropsOnOverflowAndClearsMarker(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue(), maxEvents: 1} + sess := &Session{ + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + // Fill the queue to the cap with an unrelated event + other := &Session{peer: &Peer{Interface: "ethX", LocalIP: "198.51.100.1"}} + s.eq.Push(&event{when: time.Now().Add(time.Second), eventType: eventTypeTX, session: other}) + require.Equal(t, 1, s.eq.Len()) + + // Try to schedule Detect; should be dropped due to overflow and marker cleared + now := time.Now() + sess.mu.Lock() + sess.detectDeadline = now.Add(50 * time.Millisecond) + sess.mu.Unlock() + + s.scheduleDetect(now, sess) + + require.Equal(t, 1, s.eq.Len(), "queue should remain at cap; detect dropped") + sess.mu.Lock() + require.True(t, sess.nextDetectScheduled.IsZero(), "dedupe marker must be cleared on drop") + sess.mu.Unlock() +} + +func TestClient_Liveness_Scheduler_Run_CullsStaleDetectAndClearsMarker(t *testing.T) { + t.Parallel() + + log := newTestLogger(t) + s := NewScheduler(log, nil, func(*Session) {}, 0) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + sess := &Session{ + alive: true, + detectMult: 1, + minTxFloor: time.Millisecond, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + // Make a stale detect: queued deadline d1, but current detectDeadline is d2. + now := time.Now() + d1 := now.Add(-1 * time.Millisecond) // already due -> scheduler will pop immediately + d2 := now.Add(90 * time.Millisecond) // current detect deadline (different from d1) + + sess.mu.Lock() + sess.detectDeadline = d2 + sess.nextDetectScheduled = d1 // simulate prior scheduling for d1 + sess.mu.Unlock() + + // Enqueue the stale detect event. + s.eq.Push(&event{when: d1, eventType: eventTypeDetect, session: sess}) + require.Equal(t, 1, s.eq.Len()) + + done := make(chan struct{}) + go func() { _ = s.Run(ctx); close(done) }() + + // Wait until the queue is empty (stale event culled) or time out + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if s.eq.Len() == 0 { + break + } + time.Sleep(2 * time.Millisecond) + } + cancel() + <-done + + require.Equal(t, 0, s.eq.Len(), "stale detect should be culled without rescheduling") + + sess.mu.Lock() + require.True(t, sess.nextDetectScheduled.IsZero(), "marker must be cleared when stale event is dropped") + require.Equal(t, d2, sess.detectDeadline, "current deadline must remain unchanged") + sess.mu.Unlock() +} + +func TestClient_Liveness_Scheduler_ScheduleTx_NotDroppedByOverflow(t *testing.T) { + t.Parallel() + + s := &Scheduler{eq: NewEventQueue(), maxEvents: 1} + sess := &Session{ + state: StateInit, + alive: true, + localTxMin: 20 * time.Millisecond, + localRxMin: 20 * time.Millisecond, + minTxFloor: 10 * time.Millisecond, + maxTxCeil: 200 * time.Millisecond, + backoffMax: 200 * time.Millisecond, + backoffFactor: 1, + peer: &Peer{Interface: "eth0", LocalIP: "192.0.2.1"}, + } + + // Fill the queue to the cap with an unrelated event + other := &Session{peer: &Peer{Interface: "ethX", LocalIP: "198.51.100.1"}} + s.eq.Push(&event{when: time.Now().Add(time.Second), eventType: eventTypeDetect, session: other}) + require.Equal(t, 1, s.eq.Len()) + + // scheduleTx should still enqueue despite overflow (policy: never drop TX) + s.scheduleTx(time.Now(), sess) + require.Equal(t, 2, s.eq.Len(), "TX must not be dropped by the soft cap") + + // Clean up: pop both; first could be either depending on 'when' + require.NotNil(t, s.eq.Pop()) + require.NotNil(t, s.eq.Pop()) + require.Equal(t, 0, s.eq.Len()) +} diff --git a/client/doublezerod/internal/liveness/session.go b/client/doublezerod/internal/liveness/session.go new file mode 100644 index 000000000..6f8fdd1d9 --- /dev/null +++ b/client/doublezerod/internal/liveness/session.go @@ -0,0 +1,239 @@ +package liveness + +import ( + "math/rand" + "net" + "sync" + "time" + + "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" +) + +// Session models a single bidirectional liveness relationship with a peer, +// maintaining BFD-like state, timers, and randomized transmission scheduling. +type Session struct { + route *routing.Route + + localDiscr, peerDiscr uint32 // discriminators identify this session to each side + state State // current BFD state + + // detectMult scales the detection timeout relative to the receive interval; + // it defines how many consecutive RX intervals may elapse without traffic + // before declaring the session Down (e.g., 3 → tolerate ~3 missed intervals). + detectMult uint8 + + localTxMin, localRxMin time.Duration // our minimum TX/RX intervals + peerTxMin, peerRxMin time.Duration // peer's advertised TX/RX intervals + + nextTx, detectDeadline, lastRx time.Time // computed next transmit time, detect timeout, last RX time + + peer *Peer + peerAddr *net.UDPAddr + + alive bool // manager lifecycle flag: whether this session is still managed + + minTxFloor, maxTxCeil time.Duration // global interval bounds + backoffMax time.Duration // upper bound for exponential backoff + backoffFactor uint32 // doubles when Down, resets when Up + + mu sync.Mutex // guards mutable session state + + // Scheduled time of the next enqueued detect and tx events (zero means nothing enqueued) + nextTxScheduled time.Time + nextDetectScheduled time.Time +} + +// ComputeNextTx picks the next transmit time based on current state, +// applies exponential backoff when Down, adds ±10% jitter, +// persists it to s.nextTx, and returns the chosen timestamp. +func (s *Session) ComputeNextTx(now time.Time, rnd *rand.Rand) time.Time { + s.mu.Lock() + + base := s.txInterval() + eff := base + if s.state == StateDown { + if s.backoffFactor < 1 { + s.backoffFactor = 1 + } + eff *= time.Duration(s.backoffFactor) + if s.backoffMax > 0 && eff > s.backoffMax { + eff = s.backoffMax + } + } + + j := eff / 10 + span := int64(2*j) + 1 + if span < 1 { + span = 1 + } + var off int64 + if rnd != nil { + off = rnd.Int63n(span) + } else { + off = rand.Int63n(span) + } + jit := time.Duration(off) - j + next := now.Add(eff + jit) + s.nextTx = next + + // Backoff doubles while Down; reset once Up or Init again. + if s.state == StateDown { + if s.backoffMax == 0 || eff < s.backoffMax { + if s.backoffFactor == 0 { + s.backoffFactor = 1 + } + s.backoffFactor *= 2 + } + } else { + s.backoffFactor = 1 + } + s.mu.Unlock() + return next +} + +// ArmDetect ensures the detection timer is active and not stale. +// If expired, it re-arms; if uninitialized, it returns false. +// Returns the deadline and whether detect should be (re)scheduled. +func (s *Session) ArmDetect(now time.Time) (time.Time, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if !s.alive || s.detectDeadline.IsZero() { + return time.Time{}, false + } + ddl := s.detectDeadline + if !ddl.After(now) { + ddl = now.Add(s.detectTime()) + s.detectDeadline = ddl + } + return ddl, true +} + +// ExpireIfDue transitions an active session to Down if its detect timer +// has elapsed. Returns true if state changed (Up/Init → Down). +func (s *Session) ExpireIfDue(now time.Time) (expired bool) { + s.mu.Lock() + defer s.mu.Unlock() + if !s.alive { + return false + } + + if (s.state == StateUp || s.state == StateInit) && + !s.detectDeadline.IsZero() && + !now.Before(s.detectDeadline) { + s.state = StateDown + s.backoffFactor = 1 + s.detectDeadline = time.Time{} + return true + } + return false +} + +// HandleRx ingests an incoming control packet, validates discriminators, +// updates peer timers, re-arms detection, and performs state transitions +// according to a simplified BFD-like handshake. +func (s *Session) HandleRx(now time.Time, ctrl *ControlPacket) (changed bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state == StateAdminDown { + return false + } + if ctrl.peerDiscrr != 0 && ctrl.peerDiscrr != s.localDiscr { + return false + } + + prev := s.state + + // Learn peer discriminator if not yet known. + if s.peerDiscr == 0 && ctrl.LocalDiscrr != 0 { + s.peerDiscr = ctrl.LocalDiscrr + } + + // Update peer timing and clamp within floor/ceiling bounds. + rtx := time.Duration(ctrl.DesiredMinTxUs) * time.Microsecond + rrx := time.Duration(ctrl.RequiredMinRxUs) * time.Microsecond + if rtx < s.minTxFloor { + rtx = s.minTxFloor + } else if s.maxTxCeil > 0 && rtx > s.maxTxCeil { + rtx = s.maxTxCeil + } + if rrx < s.minTxFloor { + rrx = s.minTxFloor + } else if s.maxTxCeil > 0 && rrx > s.maxTxCeil { + rrx = s.maxTxCeil + } + s.peerTxMin, s.peerRxMin = rtx, rrx + s.lastRx = now + s.detectDeadline = now.Add(s.detectTime()) + + switch prev { + case StateDown: + // Move to Init once peer identified; Up after echo confirmation. + if s.peerDiscr != 0 { + if ctrl.State >= StateInit && ctrl.peerDiscrr == s.localDiscr { + s.state = StateUp + s.backoffFactor = 1 + } else { + s.state = StateInit + s.backoffFactor = 1 + } + } + + case StateInit: + // Promote to Up only after receiving echo referencing our localDiscr. + if s.peerDiscr != 0 && ctrl.State >= StateInit && ctrl.peerDiscrr == s.localDiscr { + s.state = StateUp + s.backoffFactor = 1 + } + + case StateUp: + // If peer advertises Down, immediately mirror it and pause detect. + if ctrl.State == StateDown { + s.state = StateDown + s.backoffFactor = 1 + s.detectDeadline = time.Time{} + } + } + + return s.state != prev +} + +// detectTime computes detection interval as detectMult × rxInterval(). +func (s *Session) detectTime() time.Duration { + return time.Duration(int64(s.detectMult) * int64(s.rxInterval())) +} + +// txInterval picks the effective transmit interval, bounded by floors/ceilings, +// using the greater of localTxMin and peerRxMin. +func (s *Session) txInterval() time.Duration { + iv := s.localTxMin + if s.peerRxMin > iv { + iv = s.peerRxMin + } + if iv < s.minTxFloor { + iv = s.minTxFloor + } + if iv > s.maxTxCeil { + iv = s.maxTxCeil + } + return iv +} + +// rxInterval picks the effective receive interval based on peer TX and +// our own desired RX, clamped to the same bounds. +func (s *Session) rxInterval() time.Duration { + ref := s.peerTxMin + if s.localRxMin > ref { + ref = s.localRxMin + } + if ref == 0 { + ref = s.localRxMin + } + if ref < s.minTxFloor { + ref = s.minTxFloor + } + if ref > s.maxTxCeil { + ref = s.maxTxCeil + } + return ref +} diff --git a/client/doublezerod/internal/liveness/session_test.go b/client/doublezerod/internal/liveness/session_test.go new file mode 100644 index 000000000..d1cbb70fc --- /dev/null +++ b/client/doublezerod/internal/liveness/session_test.go @@ -0,0 +1,344 @@ +package liveness + +import ( + "math/rand" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func newSess() *Session { + return &Session{ + route: nil, + localDiscr: 0xAABBCCDD, + peerDiscr: 0, + state: StateDown, + detectMult: 3, + localTxMin: 20 * time.Millisecond, + localRxMin: 15 * time.Millisecond, + peerTxMin: 10 * time.Millisecond, + peerRxMin: 0, + minTxFloor: 5 * time.Millisecond, + maxTxCeil: 10 * time.Second, + alive: true, + peer: nil, + peerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999}, + nextTx: time.Time{}, + detectDeadline: time.Time{}, + lastRx: time.Time{}, + backoffMax: 1 * time.Second, + backoffFactor: 1, + } +} + +func TestClient_Liveness_Session_ComputeNextTx_JitterWithinBoundsAndPersists(t *testing.T) { + t.Parallel() + s := newSess() + s.localTxMin = 100 * time.Millisecond + s.state = StateDown + now := time.Unix(0, 0) + r := rand.New(rand.NewSource(1)) + next := s.ComputeNextTx(now, r) + + base := s.txInterval() + j := base / 10 + min := now.Add(base - j) + max := now.Add(base + j) + + require.True(t, !next.Before(min) && !next.After(max), "next=%v min=%v max=%v", next, min, max) + require.Equal(t, next, s.nextTx) + require.Equal(t, uint32(2), s.backoffFactor, "backoff should double after scheduling while Down") +} + +func TestClient_Liveness_Session_TxIntervalRespectspeerRxMinFloorAndCeil(t *testing.T) { + t.Parallel() + s := newSess() + s.localTxMin = 20 * time.Millisecond + s.peerRxMin = 50 * time.Millisecond + s.minTxFloor = 60 * time.Millisecond + s.maxTxCeil = 40 * time.Millisecond + require.Equal(t, 40*time.Millisecond, s.txInterval()) +} + +func TestClient_Liveness_Session_RxRefPrefersMaxFloorAndCeil(t *testing.T) { + t.Parallel() + s := newSess() + s.peerTxMin = 10 * time.Millisecond + s.localRxMin = 20 * time.Millisecond + s.minTxFloor = 5 * time.Millisecond + require.Equal(t, 20*time.Millisecond, s.rxInterval()) + + s.peerTxMin = 0 + s.localRxMin = 0 + s.minTxFloor = 7 * time.Millisecond + require.Equal(t, 7*time.Millisecond, s.rxInterval()) + + // ceiling: cap overly large refs + s.peerTxMin = 5 * time.Second + s.localRxMin = 10 * time.Second + s.minTxFloor = 1 * time.Millisecond + s.maxTxCeil = 500 * time.Millisecond + require.Equal(t, 500*time.Millisecond, s.rxInterval()) +} + +func TestClient_Liveness_Session_DetectTimeIsDetectMultTimesRxRef(t *testing.T) { + t.Parallel() + s := newSess() + s.detectMult = 5 + s.peerTxMin = 11 * time.Millisecond + s.localRxMin = 13 * time.Millisecond // max with peerTxMin => 13ms + s.minTxFloor = 3 * time.Millisecond + require.Equal(t, 5*13*time.Millisecond, s.detectTime()) +} + +func TestClient_Liveness_Session_ArmDetectNotAliveOrZeroDeadlineReturnsFalse(t *testing.T) { + t.Parallel() + s := newSess() + s.alive = false + s.detectDeadline = time.Now().Add(1 * time.Second) + _, ok := s.ArmDetect(time.Now()) + require.False(t, ok) + + s = newSess() + s.alive = true + s.detectDeadline = time.Time{} + _, ok = s.ArmDetect(time.Now()) + require.False(t, ok) +} + +func TestClient_Liveness_Session_ArmDetectFutureDeadlineReturnsSameTrue(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + want := now.Add(500 * time.Millisecond) + s.detectDeadline = want + ddl, ok := s.ArmDetect(now) + require.True(t, ok) + require.Equal(t, want, ddl) + require.Equal(t, want, s.detectDeadline) +} + +func TestClient_Liveness_Session_ArmDetectPastDeadlineReschedules(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + s.detectDeadline = now.Add(-1 * time.Millisecond) + ddl, ok := s.ArmDetect(now) + require.True(t, ok) + require.True(t, ddl.After(now)) + require.Equal(t, ddl, s.detectDeadline) +} + +func TestClient_Liveness_Session_ExpireIfDueTransitionsToDownAndClearsDeadline(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + s.state = StateUp + s.detectDeadline = now.Add(-1 * time.Millisecond) + exp := s.ExpireIfDue(now) + require.True(t, exp) + require.Equal(t, StateDown, s.state) + require.True(t, s.detectDeadline.IsZero()) + require.Equal(t, uint32(1), s.backoffFactor, "backoff should reset after transition to Down") +} + +func TestClient_Liveness_Session_ExpireIfDueNoTransitionWhenNotDueOrNotAlive(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + s.state = StateInit + s.detectDeadline = now.Add(1 * time.Second) + require.False(t, s.ExpireIfDue(now)) + require.Equal(t, StateInit, s.state) + + s = newSess() + s.state = StateUp + s.alive = false + s.detectDeadline = now.Add(-1 * time.Millisecond) + require.False(t, s.ExpireIfDue(now)) + require.Equal(t, StateUp, s.state) +} + +func TestClient_Liveness_Session_HandleRxIgnoresMismatchedpeerDiscrr(t *testing.T) { + t.Parallel() + s := newSess() + s.localDiscr = 111 + now := time.Now() + cp := &ControlPacket{peerDiscrr: 222, LocalDiscrr: 333, State: StateInit} + changed := s.HandleRx(now, cp) + require.False(t, changed) + require.Equal(t, StateDown, s.state) + require.Zero(t, s.peerDiscr) +} + +func TestClient_Liveness_Session_HandleRxFromDownToInitOrUpAndArmsDetect(t *testing.T) { + t.Parallel() + s := newSess() + s.state = StateDown + s.localDiscr = 42 + + now := time.Now() + // Peer Down -> go Init + cpDown := &ControlPacket{ + peerDiscrr: 0, // acceptable (we only check mismatch if nonzero) + LocalDiscrr: 1001, // learn peer discr + State: StateDown, + DesiredMinTxUs: 30_000, // 30ms + RequiredMinRxUs: 40_000, // 40ms + } + changed := s.HandleRx(now, cpDown) + require.True(t, changed) + require.Equal(t, StateInit, s.state) + require.EqualValues(t, 1001, s.peerDiscr) + require.False(t, s.detectDeadline.IsZero()) + require.Equal(t, now, s.lastRx) + + // Next packet peer Init -> go Up + cpInit := &ControlPacket{ + peerDiscrr: 42, // matches our localDiscr (explicit echo required) + LocalDiscrr: 1001, + State: StateInit, + DesiredMinTxUs: 20_000, + RequiredMinRxUs: 20_000, + } + changed = s.HandleRx(now.Add(10*time.Millisecond), cpInit) + require.True(t, changed) + require.Equal(t, StateUp, s.state) + require.Equal(t, uint32(1), s.backoffFactor, "backoff should reset when leaving Down") +} + +func TestClient_Liveness_Session_HandleRxFromInitToUpOnPeerInitOrUp(t *testing.T) { + t.Parallel() + s := newSess() + s.state = StateInit + s.peerDiscr = 777 // already learned + now := time.Now() + + // Without explicit echo (peerDiscrr != localDiscr), do NOT promote. + cpNoEcho := &ControlPacket{peerDiscrr: 0, LocalDiscrr: 777, State: StateUp} + changed := s.HandleRx(now, cpNoEcho) + require.False(t, changed) + require.Equal(t, StateInit, s.state) + + // With explicit echo (peerDiscrr == localDiscr), promote to Up. + cpEcho := &ControlPacket{peerDiscrr: s.localDiscr, LocalDiscrr: s.peerDiscr, State: StateUp} + changed = s.HandleRx(now, cpEcho) + require.True(t, changed) + require.Equal(t, StateUp, s.state) +} + +func TestClient_Liveness_Session_HandleRxFromUpToDownWhenPeerReportsDownAndStopDetect(t *testing.T) { + t.Parallel() + s := newSess() + s.state = StateUp + s.peerDiscr = 1 + now := time.Now() + s.detectDeadline = now.Add(10 * time.Second) + + cp := &ControlPacket{peerDiscrr: 0, LocalDiscrr: 1, State: StateDown} + changed := s.HandleRx(now, cp) + require.True(t, changed) + require.Equal(t, StateDown, s.state) + require.True(t, s.detectDeadline.IsZero()) + require.Equal(t, uint32(1), s.backoffFactor, "backoff should reset when entering Down") +} + +func TestClient_Liveness_Session_HandleRxSetsPeerTimersAndDetectDeadline(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + cp := &ControlPacket{ + peerDiscrr: 0, + LocalDiscrr: 9, + State: StateInit, + DesiredMinTxUs: 12_000, + RequiredMinRxUs: 34_000, + } + _ = s.HandleRx(now, cp) + require.Equal(t, 12*time.Millisecond, s.peerTxMin) + require.Equal(t, 34*time.Millisecond, s.peerRxMin) + require.False(t, s.detectDeadline.IsZero()) + require.Equal(t, now, s.lastRx) +} + +func TestClient_Liveness_Session_BackoffResetsWhenNotDown(t *testing.T) { + t.Parallel() + s := newSess() + s.state = StateDown + s.backoffFactor = 8 + s.backoffMax = 200 * time.Millisecond + _ = s.ComputeNextTx(time.Now(), nil) // will keep doubling (capped) while Down + s.state = StateUp + _ = s.ComputeNextTx(time.Now(), nil) // leaves Down -> resets + require.Equal(t, uint32(1), s.backoffFactor) +} + +func TestClient_Liveness_Session_HandleRxIgnoredWhenAdminDown(t *testing.T) { + t.Parallel() + s := newSess() + s.state = StateAdminDown + now := time.Now() + cp := &ControlPacket{peerDiscrr: 0, LocalDiscrr: 9, State: StateUp, DesiredMinTxUs: 1000, RequiredMinRxUs: 2000} + changed := s.HandleRx(now, cp) + require.False(t, changed) + require.Equal(t, StateAdminDown, s.state) + require.Zero(t, s.peerDiscr) +} + +func TestClient_Liveness_Session_HandleRxClampsTimersAndDetectMultZero(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + // Configure floors/ceils to make clamping observable. + s.minTxFloor = 7 * time.Millisecond + s.maxTxCeil = 40 * time.Millisecond + + cp := &ControlPacket{ + peerDiscrr: 0, + LocalDiscrr: 9, + State: StateInit, + DetectMult: 0, // invalid → clamp to 1 (internal) + DesiredMinTxUs: 1_000, // 1ms → clamp up to 7ms + RequiredMinRxUs: 1_000_000, // 1s → clamp down to 40ms + } + _ = s.HandleRx(now, cp) + + require.Equal(t, 7*time.Millisecond, s.peerTxMin) + require.Equal(t, 40*time.Millisecond, s.peerRxMin) + require.False(t, s.detectDeadline.IsZero()) +} + +func TestClient_Liveness_Session_ComputeNextTx_LargeInterval_NoOverflow(t *testing.T) { + t.Parallel() + s := newSess() + s.localTxMin = 3 * time.Hour + s.state = StateUp + require.NotPanics(t, func() { _ = s.ComputeNextTx(time.Now(), rand.New(rand.NewSource(1))) }) +} + +func TestClient_Liveness_Session_HandleRx_NoChange_RearmsDetect(t *testing.T) { + t.Parallel() + s := newSess() + now := time.Now() + s.state = StateUp + s.detectDeadline = now.Add(100 * time.Millisecond) + + callNow := now.Add(10 * time.Millisecond) + cp := &ControlPacket{ + peerDiscrr: s.localDiscr, // accepted (echo ok) + LocalDiscrr: s.peerDiscr, // may be 0; fine + State: StateUp, + DesiredMinTxUs: 20000, // 20ms + RequiredMinRxUs: 20000, + } + changed := s.HandleRx(callNow, cp) + require.False(t, changed) + + // Expect re-armed to ≈ callNow + detectTime() + wantMin := callNow.Add(s.detectTime() - 2*time.Millisecond) // tiny slack + wantMax := callNow.Add(s.detectTime() + 2*time.Millisecond) + require.True(t, !s.detectDeadline.Before(wantMin) && !s.detectDeadline.After(wantMax), + "got=%v want≈%v", s.detectDeadline, callNow.Add(s.detectTime())) +} diff --git a/client/doublezerod/internal/liveness/udp.go b/client/doublezerod/internal/liveness/udp.go new file mode 100644 index 000000000..3be3ea7a7 --- /dev/null +++ b/client/doublezerod/internal/liveness/udp.go @@ -0,0 +1,133 @@ +package liveness + +import ( + "errors" + "fmt" + "net" + "time" + + "golang.org/x/net/ipv4" +) + +// UDPService wraps an IPv4 UDP socket and provides helpers for reading and writing +// datagrams while preserving local interface and destination address context. +// It preconfigures IPv4 control message delivery (IP_PKTINFO equivalent) so that +// each received packet includes metadata about which interface and destination IP +// it arrived on, and outgoing packets can explicitly set source IP and interface. +type UDPService struct { + raw *net.UDPConn // the underlying UDP socket + pc4 *ipv4.PacketConn // ipv4-layer wrapper for control message access +} + +// ListenUDP binds an IPv4 UDP socket to bindIP:port and returns a configured UDPService. +// The returned connection is ready to read/write with control message support enabled. +func ListenUDP(bindIP string, port int) (*UDPService, error) { + laddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", bindIP, port)) + if err != nil { + return nil, err + } + raw, err := net.ListenUDP("udp4", laddr) + if err != nil { + return nil, err + } + u, err := NewUDPService(raw) + if err != nil { + _ = raw.Close() + return nil, err + } + return u, nil +} + +// NewUDPService wraps an existing *net.UDPConn and enables IPv4 control messages (IP_PKTINFO-like). +// On RX we obtain the destination IP and interface index; on TX we can set source IP and interface. +func NewUDPService(raw *net.UDPConn) (*UDPService, error) { + u := &UDPService{raw: raw, pc4: ipv4.NewPacketConn(raw)} + // Enable both RX and TX control messages: destination IP, source IP, and interface index. + if err := u.pc4.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst|ipv4.FlagSrc, true); err != nil { + return nil, err + } + return u, nil +} + +// Close shuts down the underlying UDP socket. +func (u *UDPService) Close() error { return u.raw.Close() } + +// ReadFrom reads a single UDP datagram and returns: +// - number of bytes read +// - remoteAddr address (sender) +// - local destination IP the packet was received on +// - interface name where it arrived +// +// The caller should configure read deadlines via SetReadDeadline before calling. +// This function extracts control message metadata (IP_PKTINFO) to provide per-packet context. +func (u *UDPService) ReadFrom(buf []byte) (n int, remoteAddr *net.UDPAddr, localIP net.IP, ifname string, err error) { + n, cm4, raddr, err := u.pc4.ReadFrom(buf) + if err != nil { + return 0, nil, nil, "", err + } + if ua, ok := raddr.(*net.UDPAddr); ok { + remoteAddr = ua + } + if cm4 != nil { + if cm4.Dst != nil { + localIP = cm4.Dst + } + if cm4.IfIndex != 0 { + ifi, _ := net.InterfaceByIndex(cm4.IfIndex) + if ifi != nil { + ifname = ifi.Name + } + } + } + return n, remoteAddr, localIP, ifname, nil +} + +// WriteTo transmits a UDP datagram to an IPv4 destination. +// The caller may optionally provide: +// - iface: name of the outgoing interface to bind transmission to +// - src: source IP to use (if nil, the kernel selects one) +// +// Returns number of bytes written or an error. +// This uses an ipv4.ControlMessage to set per-packet src/interface hints. +func (u *UDPService) WriteTo(pkt []byte, dst *net.UDPAddr, iface string, src net.IP) (int, error) { + if dst == nil || dst.IP == nil { + return 0, errors.New("nil dst") + } + ip4 := dst.IP.To4() + if ip4 == nil { + return 0, errors.New("ipv6 dst not supported") + } + + var ifidx int + if iface != "" { + ifi, err := net.InterfaceByName(iface) + if err != nil { + return 0, err + } + ifidx = ifi.Index + } + + var cm ipv4.ControlMessage + if ifidx != 0 { + cm.IfIndex = ifidx + } + if src != nil { + if s4 := src.To4(); s4 != nil { + cm.Src = s4 + } + // Non-IPv4 src ignored silently in IPv4 mode. + } + + return u.pc4.WriteTo(pkt, &cm, &net.UDPAddr{IP: ip4, Port: dst.Port, Zone: dst.Zone}) +} + +// SetReadDeadline forwards directly to the underlying UDPService. +// This controls how long ReadFrom will block before returning a timeout. +func (u *UDPService) SetReadDeadline(t time.Time) error { + return u.raw.SetReadDeadline(t) +} + +// LocalAddr returns the socket’s bound local address (IP and port). +func (u *UDPService) LocalAddr() net.Addr { + return u.raw.LocalAddr() +} diff --git a/client/doublezerod/internal/liveness/udp_test.go b/client/doublezerod/internal/liveness/udp_test.go new file mode 100644 index 000000000..26e2b1929 --- /dev/null +++ b/client/doublezerod/internal/liveness/udp_test.go @@ -0,0 +1,169 @@ +package liveness + +import ( + "net" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClient_Liveness_UDP_WriteUDPWithNilDst(t *testing.T) { + t.Parallel() + uc, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer uc.Close() + + u, err := NewUDPService(uc) + require.NoError(t, err) + + n, err := u.WriteTo([]byte("x"), nil, "", nil) + require.EqualError(t, err, "nil dst") + require.Equal(t, 0, n) +} + +func TestClient_Liveness_UDP_WriteUDPWithBadIface(t *testing.T) { + t.Parallel() + + srv, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer srv.Close() + + cl, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer cl.Close() + + w, err := NewUDPService(cl) + require.NoError(t, err) + + dst := srv.LocalAddr().(*net.UDPAddr) + _, err = w.WriteTo([]byte("payload"), dst, "definitely-not-an-interface", nil) + require.Error(t, err) +} + +func TestClient_Liveness_UDP_IPv4RoundtripWriteAndRead(t *testing.T) { + t.Parallel() + + srv, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer srv.Close() + _ = srv.SetDeadline(time.Now().Add(2 * time.Second)) + + cl, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer cl.Close() + _ = cl.SetDeadline(time.Now().Add(2 * time.Second)) + + r, err := NewUDPService(srv) + require.NoError(t, err) + w, err := NewUDPService(cl) + require.NoError(t, err) + + payload := []byte("hello-v4") + dst := srv.LocalAddr().(*net.UDPAddr) + + nw, err := w.WriteTo(payload, dst, "", nil) + require.NoError(t, err) + require.Equal(t, len(payload), nw) + + buf := make([]byte, 128) + nr, src, dstIP, ifname, err := r.ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, len(payload), nr) + require.Equal(t, payload, buf[:nr]) + + require.NotNil(t, src) + + clientLocal := cl.LocalAddr().(*net.UDPAddr) + serverLocal := srv.LocalAddr().(*net.UDPAddr) + + // Must be the client's IP/port (fails if swapped) + require.True(t, src.IP.Equal(clientLocal.IP)) + require.Equal(t, clientLocal.Port, src.Port) + + // Must be the server's local IP (fails if swapped) + require.NotNil(t, dstIP) + require.True(t, dstIP.Equal(serverLocal.IP)) + + // ifname may be empty; if present, it should be loopback + lb := loopbackInterface(t) + if ifname != "" { + require.Equal(t, lb.Name, ifname) + } +} + +func TestClient_Liveness_UDP_WriteUDPWithSrcHintIPv4(t *testing.T) { + t.Parallel() + + // Binding to 0.0.0.0 then hinting src=127.0.0.1 should still succeed locally. + srv, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer srv.Close() + _ = srv.SetDeadline(time.Now().Add(2 * time.Second)) + + cl, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0}) + require.NoError(t, err) + defer cl.Close() + _ = cl.SetDeadline(time.Now().Add(2 * time.Second)) + + r, err := NewUDPService(srv) + require.NoError(t, err) + w, err := NewUDPService(cl) + require.NoError(t, err) + + payload := []byte("src-hint") + dst := srv.LocalAddr().(*net.UDPAddr) + + nw, err := w.WriteTo(payload, dst, "", net.ParseIP("127.0.0.1")) + // Some OSes may reject an impossible source; accept either success or specific error, but never hang. + if err != nil && runtime.GOOS == "windows" { + t.Skipf("src control message not supported on %s: %v", runtime.GOOS, err) + } + if err == nil { + require.Equal(t, len(payload), nw) + + buf := make([]byte, 128) + nr, _, _, _, err := r.ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, payload, buf[:nr]) + } +} + +func TestClient_Liveness_UDP_WriteTo_RejectsIPv6(t *testing.T) { + t.Parallel() + uc, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer uc.Close() + u, err := NewUDPService(uc) + require.NoError(t, err) + _, err = u.WriteTo([]byte("x"), &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1}, "", nil) + require.EqualError(t, err, "ipv6 dst not supported") +} + +func TestClient_Liveness_UDP_ReadDeadline_TimesOut(t *testing.T) { + t.Parallel() + srv, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer srv.Close() + r, err := NewUDPService(srv) + require.NoError(t, err) + require.NoError(t, r.SetReadDeadline(time.Now().Add(50*time.Millisecond))) + buf := make([]byte, 8) + _, _, _, _, err = r.ReadFrom(buf) + require.Error(t, err) + nerr, ok := err.(net.Error) + require.True(t, ok && nerr.Timeout()) +} + +func loopbackInterface(t *testing.T) net.Interface { + ifs, err := net.Interfaces() + require.NoError(t, err) + for _, ifi := range ifs { + if ifi.Flags&net.FlagLoopback != 0 && ifi.Flags&net.FlagUp != 0 { + return ifi + } + } + t.Skip("no up loopback interface found") + return net.Interface{} +} diff --git a/client/doublezerod/internal/runtime/run.go b/client/doublezerod/internal/runtime/run.go index 51428dd61..a7a0011c3 100644 --- a/client/doublezerod/internal/runtime/run.go +++ b/client/doublezerod/internal/runtime/run.go @@ -13,13 +13,14 @@ import ( "github.com/malbeclabs/doublezero/client/doublezerod/internal/api" "github.com/malbeclabs/doublezero/client/doublezerod/internal/bgp" "github.com/malbeclabs/doublezero/client/doublezerod/internal/latency" + "github.com/malbeclabs/doublezero/client/doublezerod/internal/liveness" "github.com/malbeclabs/doublezero/client/doublezerod/internal/manager" "github.com/malbeclabs/doublezero/client/doublezerod/internal/pim" "github.com/malbeclabs/doublezero/client/doublezerod/internal/routing" "golang.org/x/sys/unix" ) -func Run(ctx context.Context, sockFile string, routeConfigPath string, enableLatencyProbing, enableLatencyMetrics bool, programId string, rpcEndpoint string, probeInterval, cacheUpdateInterval int) error { +func Run(ctx context.Context, sockFile string, routeConfigPath string, enableLatencyProbing, enableLatencyMetrics bool, programId string, rpcEndpoint string, probeInterval, cacheUpdateInterval int, lmc *liveness.ManagerConfig) error { nlr := routing.Netlink{} var crw bgp.RouteReaderWriter if _, err := os.Stat(routeConfigPath); os.IsNotExist(err) { @@ -30,7 +31,23 @@ func Run(ctx context.Context, sockFile string, routeConfigPath string, enableLat return fmt.Errorf("error creating configured route reader writer: %v", err) } } - bgp, err := bgp.NewBgpServer(net.IPv4(1, 1, 1, 1), crw) + + // If the liveness manager config is not nil, create a new manager. + // Otherwise, completely disable the liveness subsystem. + // TODO(snormore): The scenario where the liveness subsystem is completely disabled is + // temporary for initial rollout testing. + var lm *liveness.Manager + if lmc != nil { + lmc.Netlinker = crw + var err error + lm, err = liveness.NewManager(ctx, lmc) + if err != nil { + return fmt.Errorf("error creating liveness manager: %v", err) + } + defer lm.Close() + } + + bgp, err := bgp.NewBgpServer(net.IPv4(1, 1, 1, 1), crw, lm) if err != nil { return fmt.Errorf("error creating bgp server: %v", err) } @@ -116,6 +133,14 @@ func Run(ctx context.Context, sockFile string, routeConfigPath string, enableLat errCh <- err }() + // The liveness manager can be nil if the liveness subsystem is disabled. + // TODO(snormore): The scenario where the liveness subsystem is completely disabled is + // temporary for initial rollout testing. + var lmErrCh <-chan error + if lm != nil { + lmErrCh = lm.Err() + } + select { case <-ctx.Done(): slog.Info("teardown: cleaning up and closing") @@ -124,5 +149,7 @@ func Run(ctx context.Context, sockFile string, routeConfigPath string, enableLat return nil case err := <-errCh: return err + case err := <-lmErrCh: + return err } } diff --git a/client/doublezerod/internal/runtime/run_test.go b/client/doublezerod/internal/runtime/run_test.go index 09f24bec9..bfa5eb906 100644 --- a/client/doublezerod/internal/runtime/run_test.go +++ b/client/doublezerod/internal/runtime/run_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/http" "net/netip" @@ -29,8 +30,10 @@ import ( "github.com/jwhited/corebgp" "github.com/malbeclabs/doublezero/client/doublezerod/internal/api" "github.com/malbeclabs/doublezero/client/doublezerod/internal/bgp" + "github.com/malbeclabs/doublezero/client/doublezerod/internal/liveness" "github.com/malbeclabs/doublezero/client/doublezerod/internal/pim" "github.com/malbeclabs/doublezero/client/doublezerod/internal/runtime" + "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" "golang.org/x/sys/unix" @@ -159,7 +162,7 @@ func runIBRLTest(t *testing.T, userType api.UserType, provisioningRequest map[st sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -378,7 +381,7 @@ func runIBRLTest(t *testing.T, userType api.UserType, provisioningRequest map[st ctx, cancel = context.WithCancel(context.Background()) go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -494,7 +497,7 @@ func TestEndToEnd_EdgeFiltering(t *testing.T) { sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -642,7 +645,7 @@ func TestEndToEnd_EdgeFiltering(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -863,7 +866,7 @@ func TestMulticastPublisher(t *testing.T) { sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1024,7 +1027,7 @@ func TestMulticastPublisher(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1231,7 +1234,7 @@ func TestMulticastSubscriber(t *testing.T) { sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1492,7 +1495,7 @@ func TestMulticastSubscriber(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1647,12 +1650,11 @@ func TestServiceNoCoExistence(t *testing.T) { }() errChan := make(chan error, 1) - ctx, _ := context.WithCancel(context.Background()) sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(t.Context(), sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1833,7 +1835,7 @@ func TestServiceCoexistence(t *testing.T) { sockFile := filepath.Join(rootPath, "doublezerod.sock") go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -1950,7 +1952,7 @@ func TestServiceCoexistence(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) go func() { programId := "" - err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30) + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) errChan <- err }() @@ -2050,6 +2052,120 @@ func TestServiceCoexistence(t *testing.T) { }) } +func TestRuntime_Run_ReturnsOnContextCancel(t *testing.T) { + errChan := make(chan error, 1) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + rootPath, err := os.MkdirTemp("", "doublezerod") + require.NoError(t, err) + defer os.RemoveAll(rootPath) + t.Setenv("XDG_STATE_HOME", rootPath) + + path := filepath.Join(rootPath, "doublezerod") + if err := os.Mkdir(path, 0766); err != nil { + t.Fatalf("error creating state dir: %v", err) + } + + sockFile := filepath.Join(rootPath, "doublezerod.sock") + go func() { + programId := "" + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, newTestLivenessManagerConfig()) + errChan <- err + }() + + // Give the runtime a moment to start, then cancel the context to force exit. + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(300 * time.Millisecond): + } + + cancel() + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for runtime to exit after context cancel") + } +} + +func TestRuntime_Run_PropagatesLivenessStartupError(t *testing.T) { + errChan := make(chan error, 1) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + rootPath, err := os.MkdirTemp("", "doublezerod") + require.NoError(t, err) + defer os.RemoveAll(rootPath) + t.Setenv("XDG_STATE_HOME", rootPath) + + // Invalid liveness config (port < 0) -> NewManager.Validate() error. + bad := *newTestLivenessManagerConfig() + bad.Port = -1 + + sockFile := filepath.Join(rootPath, "doublezerod.sock") + go func() { + programId := "" + err := runtime.Run(ctx, sockFile, "", false, false, programId, "", 30, 30, &bad) + errChan <- err + }() + + select { + case err := <-errChan: + require.Error(t, err) + require.Contains(t, err.Error(), "port must be greater than or equal to 0") + case <-time.After(5 * time.Second): + t.Fatalf("expected startup error from runtime.Run with bad liveness config") + } +} + +func TestRuntime_Run_PropagatesLivenessError_FromUDPClosure(t *testing.T) { + errCh := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Minimal state dir + socket path + rootPath, err := os.MkdirTemp("", "doublezerod") + if err != nil { + t.Fatalf("mktemp: %v", err) + } + defer os.RemoveAll(rootPath) + t.Setenv("XDG_STATE_HOME", rootPath) + sockFile := filepath.Join(rootPath, "doublezerod.sock") + + // Create a real UDPService we can close to induce a receiver error. + udp, err := liveness.ListenUDP("127.0.0.1", 0) + if err != nil { + t.Fatalf("ListenUDP: %v", err) + } + + // Build a liveness config that uses our injected UDP service. + cfg := newTestLivenessManagerConfig() + cfg.UDP = udp + cfg.PassiveMode = true + + // Start the runtime. + go func() { + programID := "" + errCh <- runtime.Run(ctx, sockFile, "", false, false, programID, "", 30, 30, cfg) + }() + + // Give the liveness receiver a moment to start, then close the UDP socket. + time.Sleep(200 * time.Millisecond) + _ = udp.Close() + + // The receiver should error, Manager should send on lm.Err(), and Run should return that error. + select { + case err := <-errCh: + if err == nil { + t.Fatalf("expected non-nil error propagated from liveness manager, got nil") + } + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for runtime to return error from liveness manager") + } +} + func setupTest(t *testing.T) (func(), error) { abortIfLinksAreUp(t) rootPath, err := os.MkdirTemp("", "doublezerod") @@ -2434,3 +2550,17 @@ func abortIfLinksAreUp(t *testing.T) { } } } + +func newTestLivenessManagerConfig() *liveness.ManagerConfig { + return &liveness.ManagerConfig{ + Logger: slog.Default(), + BindIP: "0.0.0.0", + Port: 44880, + PassiveMode: true, + TxMin: 300 * time.Millisecond, + RxMin: 300 * time.Millisecond, + DetectMult: 3, + MinTxFloor: 50 * time.Millisecond, + MaxTxCeil: 1 * time.Second, + } +} diff --git a/client/doublezerod/internal/services/ibrl.go b/client/doublezerod/internal/services/ibrl.go index ae38f92fd..e2712da12 100644 --- a/client/doublezerod/internal/services/ibrl.go +++ b/client/doublezerod/internal/services/ibrl.go @@ -14,11 +14,12 @@ import ( ) type IBRLService struct { - bgp BGPReaderWriter - nl routing.Netlinker - db DBReaderWriter - Tunnel *routing.Tunnel - DoubleZeroAddr net.IP + bgp BGPReaderWriter + nl routing.Netlinker + db DBReaderWriter + Tunnel *routing.Tunnel + DoubleZeroAddr net.IP + livenessEnabled bool } func (s *IBRLService) UserType() api.UserType { return api.UserTypeIBRL } @@ -29,6 +30,8 @@ func NewIBRLService(bgp BGPReaderWriter, nl routing.Netlinker, db DBReaderWriter bgp: bgp, nl: nl, db: db, + + livenessEnabled: true, } } @@ -47,7 +50,7 @@ func (s *IBRLService) Setup(p *api.ProvisionRequest) error { err = createTunnelWithIP(s.nl, tun, p.DoubleZeroIP) flush = false default: - return fmt.Errorf("unsupported tunnel type: %v\n", p) + return fmt.Errorf("unsupported tunnel type: %v", p) } if err != nil { return fmt.Errorf("error creating tunnel interface: %v", err) @@ -57,13 +60,15 @@ func (s *IBRLService) Setup(p *api.ProvisionRequest) error { s.DoubleZeroAddr = p.DoubleZeroIP peer := &bgp.PeerConfig{ - RemoteAddress: s.Tunnel.RemoteOverlay, - LocalAddress: s.Tunnel.LocalOverlay, - LocalAs: p.BgpLocalAsn, - RemoteAs: p.BgpRemoteAsn, - RouteSrc: p.DoubleZeroIP, - RouteTable: syscall.RT_TABLE_MAIN, - FlushRoutes: flush, + RemoteAddress: s.Tunnel.RemoteOverlay, + LocalAddress: s.Tunnel.LocalOverlay, + LocalAs: p.BgpLocalAsn, + RemoteAs: p.BgpRemoteAsn, + RouteSrc: p.DoubleZeroIP, + RouteTable: syscall.RT_TABLE_MAIN, + FlushRoutes: flush, + LivenessEnabled: s.livenessEnabled, + Interface: "doublezero0", } nlri, err := bgp.NewNLRI([]uint32{peer.LocalAs}, s.Tunnel.LocalOverlay.String(), p.DoubleZeroIP.String(), 32) if err != nil { diff --git a/client/doublezerod/internal/services/services_test.go b/client/doublezerod/internal/services/services_test.go index 1bea2effa..4429ec00e 100644 --- a/client/doublezerod/internal/services/services_test.go +++ b/client/doublezerod/internal/services/services_test.go @@ -170,13 +170,15 @@ func TestServices(t *testing.T) { wantRulesAdded: nil, wantRoutesAdded: nil, wantPeerConfig: &bgp.PeerConfig{ - LocalAddress: net.IPv4(169, 254, 0, 1), - RemoteAddress: net.IPv4(169, 254, 0, 0), - LocalAs: 65000, - RemoteAs: 65001, - RouteSrc: net.IPv4(192, 168, 1, 1), - RouteTable: syscall.RT_TABLE_MAIN, - FlushRoutes: true, + LocalAddress: net.IPv4(169, 254, 0, 1), + RemoteAddress: net.IPv4(169, 254, 0, 0), + LocalAs: 65000, + RemoteAs: 65001, + RouteSrc: net.IPv4(192, 168, 1, 1), + RouteTable: syscall.RT_TABLE_MAIN, + FlushRoutes: true, + Interface: "doublezero0", + LivenessEnabled: true, }, wantTunRemoved: &routing.Tunnel{ Name: "doublezero0", @@ -225,6 +227,7 @@ func TestServices(t *testing.T) { RouteSrc: net.IPv4(192, 168, 1, 0), RouteTable: syscall.RT_TABLE_MAIN, FlushRoutes: false, + Interface: "doublezero0", }, wantTunRemoved: &routing.Tunnel{ Name: "doublezero0", diff --git a/e2e/docker/client/Dockerfile b/e2e/docker/client/Dockerfile index 549b4573d..6ed866974 100644 --- a/e2e/docker/client/Dockerfile +++ b/e2e/docker/client/Dockerfile @@ -4,7 +4,7 @@ FROM ${BASE_IMAGE} AS base FROM ubuntu:24.04 RUN apt-get update && \ - apt-get install -y curl jq iproute2 iputils-ping iproute2 net-tools tcpdump vim iperf fping iptables ethtool + apt-get install -y curl jq iproute2 iputils-ping iproute2 net-tools tcpdump tshark vim iperf fping iptables ethtool COPY --from=base /doublezero/bin/doublezero /usr/local/bin/ COPY --from=base /doublezero/bin/doublezerod /usr/local/bin/ diff --git a/e2e/docker/client/entrypoint.sh b/e2e/docker/client/entrypoint.sh index 247afd9ed..19bfd1206 100755 --- a/e2e/docker/client/entrypoint.sh +++ b/e2e/docker/client/entrypoint.sh @@ -77,4 +77,4 @@ for dev in $(ip -o link show | awk -F': ' '/^ *[0-9]+: eth[0-9]+/ {print $2}' | done # Start doublezerod. -doublezerod -program-id ${DZ_SERVICEABILITY_PROGRAM_ID} -solana-rpc-endpoint ${DZ_LEDGER_URL} -probe-interval 5 -cache-update-interval 3 +doublezerod -program-id ${DZ_SERVICEABILITY_PROGRAM_ID} -solana-rpc-endpoint ${DZ_LEDGER_URL} -probe-interval 5 -cache-update-interval 3 ${DZ_CLIENT_EXTRA_ARGS} diff --git a/e2e/ibrl_test.go b/e2e/ibrl_test.go index ab4a5de9a..b9de1f22a 100644 --- a/e2e/ibrl_test.go +++ b/e2e/ibrl_test.go @@ -45,13 +45,13 @@ func TestE2E_IBRL(t *testing.T) { if !t.Run("remove_ibgp_msdp_peer", func(t *testing.T) { dn.DeleteDeviceLoopbackInterface(t.Context(), "pit-dzd01", "Loopback255") time.Sleep(30 * time.Second) // Wait for the device to process the change - checkIbgpMsdpPeerRemoved(t, dn, device, client) + checkIbgpMsdpPeerRemoved(t, dn, device) }) { t.Fail() } } -func checkIbgpMsdpPeerRemoved(t *testing.T, dn *TestDevnet, device *devnet.Device, client *devnet.Client) { +func checkIbgpMsdpPeerRemoved(t *testing.T, dn *TestDevnet, device *devnet.Device) { dn.log.Info("==> Checking that iBGP/MSDP peers have been removed after peer's Loopback255 interface was removed") if !t.Run("wait_for_agent_config_after_peer_removal", func(t *testing.T) { diff --git a/e2e/internal/devnet/client.go b/e2e/internal/devnet/client.go index 10e87050d..18f8ba384 100644 --- a/e2e/internal/devnet/client.go +++ b/e2e/internal/devnet/client.go @@ -23,6 +23,17 @@ type ClientSpec struct { ContainerImage string KeypairPath string + // Route liveness passive/active mode flags. + // TODO(snormore): These flags are temporary for initial rollout testing. + // They will be superceded by a single `route-liveness-enable` flag, where false means passive-mode + // and true means active-mode. + RouteLivenessEnablePassive bool + RouteLivenessEnableActive bool + + // RouteLivenessEnable is a flag to enable or disable route liveness. False puts the system in + // passive-mode, and true puts it in active-mode. + // RouteLivenessEnable bool + // CYOANetworkIPHostID is the offset into the host portion of the subnet (must be < 2^(32 - prefixLen)). CYOANetworkIPHostID uint32 } @@ -152,6 +163,14 @@ func (c *Client) Start(ctx context.Context) error { // We need to set this here because dockerContainerName and dockerContainerHostname use it. c.Pubkey = pubkey + extraArgs := []string{} + if c.Spec.RouteLivenessEnablePassive { + extraArgs = append(extraArgs, "-route-liveness-enable-passive") + } + if c.Spec.RouteLivenessEnableActive { + extraArgs = append(extraArgs, "-route-liveness-enable-active") + } + // Start the client container. req := testcontainers.ContainerRequest{ Image: c.Spec.ContainerImage, @@ -163,6 +182,7 @@ func (c *Client) Start(ctx context.Context) error { "DZ_LEDGER_URL": c.dn.Ledger.InternalRPCURL, "DZ_LEDGER_WS": c.dn.Ledger.InternalRPCWSURL, "DZ_SERVICEABILITY_PROGRAM_ID": c.dn.Manager.ServiceabilityProgramID, + "DZ_CLIENT_EXTRA_ARGS": strings.Join(extraArgs, " "), }, Files: []testcontainers.ContainerFile{ { diff --git a/e2e/internal/devnet/cmd/add-client.go b/e2e/internal/devnet/cmd/add-client.go index d5ed433f5..698886a32 100644 --- a/e2e/internal/devnet/cmd/add-client.go +++ b/e2e/internal/devnet/cmd/add-client.go @@ -17,6 +17,8 @@ func NewAddClientCmd() *AddClientCmd { func (c *AddClientCmd) Command() *cobra.Command { var cyoaNetworkHostID uint32 var keypairPath string + var routeLivenessEnablePassive bool + var routeLivenessEnableActive bool cmd := &cobra.Command{ Use: "add-client", @@ -28,8 +30,10 @@ func (c *AddClientCmd) Command() *cobra.Command { } _, err = dn.AddClient(ctx, devnet.ClientSpec{ - CYOANetworkIPHostID: cyoaNetworkHostID, - KeypairPath: keypairPath, + CYOANetworkIPHostID: cyoaNetworkHostID, + KeypairPath: keypairPath, + RouteLivenessEnablePassive: routeLivenessEnablePassive, + RouteLivenessEnableActive: routeLivenessEnableActive, }) if err != nil { return fmt.Errorf("failed to add client: %w", err) @@ -43,6 +47,8 @@ func (c *AddClientCmd) Command() *cobra.Command { _ = cmd.MarkFlagRequired("cyoa-network-host-id") cmd.Flags().StringVar(&keypairPath, "keypair-path", "", "Path to the keypair file (optional)") + cmd.Flags().BoolVar(&routeLivenessEnablePassive, "route-liveness-enable-passive", false, "Enable route liveness in passive mode (experimental)") + cmd.Flags().BoolVar(&routeLivenessEnableActive, "route-liveness-enable-active", false, "Enable route liveness in active mode (experimental)") return cmd } diff --git a/e2e/internal/devnet/device.go b/e2e/internal/devnet/device.go index b7354415b..5f1b916b4 100644 --- a/e2e/internal/devnet/device.go +++ b/e2e/internal/devnet/device.go @@ -39,8 +39,8 @@ const ( const ( // Device container is more CPU and memory intensive than the others. - deviceContainerNanoCPUs = 4_000_000_000 // 4 cores - deviceContainerMemory = 4 * 1024 * 1024 * 1024 // 4GB + deviceContainerNanoCPUs = 4_000_000_000 // 4 cores + deviceContainerMemory = 4.5 * 1024 * 1024 * 1024 // 4.5GB defaultCYOANetworkAllocatablePrefix = 29 // 8 addresses ) diff --git a/e2e/internal/rpc/agent_test.go b/e2e/internal/rpc/agent_test.go index 6f752be1a..80c60aac4 100644 --- a/e2e/internal/rpc/agent_test.go +++ b/e2e/internal/rpc/agent_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "os" + "os/exec" "testing" pb "github.com/malbeclabs/doublezero/e2e/proto/qa/gen/pb-go" @@ -74,6 +75,10 @@ func TestQAAgentConnectivity(t *testing.T) { }) t.Run("GetStatus", func(t *testing.T) { + if _, err := exec.LookPath("doublezero"); err != nil { + t.Skip("skipping test: doublezero binary not found") + } + statusResult, err := client.GetStatus(ctx, &emptypb.Empty{}) require.NoError(t, err) require.NotNil(t, statusResult) diff --git a/e2e/multi_client_test.go b/e2e/multi_client_test.go index 9f31fab1e..31a976e60 100644 --- a/e2e/multi_client_test.go +++ b/e2e/multi_client_test.go @@ -3,18 +3,27 @@ package e2e_test import ( + "errors" + "fmt" "log/slog" "os" "path/filepath" + "strconv" "strings" + "sync" "testing" "time" "github.com/malbeclabs/doublezero/e2e/internal/devnet" + "github.com/malbeclabs/doublezero/e2e/internal/docker" "github.com/malbeclabs/doublezero/e2e/internal/random" "github.com/stretchr/testify/require" ) +const ( + routeLivenessPort = 44880 +) + func TestE2E_MultiClient(t *testing.T) { t.Parallel() @@ -47,49 +56,65 @@ func TestE2E_MultiClient(t *testing.T) { _, err = linkNetwork.CreateIfNotExists(t.Context()) require.NoError(t, err) - // Add la2-dz01 device in xlax exchange. + var wg sync.WaitGroup deviceCode1 := "la2-dz01" - device1, err := dn.AddDevice(t.Context(), devnet.DeviceSpec{ - Code: deviceCode1, - Location: "lax", - Exchange: "xlax", - // .8/29 has network address .8, allocatable up to .14, and broadcast .15 - CYOANetworkIPHostID: 8, - CYOANetworkAllocatablePrefix: 29, - AdditionalNetworks: []string{linkNetwork.Name}, - Interfaces: map[string]string{ - "Ethernet2": "physical", - }, - LoopbackInterfaces: map[string]string{ - "Loopback255": "vpnv4", - "Loopback256": "ipv4", - }, - }) - require.NoError(t, err) - devicePK1 := device1.ID - log.Info("--> Device1 added", "deviceCode", deviceCode1, "devicePK", devicePK1) - - // Add ewr1-dz01 device in xewr exchange. deviceCode2 := "ewr1-dz01" - device2, err := dn.AddDevice(t.Context(), devnet.DeviceSpec{ - Code: deviceCode2, - Location: "ewr", - Exchange: "xewr", - // .16/29 has network address .16, allocatable up to .22, and broadcast .23 - CYOANetworkIPHostID: 16, - CYOANetworkAllocatablePrefix: 29, - AdditionalNetworks: []string{linkNetwork.Name}, - Interfaces: map[string]string{ - "Ethernet2": "physical", - }, - LoopbackInterfaces: map[string]string{ - "Loopback255": "vpnv4", - "Loopback256": "ipv4", - }, - }) - require.NoError(t, err) - devicePK2 := device2.ID - log.Info("--> Device2 added", "deviceCode", deviceCode2, "devicePK", devicePK2) + var devicePK1, devicePK2 string + + wg.Add(1) + go func() { + defer wg.Done() + + // Add la2-dz01 device in xlax exchange. + device1, err := dn.AddDevice(t.Context(), devnet.DeviceSpec{ + Code: deviceCode1, + Location: "lax", + Exchange: "xlax", + // .8/29 has network address .8, allocatable up to .14, and broadcast .15 + CYOANetworkIPHostID: 8, + CYOANetworkAllocatablePrefix: 29, + AdditionalNetworks: []string{linkNetwork.Name}, + Interfaces: map[string]string{ + "Ethernet2": "physical", + }, + LoopbackInterfaces: map[string]string{ + "Loopback255": "vpnv4", + "Loopback256": "ipv4", + }, + }) + require.NoError(t, err) + devicePK1 = device1.ID + log.Info("--> Device1 added", "deviceCode", deviceCode1, "devicePK", devicePK1) + }() + + wg.Add(1) + go func() { + defer wg.Done() + + // Add ewr1-dz01 device in xewr exchange. + device2, err := dn.AddDevice(t.Context(), devnet.DeviceSpec{ + Code: deviceCode2, + Location: "ewr", + Exchange: "xewr", + // .16/29 has network address .16, allocatable up to .22, and broadcast .23 + CYOANetworkIPHostID: 16, + CYOANetworkAllocatablePrefix: 29, + AdditionalNetworks: []string{linkNetwork.Name}, + Interfaces: map[string]string{ + "Ethernet2": "physical", + }, + LoopbackInterfaces: map[string]string{ + "Loopback255": "vpnv4", + "Loopback256": "ipv4", + }, + }) + require.NoError(t, err) + devicePK2 = device2.ID + log.Info("--> Device2 added", "deviceCode", deviceCode2, "devicePK", devicePK2) + }() + + // Wait for devices to be added. + wg.Wait() // Wait for devices to exist onchain. log.Info("==> Waiting for devices to exist onchain") @@ -107,42 +132,69 @@ func TestE2E_MultiClient(t *testing.T) { require.NoError(t, err) log.Info("--> Link created onchain") - // Add a client. + // Add client1. log.Info("==> Adding client1") client1, err := dn.AddClient(t.Context(), devnet.ClientSpec{ - CYOANetworkIPHostID: 100, + CYOANetworkIPHostID: 100, + RouteLivenessEnableActive: true, }) require.NoError(t, err) log.Info("--> Client1 added", "client1Pubkey", client1.Pubkey, "client1IP", client1.CYOANetworkIP) - // Add another client. + // Add client2. log.Info("==> Adding client2") client2, err := dn.AddClient(t.Context(), devnet.ClientSpec{ - CYOANetworkIPHostID: 110, + CYOANetworkIPHostID: 110, + RouteLivenessEnablePassive: true, // route liveness in passive mode for this client }) require.NoError(t, err) log.Info("--> Client2 added", "client2Pubkey", client2.Pubkey, "client2IP", client2.CYOANetworkIP) + // Add client3. + log.Info("==> Adding client3") + client3, err := dn.AddClient(t.Context(), devnet.ClientSpec{ + CYOANetworkIPHostID: 120, + RouteLivenessEnableActive: true, // + }) + require.NoError(t, err) + log.Info("--> Client3 added", "client3Pubkey", client3.Pubkey, "client3IP", client3.CYOANetworkIP) + + // Add client4. + log.Info("==> Adding client4") + client4, err := dn.AddClient(t.Context(), devnet.ClientSpec{ + CYOANetworkIPHostID: 130, + RouteLivenessEnablePassive: false, // route liveness subsystem is disabled for this client + RouteLivenessEnableActive: false, + }) + require.NoError(t, err) + log.Info("--> Client4 added", "client4Pubkey", client4.Pubkey, "client4IP", client4.CYOANetworkIP) + // Wait for client latency results. log.Info("==> Waiting for client latency results") err = client1.WaitForLatencyResults(t.Context(), devicePK1, 90*time.Second) require.NoError(t, err) err = client2.WaitForLatencyResults(t.Context(), devicePK2, 90*time.Second) require.NoError(t, err) + err = client3.WaitForLatencyResults(t.Context(), devicePK2, 90*time.Second) + require.NoError(t, err) + err = client4.WaitForLatencyResults(t.Context(), devicePK2, 90*time.Second) + require.NoError(t, err) log.Info("--> Finished waiting for client latency results") log.Info("==> Add clients to user Access Pass") - // Set access pass for the client. _, err = dn.Manager.Exec(t.Context(), []string{"bash", "-c", "doublezero access-pass set --accesspass-type prepaid --client-ip " + client1.CYOANetworkIP + " --user-payer " + client1.Pubkey}) require.NoError(t, err) - // Set access pass for the client. _, err = dn.Manager.Exec(t.Context(), []string{"bash", "-c", "doublezero access-pass set --accesspass-type prepaid --client-ip " + client2.CYOANetworkIP + " --user-payer " + client2.Pubkey}) require.NoError(t, err) + _, err = dn.Manager.Exec(t.Context(), []string{"bash", "-c", "doublezero access-pass set --accesspass-type prepaid --client-ip " + client3.CYOANetworkIP + " --user-payer " + client3.Pubkey}) + require.NoError(t, err) + _, err = dn.Manager.Exec(t.Context(), []string{"bash", "-c", "doublezero access-pass set --accesspass-type prepaid --client-ip " + client4.CYOANetworkIP + " --user-payer " + client4.Pubkey}) + require.NoError(t, err) log.Info("--> Clients added to user Access Pass") // Run IBRL workflow test. if !t.Run("ibrl", func(t *testing.T) { - runMultiClientIBRLWorkflowTest(t, log, dn, client1, client2, deviceCode1, deviceCode2) + runMultiClientIBRLWorkflowTest(t, log, dn, client1, client2, client3, client4, deviceCode1, deviceCode2) }) { t.Fail() } @@ -155,7 +207,7 @@ func TestE2E_MultiClient(t *testing.T) { } } -func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.Devnet, client1 *devnet.Client, client2 *devnet.Client, deviceCode1 string, deviceCode2 string) { +func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.Devnet, client1 *devnet.Client, client2 *devnet.Client, client3 *devnet.Client, client4 *devnet.Client, deviceCode1 string, deviceCode2 string) { // Check that the clients are disconnected and do not have a DZ IP allocated. log.Info("==> Checking that the clients are disconnected and do not have a DZ IP allocated") status, err := client1.GetTunnelStatus(t.Context()) @@ -168,23 +220,53 @@ func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.D require.Len(t, status, 1, status) require.Nil(t, status[0].DoubleZeroIP, status) require.Equal(t, devnet.ClientSessionStatusDisconnected, status[0].DoubleZeroStatus.SessionStatus) + status, err = client3.GetTunnelStatus(t.Context()) + require.NoError(t, err) + require.Len(t, status, 1, status) + require.Nil(t, status[0].DoubleZeroIP, status) + require.Equal(t, devnet.ClientSessionStatusDisconnected, status[0].DoubleZeroStatus.SessionStatus) + status, err = client4.GetTunnelStatus(t.Context()) + require.NoError(t, err) + require.Len(t, status, 1, status) + require.Nil(t, status[0].DoubleZeroIP, status) + require.Equal(t, devnet.ClientSessionStatusDisconnected, status[0].DoubleZeroStatus.SessionStatus) log.Info("--> Confirmed clients are disconnected and do not have a DZ IP allocated") // Connect client1 in IBRL mode to device1 (xlax exchange). log.Info("==> Connecting client1 in IBRL mode to device1") _, err = client1.Exec(t.Context(), []string{"doublezero", "connect", "ibrl", "--client-ip", client1.CYOANetworkIP, "--device", deviceCode1}) require.NoError(t, err) - err = client1.WaitForTunnelUp(t.Context(), 90*time.Second) - require.NoError(t, err) log.Info("--> Client1 connected in IBRL mode to device1") // Connect client2 in IBRL mode to device2 (xewr exchange). log.Info("==> Connecting client2 in IBRL mode to device2") _, err = client2.Exec(t.Context(), []string{"doublezero", "connect", "ibrl", "--client-ip", client2.CYOANetworkIP, "--device", deviceCode2}) require.NoError(t, err) + log.Info("--> Client2 connected in IBRL mode to device2") + + // Connect client3 in IBRL mode to device2 (xewr exchange). + log.Info("==> Connecting client3 in IBRL mode to device2") + _, err = client3.Exec(t.Context(), []string{"doublezero", "connect", "ibrl", "--client-ip", client3.CYOANetworkIP, "--device", deviceCode2}) + require.NoError(t, err) + log.Info("--> Client3 connected in IBRL mode to device2") + + // Connect client4 in IBRL mode to device2 (xewr exchange). + log.Info("==> Connecting client4 in IBRL mode to device2") + _, err = client4.Exec(t.Context(), []string{"doublezero", "connect", "ibrl", "--client-ip", client4.CYOANetworkIP, "--device", deviceCode2}) + require.NoError(t, err) + log.Info("--> Client4 connected in IBRL mode to device2") + + // Wait for all clients to be connected. + log.Info("==> Waiting for all clients to be connected") + err = client1.WaitForTunnelUp(t.Context(), 90*time.Second) + require.NoError(t, err) err = client2.WaitForTunnelUp(t.Context(), 90*time.Second) require.NoError(t, err) - log.Info("--> Client2 connected in IBRL mode to device2") + err = client3.WaitForTunnelUp(t.Context(), 90*time.Second) + require.NoError(t, err) + err = client4.WaitForTunnelUp(t.Context(), 90*time.Second) + require.NoError(t, err) + log.Info("--> All clients connected") // Check that the clients have a DZ IP equal to their client IP when not configured to use an allocated IP. log.Info("==> Checking that the clients have a DZ IP as public IP when not configured to use an allocated IP") @@ -198,37 +280,269 @@ func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.D client2DZIP := status[0].DoubleZeroIP.String() require.NoError(t, err) require.Equal(t, client2.CYOANetworkIP, client2DZIP) + status, err = client3.GetTunnelStatus(t.Context()) + require.Len(t, status, 1) + client3DZIP := status[0].DoubleZeroIP.String() + require.NoError(t, err) + require.Equal(t, client3.CYOANetworkIP, client3DZIP) + status, err = client4.GetTunnelStatus(t.Context()) + require.Len(t, status, 1) + client4DZIP := status[0].DoubleZeroIP.String() + require.NoError(t, err) + require.Equal(t, client4.CYOANetworkIP, client4DZIP) log.Info("--> Clients have a DZ IP as public IP when not configured to use an allocated IP") // Check that the clients have routes to each other. log.Info("==> Checking that the clients have routes to each other") + + // Client1 (on DZD1) should have routes to client2 (on DZD2) and client3 (on DZD2). + log.Info("--> Client1 (on DZD1) should have routes to client2 (on DZD2) and client3 (on DZD2)") require.Eventually(t, func() bool { output, err := client1.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) - if err != nil { - return false - } - return strings.Contains(string(output), client2DZIP) + require.NoError(t, err) + return strings.Contains(string(output), client2DZIP) && strings.Contains(string(output), client3DZIP) }, 120*time.Second, 5*time.Second, "client1 should have route to client2") + + // Client2 (on DZD2) should have routes to client1 (on DZD1) only. + log.Info("--> Client2 (on DZD2) should have routes to client1 (on DZD1) only") require.Eventually(t, func() bool { output, err := client2.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) - if err != nil { - return false - } + require.NoError(t, err) return strings.Contains(string(output), client1DZIP) }, 120*time.Second, 5*time.Second, "client2 should have route to client1") + + // Client3 (on DZD2) should have routes to client1 (on DZD1) only. + log.Info("--> Client3 (on DZD2) should have routes to client1 (on DZD1) only") + require.Eventually(t, func() bool { + output, err := client3.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) + require.NoError(t, err) + return strings.Contains(string(output), client1DZIP) + }, 120*time.Second, 5*time.Second, "client3 should have route to client1") + + // Client2 (on DZD2) should not have routes to client3 (on DZD2). + log.Info("--> Client2 (on DZD2) should not have routes to client3 (on DZD2)") + require.Never(t, func() bool { + output, err := client2.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) + require.NoError(t, err) + return strings.Contains(string(output), client3DZIP) + }, 1*time.Second, 100*time.Millisecond, "client2 should not have route to client3") + + // Client3 (on DZD2) should not have routes to client2 (on DZD2). + log.Info("--> Client3 (on DZD2) should not have routes to client2 (on DZD2)") + require.Never(t, func() bool { + output, err := client3.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) + require.NoError(t, err) + return strings.Contains(string(output), client2DZIP) + }, 1*time.Second, 100*time.Millisecond, "client3 should not have route to client2") + + // Client4 (on DZD2) should have route to client1 (on DZD1). + log.Info("--> Client4 (on DZD2) should have route to client1 (on DZD1)") + require.Eventually(t, func() bool { + output, err := client4.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) + require.NoError(t, err) + return strings.Contains(string(output), client1DZIP) + }, 120*time.Second, 5*time.Second, "client4 should have routes to client1") + log.Info("--> Clients have routes to each other") // Check that the clients can reach each other via their DZ IPs, via ping. log.Info("==> Checking that the clients can reach each other via their DZ IPs") + + // Client1 can reach client2 and client3 over doublezero0 interface. _, err = client1.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client2DZIP, "-W", "1"}) require.NoError(t, err) + _, err = client1.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client3DZIP, "-W", "1"}) + require.NoError(t, err) + + // Client2 can reach client1 over doublezero0 interface. _, err = client2.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client1DZIP, "-W", "1"}) require.NoError(t, err) + // Client2 cannot reach client3 over doublezero0 interface. + _, err = client2.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client3DZIP, "-W", "1"}, docker.NoPrintOnError()) + require.Error(t, err) + + // Client3 can reach client1 over doublezero0 interface. + _, err = client3.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client1DZIP, "-W", "1"}) + require.NoError(t, err) + // Client3 cannot reach client2 over doublezero0 interface. + _, err = client3.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client2DZIP, "-W", "1"}, docker.NoPrintOnError()) + require.Error(t, err) + + // Client4 cannot reach client1 over doublezero0 interface, since client1 does not have a route to client4 and so replies over eth0/1. + _, err = client4.Exec(t.Context(), []string{"ping", "-I", "doublezero0", "-c", "3", client1DZIP, "-W", "1"}, docker.NoPrintOnError()) + require.Error(t, err) + + // Client1 can reach client2 and client3 without specifying the interface. _, err = client1.Exec(t.Context(), []string{"ping", "-c", "3", client2DZIP, "-W", "1"}) require.NoError(t, err) + _, err = client1.Exec(t.Context(), []string{"ping", "-c", "3", client3DZIP, "-W", "1"}) + require.NoError(t, err) + + // Client2 can reach client1 and client3 without specifying the interface. _, err = client2.Exec(t.Context(), []string{"ping", "-c", "3", client1DZIP, "-W", "1"}) require.NoError(t, err) + _, err = client2.Exec(t.Context(), []string{"ping", "-c", "3", client3DZIP, "-W", "1"}) + require.NoError(t, err) + + // Client3 can reach client1 and client2 without specifying the interface. + _, err = client3.Exec(t.Context(), []string{"ping", "-c", "3", client1DZIP, "-W", "1"}) + require.NoError(t, err) + _, err = client3.Exec(t.Context(), []string{"ping", "-c", "3", client2DZIP, "-W", "1"}) + require.NoError(t, err) + + // Client4 can reach client1, client2, and client3 without specifying the interface. + _, err = client4.Exec(t.Context(), []string{"ping", "-c", "3", client1DZIP, "-W", "1"}) + require.NoError(t, err) + _, err = client4.Exec(t.Context(), []string{"ping", "-c", "3", client2DZIP, "-W", "1"}) + require.NoError(t, err) + _, err = client4.Exec(t.Context(), []string{"ping", "-c", "3", client3DZIP, "-W", "1"}) + require.NoError(t, err) + log.Info("--> Clients can reach each other via their DZ IPs") + // --- Route liveness block matrix --- + log.Info("==> Route liveness: block each client independently and require expected route behavior") + const wait = 120 * time.Second + const tick = 5 * time.Second + + doRouteLivenessBaseline := func() { + t.Helper() + // Baseline should already be: + // - c1 has routes to c2,c3 + // - c2 has route to c1, NOT to c3 + // - c3 has route to c1, NOT to c2 + requireEventuallyRoute(t, client1, client2DZIP, true, wait, tick, "baseline c1->c2") + requireEventuallyRoute(t, client1, client3DZIP, true, wait, tick, "baseline c1->c3") + requireEventuallyRoute(t, client2, client1DZIP, true, wait, tick, "baseline c2->c1") + requireEventuallyRoute(t, client2, client3DZIP, false, wait, tick, "baseline c2->c3") + requireEventuallyRoute(t, client3, client1DZIP, true, wait, tick, "baseline c3->c1") + requireEventuallyRoute(t, client3, client2DZIP, false, wait, tick, "baseline c3->c2") + + // Baseline liveness packets (dz0 present where peers exist; never on eth0/eth1) + requireUDPLivenessOnDZ0(t, client1, client2DZIP, true, "baseline c1 liveness packets -> c2 on dz0") + requireUDPLivenessOnDZ0(t, client1, client3DZIP, true, "baseline c1 liveness packets -> c3 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "baseline c2 liveness packets -> c1 on dz0 (disabled = routing-agnostic)") + requireUDPLivenessOnDZ0(t, client2, client3DZIP, false, "baseline c2 liveness packets -> c3 none") + requireUDPLivenessOnDZ0(t, client3, client1DZIP, true, "baseline c3 liveness packets -> c1 on dz0") + requireUDPLivenessOnDZ0(t, client3, client2DZIP, false, "baseline c3 liveness packets -> c2 none") + requireNoUDPLivenessOnEth01(t, client1, client2DZIP, "baseline no c1 liveness packets on eth0/1 -> c2") + requireNoUDPLivenessOnEth01(t, client1, client3DZIP, "baseline no c1 liveness packets on eth0/1 -> c3") + requireNoUDPLivenessOnEth01(t, client2, client1DZIP, "baseline no c2 liveness packets on eth0/1 -> c1") + requireNoUDPLivenessOnEth01(t, client3, client1DZIP, "baseline no c3 liveness packets on eth0/1 -> c1") + } + + doRouteLivenessCaseA := func(pass int) { + t.Helper() + log.Info("==> Route liveness Case A (block client1)", "pass", pass) + blockUDPLiveness(t, client1) + + // Routes + requireEventuallyRoute(t, client1, client2DZIP, false, wait, tick, "pass %d: block c1: c1->c2 removed") + requireEventuallyRoute(t, client1, client3DZIP, false, wait, tick, "pass %d: block c1: c1->c3 removed") + requireEventuallyRoute(t, client1, client4DZIP, false, wait, tick, "pass %d: block c1: c1->c4 removed") + requireEventuallyRoute(t, client3, client1DZIP, false, wait, tick, "pass %d: block c1: c3->c1 removed") + requireEventuallyRoute(t, client2, client1DZIP, true, wait, tick, "pass %d: block c1: c2->c1 remains") + requireEventuallyRoute(t, client2, client3DZIP, false, wait, tick, "pass %d: block c1: c2->c3 remains absent") + requireEventuallyRoute(t, client3, client2DZIP, false, wait, tick, "pass %d: block c1: c3->c2 remains absent") + requireEventuallyRoute(t, client4, client1DZIP, true, wait, tick, "pass %d: block c1: c4->c1 remains") + + // Liveness packets on doublezero0, none on eth0/1 + requireUDPLivenessOnDZ0(t, client1, client2DZIP, true, "pass %d: block c1: no c1 liveness packets -> c2 on dz0") + requireUDPLivenessOnDZ0(t, client1, client3DZIP, true, "pass %d: block c1: no c1 liveness packets -> c3 on dz0") + requireUDPLivenessOnDZ0(t, client3, client1DZIP, true, "pass %d: block c1: no c3 liveness packets -> c1 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: block c1: c2 still shows liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client2DZIP, "pass %d: block c1: no c1 liveness packets on eth0/1 -> c2") + requireNoUDPLivenessOnEth01(t, client1, client3DZIP, "pass %d: block c1: no c1 liveness packets on eth0/1 -> c3") + requireNoUDPLivenessOnEth01(t, client3, client1DZIP, "pass %d: block c1: no c3 liveness packets on eth0/1 -> c1") + requireNoUDPLivenessOnEth01(t, client2, client1DZIP, "pass %d: block c1: no c2 liveness packets on eth0/1 -> c1") + + unblockUDPLiveness(t, client1) + + // Routes restored + requireEventuallyRoute(t, client1, client2DZIP, true, wait, tick, "pass %d: unblock c1: c1->c2 restored") + requireEventuallyRoute(t, client1, client3DZIP, true, wait, tick, "pass %d: unblock c1: c1->c3 restored") + requireEventuallyRoute(t, client3, client1DZIP, true, wait, tick, "pass %d: unblock c1: c3->c1 restored") + + // Liveness packets on dz0; none on eth0/1 + requireUDPLivenessOnDZ0(t, client1, client2DZIP, true, "pass %d: unblock c1: c1 liveness packets -> c2 on dz0") + requireUDPLivenessOnDZ0(t, client1, client3DZIP, true, "pass %d: unblock c1: c1 liveness packets -> c3 on dz0") + requireUDPLivenessOnDZ0(t, client3, client1DZIP, true, "pass %d: unblock c1: c3 liveness packets -> c1 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: unblock c1: c2 liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client2DZIP, "pass %d: unblock c1: none on eth0/1 -> c2") + requireNoUDPLivenessOnEth01(t, client1, client3DZIP, "pass %d: unblock c1: none on eth0/1 -> c3") + } + + doRouteLivenessCaseB := func(pass int) { + t.Helper() + log.Info("==> Route liveness Case B (block client2)", "pass", pass) + blockUDPLiveness(t, client2) + + // Routes + requireEventuallyRoute(t, client1, client2DZIP, false, wait, tick, "pass %d: block c2: c1->c2 removed") + requireEventuallyRoute(t, client2, client1DZIP, true, wait, tick, "pass %d: block c2: c2->c1 remains") + requireEventuallyRoute(t, client2, client3DZIP, false, wait, tick, "pass %d: block c2: c2->c3 remains absent") + requireEventuallyRoute(t, client3, client2DZIP, false, wait, tick, "pass %d: block c2: c3->c2 remains absent") + requireEventuallyRoute(t, client1, client3DZIP, true, wait, tick, "pass %d: block c2: c1->c3 remains") + requireEventuallyRoute(t, client3, client1DZIP, true, wait, tick, "pass %d: block c2: c3->c1 remains") + + // Liveness packets + requireUDPLivenessOnDZ0(t, client1, client2DZIP, true, "pass %d: block c2: c1 liveness packets -> c2 on dz0 (route withdrawn)") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: block c2: c2 still shows liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client2DZIP, "pass %d: block c2: no c1 liveness packets on eth0/1 -> c2") + requireNoUDPLivenessOnEth01(t, client2, client1DZIP, "pass %d: block c2: no c2 liveness packets on eth0/1 -> c1") + + unblockUDPLiveness(t, client2) + + // Routes restored + requireEventuallyRoute(t, client1, client2DZIP, true, wait, tick, "pass %d: unblock c2: c1->c2 restored") + + // Liveness packets on dz0; none on eth0/1 + requireUDPLivenessOnDZ0(t, client1, client2DZIP, true, "pass %d: unblock c2: c1 liveness packets -> c2 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: unblock c2: c2 liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client2DZIP, "pass %d: unblock c2: none on eth0/1 -> c2") + } + + doRouteLivenessCaseC := func(pass int) { + t.Helper() + log.Info("==> Route liveness Case C (block client3)", "pass", pass) + blockUDPLiveness(t, client3) + + // Routes + requireEventuallyRoute(t, client1, client3DZIP, false, wait, tick, "pass %d: block c3: c1->c3 removed") + requireEventuallyRoute(t, client3, client1DZIP, false, wait, tick, "pass %d: block c3: c3->c1 removed") + requireEventuallyRoute(t, client1, client2DZIP, true, wait, tick, "pass %d: block c3: c1->c2 remains") + requireEventuallyRoute(t, client2, client1DZIP, true, wait, tick, "pass %d: block c3: c2->c1 remains") + requireEventuallyRoute(t, client2, client3DZIP, false, wait, tick, "pass %d: block c3: c2->c3 remains absent") + requireEventuallyRoute(t, client3, client2DZIP, false, wait, tick, "pass %d: block c3: c3->c2 remains absent") + + // Liveness packets + requireUDPLivenessOnDZ0(t, client1, client3DZIP, true, "pass %d: block c3: c1 liveness packets -> c3 on dz0") + requireUDPLivenessOnDZ0(t, client3, client1DZIP, true, "pass %d: block c3: c3 liveness packets -> c1 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: block c3: c2 still shows liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client3DZIP, "pass %d: block c3: no c1 liveness packets on eth0/1 -> c3") + requireNoUDPLivenessOnEth01(t, client3, client1DZIP, "pass %d: block c3: no c3 liveness packets on eth0/1 -> c1") + requireNoUDPLivenessOnEth01(t, client2, client1DZIP, "pass %d: block c3: no c2 liveness packets on eth0/1 -> c1") + + unblockUDPLiveness(t, client3) + + // Routes restored + requireEventuallyRoute(t, client1, client3DZIP, true, wait, tick, "pass %d: unblock c3: c1->c3 restored") + requireEventuallyRoute(t, client3, client1DZIP, true, wait, tick, "pass %d: unblock c3: c3->c1 restored") + + // Liveness packets on dz0; none on eth0/1 + requireUDPLivenessOnDZ0(t, client1, client3DZIP, true, "pass %d: unblock c3: c1 liveness packets -> c3 on dz0") + requireUDPLivenessOnDZ0(t, client3, client1DZIP, true, "pass %d: unblock c3: c3 liveness packets -> c1 on dz0") + requireUDPLivenessOnDZ0(t, client2, client1DZIP, true, "pass %d: unblock c3: c2 liveness packets -> c1 on dz0") + requireNoUDPLivenessOnEth01(t, client1, client3DZIP, "pass %d: unblock c3: none on eth0/1 -> c3") + } + + // Run the matrix multiple times to check multiple iterations of the workflow. + doRouteLivenessBaseline() + doRouteLivenessCaseA(1) + doRouteLivenessCaseB(1) + doRouteLivenessCaseC(1) + doRouteLivenessCaseA(2) + + log.Info("--> Route liveness block matrix (repeat) complete") // Disconnect client1. log.Info("==> Disconnecting client1 from IBRL") @@ -242,6 +556,18 @@ func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.D require.NoError(t, err) log.Info("--> Client2 disconnected from IBRL") + // Disconnect client3. + log.Info("==> Disconnecting client3 from IBRL") + _, err = client3.Exec(t.Context(), []string{"doublezero", "disconnect", "--client-ip", client3.CYOANetworkIP}) + require.NoError(t, err) + log.Info("--> Client3 disconnected from IBRL") + + // Disconnect client4. + log.Info("==> Disconnecting client4 from IBRL") + _, err = client4.Exec(t.Context(), []string{"doublezero", "disconnect", "--client-ip", client4.CYOANetworkIP}) + require.NoError(t, err) + log.Info("--> Client4 disconnected from IBRL") + // Wait for users to be deleted onchain. log.Info("==> Waiting for users to be deleted onchain") serviceabilityClient, err := dn.Ledger.GetServiceabilityClient() @@ -259,6 +585,10 @@ func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.D require.NoError(t, err) err = client2.WaitForTunnelDisconnected(t.Context(), 60*time.Second) require.NoError(t, err) + err = client3.WaitForTunnelDisconnected(t.Context(), 60*time.Second) + require.NoError(t, err) + err = client4.WaitForTunnelDisconnected(t.Context(), 60*time.Second) + require.NoError(t, err) status, err = client1.GetTunnelStatus(t.Context()) require.NoError(t, err) require.Len(t, status, 1, status) @@ -267,6 +597,15 @@ func runMultiClientIBRLWorkflowTest(t *testing.T, log *slog.Logger, dn *devnet.D require.NoError(t, err) require.Len(t, status, 1, status) require.Nil(t, status[0].DoubleZeroIP, status) + status, err = client3.GetTunnelStatus(t.Context()) + require.NoError(t, err) + require.Len(t, status, 1, status) + require.Nil(t, status[0].DoubleZeroIP, status) + status, err = client4.GetTunnelStatus(t.Context()) + require.NoError(t, err) + require.Len(t, status, 1, status) + require.Nil(t, status[0].DoubleZeroIP, status) + require.Equal(t, devnet.ClientSessionStatusDisconnected, status[0].DoubleZeroStatus.SessionStatus) log.Info("--> Confirmed clients are disconnected and do not have a DZ IP allocated") } @@ -374,3 +713,67 @@ func runMultiClientIBRLWithAllocatedIPWorkflowTest(t *testing.T, log *slog.Logge require.Equal(t, devnet.ClientSessionStatusDisconnected, status[0].DoubleZeroStatus.SessionStatus) log.Info("--> Confirmed clients are disconnected and do not have a DZ IP allocated") } + +func blockUDPLiveness(t *testing.T, c *devnet.Client) { + t.Helper() + cmd := []string{"iptables", "-A", "INPUT", "-p", "udp", "--dport", strconv.Itoa(routeLivenessPort), "-j", "DROP"} + _, err := c.Exec(t.Context(), cmd) + require.NoError(t, err) +} + +func unblockUDPLiveness(t *testing.T, c *devnet.Client) { + t.Helper() + cmd := []string{"iptables", "-D", "INPUT", "-p", "udp", "--dport", strconv.Itoa(routeLivenessPort), "-j", "DROP"} + _, err := c.Exec(t.Context(), cmd) + require.NoError(t, err) +} + +func hasRoute(t *testing.T, from *devnet.Client, ip string) bool { + t.Helper() + out, err := from.Exec(t.Context(), []string{"ip", "r", "list", "dev", "doublezero0"}) + require.NoError(t, err) + return strings.Contains(string(out), ip) +} + +func requireEventuallyRoute(t *testing.T, from *devnet.Client, ip string, want bool, wait, tick time.Duration, msg string) { + t.Helper() + require.Eventually(t, func() bool { return hasRoute(t, from, ip) == want }, wait, tick, msg) +} + +func requireUDPLivenessOnDZ0(t *testing.T, c *devnet.Client, host string, want bool, msg string) { + t.Helper() + n, err := udpLivenessCaptureCount(t, c, []string{"doublezero0"}, host) + require.NoError(t, err) + require.Equal(t, want, n > 0, msg) +} + +func requireNoUDPLivenessOnEth01(t *testing.T, c *devnet.Client, host string, msg string) { + t.Helper() + n, err := udpLivenessCaptureCount(t, c, []string{"eth0", "eth1"}, host) + require.NoError(t, err) + require.Equal(t, 0, n, msg) +} + +func udpLivenessCaptureCount(t *testing.T, c *devnet.Client, ifaces []string, host string) (int, error) { + t.Helper() + var iargs []string + for _, i := range ifaces { + iargs = append(iargs, "-i", i) + } + cmd := fmt.Sprintf(`tshark %s -a duration:1 -Y "not gre && ip.addr==%s && udp.port==%d"`, strings.Join(iargs, " "), host, routeLivenessPort) + args := append([]string{"bash", "-lc"}, cmd) + out, err := c.Exec(t.Context(), args) + require.NoError(t, err) + // Expect a line like: "9 packets captured" + s := string(out) + for line := range strings.SplitSeq(s, "\n") { + if strings.Contains(line, " packets captured") { + idx := strings.LastIndex(line, " packets captured") + numStr := strings.TrimSpace(line[:idx]) + n, err := strconv.Atoi(numStr) + require.NoError(t, err) + return n, nil + } + } + return 0, errors.New("no capture count found in output") +} diff --git a/e2e/user_ban_test.go b/e2e/user_ban_test.go index 20b4c0f3b..8053d28b5 100644 --- a/e2e/user_ban_test.go +++ b/e2e/user_ban_test.go @@ -161,13 +161,13 @@ func TestE2E_UserBan(t *testing.T) { // Run IBRL workflow test. if !t.Run("user-ban-ibrl", func(t *testing.T) { - runUserBanIBRLWorkflowTest(t, log, client1, client2, client3, dn, device2, deviceCode1, deviceCode2) + runUserBanIBRLWorkflowTest(t, log, client1, client2, client3, dn, deviceCode1, deviceCode2) }) { t.Fail() } } -func runUserBanIBRLWorkflowTest(t *testing.T, log *slog.Logger, client1 *devnet.Client, client2 *devnet.Client, client3 *devnet.Client, dn *devnet.Devnet, device2 *devnet.Device, deviceCode1 string, deviceCode2 string) { +func runUserBanIBRLWorkflowTest(t *testing.T, log *slog.Logger, client1 *devnet.Client, client2 *devnet.Client, client3 *devnet.Client, dn *devnet.Devnet, deviceCode1 string, deviceCode2 string) { // Check that the clients are disconnected and do not have a DZ IP allocated. log.Info("==> Checking that the clients are disconnected and do not have a DZ IP allocated") status, err := client1.GetTunnelStatus(t.Context()) diff --git a/go.mod b/go.mod index 56eb6b9f9..03d9df2f6 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,6 @@ require ( github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.2 github.com/spf13/cobra v1.10.1 - github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 github.com/testcontainers/testcontainers-go/modules/clickhouse v0.40.0 @@ -122,6 +121,7 @@ require ( github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/streamingfast/logging v0.0.0-20230608130331-f22c91403091 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect diff --git a/tools/uping/Makefile b/tools/uping/Makefile deleted file mode 100644 index 581a1d001..000000000 --- a/tools/uping/Makefile +++ /dev/null @@ -1,25 +0,0 @@ -PREFIX:=github.com/malbeclabs/doublezero/tools/uping -BUILD:=`git rev-parse --short HEAD` -LDFLAGS=-ldflags "-X=$(PREFIX)/build.Build=$(BUILD)" - -.PHONY: test -test: - go test -race -v -coverprofile coverage.out ./... - -.PHONY: lint -lint: - golangci-lint run -c ../../.golangci.yaml - -.PHONY: build -build: - CGO_ENABLED=0 go build -v $(LDFLAGS) -o bin/uping-send cmd/uping-send/main.go - CGO_ENABLED=0 go build -v $(LDFLAGS) -o bin/uping-recv cmd/uping-recv/main.go - -FUZZTIME ?= 10s -.PHONY: fuzz -fuzz: - @for f in $$(go test ./pkg/uping -list=Fuzz | grep '^Fuzz'); do \ - echo "==> Fuzzing $$f"; \ - go test ./pkg/uping -run=^$$ -fuzz=$$f -fuzztime=$(FUZZTIME) || exit 1; \ - done - diff --git a/tools/uping/README.md b/tools/uping/README.md deleted file mode 100644 index d6656efd6..000000000 --- a/tools/uping/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# uping - -Minimal Linux-only raw ICMP echo library and toolset for user-space liveness probing over doublezero interfaces, even when certain routes are not in the kernel routing table. - -## Components - -- **Listener**: Responds to ICMP echo requests on a specific interface and IPv4 address, providing consistent user-space replies for local or unroutable peers. -- **Sender**: Sends ICMP echo requests and measures round-trip times per interface, operating reliably even without kernel routing. Handles retries, timeouts, and context cancellation. - -## Example - -```bash -uping-recv --iface doublezero0 --ip 9.169.90.100 -uping-send --iface doublezero0 --src 9.169.90.100 --dst 9.169.90.110 -``` - -## Notes - -- IPv4 only -- Requires CAP_NET_RAW -- Socket egress/ingress is pinned to the selected interface diff --git a/tools/uping/cmd/uping-recv/main.go b/tools/uping/cmd/uping-recv/main.go deleted file mode 100644 index 3de1bf887..000000000 --- a/tools/uping/cmd/uping-recv/main.go +++ /dev/null @@ -1,89 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log/slog" - "net" - "os" - "os/signal" - "syscall" - "time" - - "github.com/malbeclabs/doublezero/tools/uping/pkg/uping" - "github.com/spf13/pflag" -) - -func main() { - var ( - iface string - ipStr string - timeout time.Duration - verbose bool - ) - - pflag.StringVarP(&iface, "iface", "i", "", "interface to bind for RX/TX (required)") - pflag.StringVarP(&ipStr, "ip", "p", "", "IPv4 source address on that interface (required)") - pflag.DurationVarP(&timeout, "timeout", "t", 3*time.Second, "poll timeout") - pflag.BoolVarP(&verbose, "verbose", "v", false, "enable verbose logs") - pflag.Parse() - - fail := func(msg string, code int) { - fmt.Fprintf(os.Stderr, "error: %s\n", msg) - pflag.Usage() - os.Exit(code) - } - if iface == "" { - fail("missing --iface", 2) - } - if ipStr == "" { - fail("missing --ip", 2) - } - if timeout <= 0 { - fail("--timeout must be > 0", 2) - } - - ip := mustIPv4(ipStr) - - // Raw sockets + SO_BINDTODEVICE need caps; require if iface provided (always here). - if err := uping.RequirePrivileges(true); err != nil { - fmt.Fprintf(os.Stderr, "privileges check failed: %v\n", err) - os.Exit(1) - } - - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) - - log.Info("uping-recv started", "iface", iface, "ip", ip.String(), "timeout", timeout) - - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - ln, err := uping.NewListener(uping.ListenerConfig{ - Logger: log, - Interface: iface, - IP: ip, - Timeout: timeout, - }) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to create listener: %v\n", err) - os.Exit(1) - } - - if err := ln.Listen(ctx); err != nil { - fmt.Fprintf(os.Stderr, "listen error: %v\n", err) - os.Exit(1) - } -} - -func mustIPv4(s string) net.IP { - ip := net.ParseIP(s).To4() - if ip == nil { - fmt.Fprintf(os.Stderr, "bad IPv4: %s\n", s) - os.Exit(2) - } - return ip -} diff --git a/tools/uping/cmd/uping-send/main.go b/tools/uping/cmd/uping-send/main.go deleted file mode 100644 index 9c0c2f2f4..000000000 --- a/tools/uping/cmd/uping-send/main.go +++ /dev/null @@ -1,115 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log/slog" - "net" - "os" - "os/signal" - "syscall" - "time" - - "github.com/malbeclabs/doublezero/tools/uping/pkg/uping" - "github.com/spf13/pflag" -) - -func main() { - var ( - iface string - src string - dst string - count int - timeout time.Duration - verbose bool - ) - - pflag.StringVarP(&iface, "iface", "i", "", "bind sender to this interface (required)") - pflag.StringVarP(&src, "src", "s", "", "source IPv4 address (required)") - pflag.StringVarP(&dst, "dst", "d", "", "destination IPv4 address (required)") - pflag.IntVarP(&count, "count", "c", 4, "number of echo requests to send (>0)") - pflag.DurationVarP(&timeout, "timeout", "t", 5*time.Second, "per-echo timeout (e.g. 800ms, 2s)") - pflag.BoolVarP(&verbose, "verbose", "v", false, "enable verbose logs") - pflag.Parse() - - if iface == "" { - fmt.Fprintln(os.Stderr, "error: --iface is required") - pflag.Usage() - os.Exit(2) - } - - if src == "" || dst == "" { - fmt.Fprintln(os.Stderr, "error: --src and --dst are required") - pflag.Usage() - os.Exit(2) - } - if count <= 0 { - fmt.Fprintln(os.Stderr, "error: --count must be > 0") - os.Exit(2) - } - if timeout <= 0 { - fmt.Fprintln(os.Stderr, "error: --timeout must be > 0") - os.Exit(2) - } - - srcIP := mustIPv4(src) - dstIP := mustIPv4(dst) - - if err := uping.RequirePrivileges(iface != ""); err != nil { - fmt.Fprintf(os.Stderr, "privileges check failed: %v\n", err) - os.Exit(1) - } - - level := slog.LevelInfo - if verbose { - level = slog.LevelDebug - } - log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) - - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - sender, err := uping.NewSender(uping.SenderConfig{ - Logger: log, - Interface: iface, - Source: srcIP, - }) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to create sender: %v\n", err) - os.Exit(1) - } - defer sender.Close() - - results, err := sender.Send(ctx, uping.SendConfig{ - Target: dstIP, - Count: count, - Timeout: timeout, - }) - if err != nil { - fmt.Fprintf(os.Stderr, "send error: %v\n", err) - os.Exit(1) - } - - allOK := true - for i, r := range results.Results { - seq := i + 1 - if r.Error != nil { - allOK = false - fmt.Printf("seq=%d error=%v\n", seq, r.Error) - continue - } - fmt.Printf("seq=%d rtt=%v\n", seq, r.RTT) - } - if !allOK { - os.Exit(1) - } -} - -func mustIPv4(s string) net.IP { - ip := net.ParseIP(s).To4() - if ip == nil { - fmt.Fprintf(os.Stderr, "bad IPv4: %s\n", s) - os.Exit(2) - } - return ip -} diff --git a/tools/uping/pkg/uping/fuzz_test.go b/tools/uping/pkg/uping/fuzz_test.go deleted file mode 100644 index 76f9c1361..000000000 --- a/tools/uping/pkg/uping/fuzz_test.go +++ /dev/null @@ -1,71 +0,0 @@ -//go:build linux - -package uping - -import ( - "encoding/binary" - "math/rand" - "testing" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" -) - -// Must not panic on arbitrary bytes (IPv4 or bare ICMP). -func FuzzUping_ValidateEchoReply_Malformed_NoPanic(f *testing.F) { - f.Add([]byte{}) - f.Add([]byte{0x45}) // minimal header byte - f.Add(make([]byte, 19)) // truncated IPv4 - f.Add([]byte{8, 0, 0, 0, 0, 0, 0}) // short ICMP - f.Fuzz(func(t *testing.T, pkt []byte) { - if len(pkt) > 1<<16 { - pkt = pkt[:1<<16] - } - _, _, _, _ = validateEchoReply(pkt, 0xBEEF, 1, 99) - }) -} - -// ICMP checksum property: set -> validates to zero; flip a byte -> non-zero. -func FuzzUping_ICMPChecksum_Roundtrip(f *testing.F) { - seed := fuzzEchoReply(0x1234, 7, 42, 8) - f.Add(seed) - f.Fuzz(func(t *testing.T, msg []byte) { - if len(msg) < 8 { - msg = append(msg, make([]byte, 8-len(msg))...) - } - if len(msg) > 2048 { - msg = msg[:2048] - } - binary.BigEndian.PutUint16(msg[2:], 0) - cs := icmpChecksum(msg) - binary.BigEndian.PutUint16(msg[2:], cs) - if icmpChecksum(msg) != 0 { - t.Fatalf("checksum not zero after set") - } - if len(msg) > 8 { - i := 8 + rand.Intn(len(msg)-8) - msg[i] ^= 0xFF - if icmpChecksum(msg) == 0 { - t.Fatalf("checksum still zero after flip") - } - } - }) -} - -// Bare helper: valid Echo Reply bytes. -func fuzzEchoReply(id, seq uint16, nonce uint64, extra int) []byte { - if extra < 0 { - extra = -extra - } - if extra > 1024 { - extra = 1024 - } - data := make([]byte, 8+extra) - binary.BigEndian.PutUint64(data[:8], nonce) - msg := &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, Code: 0, - Body: &icmp.Echo{ID: int(id), Seq: int(seq), Data: data}, - } - b, _ := msg.Marshal(nil) - return b -} diff --git a/tools/uping/pkg/uping/listener.go b/tools/uping/pkg/uping/listener.go deleted file mode 100644 index a21ccb601..000000000 --- a/tools/uping/pkg/uping/listener.go +++ /dev/null @@ -1,189 +0,0 @@ -//go:build linux - -package uping - -import ( - "context" - "fmt" - "log/slog" - "net" - "os" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" -) - -const defaultListenerTimeout = 1 * time.Second - -// ListenerConfig defines how the ICMP listener should bind and behave. -// Interface + IP pin the socket to a specific kernel network interface and address; Timeout bounds poll(). -type ListenerConfig struct { - Logger *slog.Logger - Interface string // required: Linux ifname (e.g. "eth0") - IP net.IP // required: IPv4 address on Interface - Timeout time.Duration // per-iteration poll timeout; 0 -> default -} - -func (cfg *ListenerConfig) Validate() error { - if cfg.Interface == "" { - return fmt.Errorf("interface is required") - } - if cfg.IP == nil || cfg.IP.To4() == nil { - return fmt.Errorf("IP must be an IPv4 address") - } - if cfg.Timeout == 0 { - cfg.Timeout = defaultListenerTimeout - } - if cfg.Timeout <= 0 { - return fmt.Errorf("timeout must be greater than 0") - } - return nil -} - -// Listener exposes a blocking receive/reply loop until ctx is done. -type Listener interface { - Listen(ctx context.Context) error -} - -type listener struct { - log *slog.Logger - cfg ListenerConfig - iface *net.Interface - ifIndex int - src4 net.IP // local IPv4 we will answer for (and source from) -} - -func NewListener(cfg ListenerConfig) (Listener, error) { - if err := cfg.Validate(); err != nil { - return nil, err - } - ifi, err := net.InterfaceByName(cfg.Interface) - if err != nil { - return nil, fmt.Errorf("lookup interface %q: %w", cfg.Interface, err) - } - return &listener{log: cfg.Logger, cfg: cfg, iface: ifi, ifIndex: ifi.Index, src4: cfg.IP.To4()}, nil -} - -func (l *listener) Listen(ctx context.Context) error { - // Instance tag helps spot duplicate listeners (pid/object address). - inst := fmt.Sprintf("%d/%p", os.Getpid(), l) - if l.log != nil { - l.log.Info("uping/recv: starting listener", "inst", inst, "iface", l.cfg.Interface, "src", l.src4) - } - - // Raw ICMPv4 via net.IPConn so we can pin to device and use control messages. - ipc, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: l.src4}) - if err != nil { - return fmt.Errorf("ListenIP: %w", err) - } - defer ipc.Close() - - // Wrap in ipv4.PacketConn so we can enable control messages (interface, dst). - ip4c := ipv4.NewPacketConn(ipc) - defer ip4c.Close() - if err := ip4c.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst, true); err != nil { - return fmt.Errorf("SetControlMessage: %w", err) - } - - // Pin the socket to the given interface for both RX and TX routing. - if err := bindToDevice(ipc, l.iface.Name); err != nil { - return fmt.Errorf("bind-to-device %q: %w", l.iface.Name, err) - } - - // Interrupt blocking reads immediately on ctx cancellation. - go func() { - <-ctx.Done() - _ = ipc.SetReadDeadline(time.Now().Add(-time.Hour)) - }() - - buf := make([]byte, 65535) - - for { - // Use the smaller of ctx deadline or fallback timeout to bound reads. - if ms := pollTimeoutMs(ctx, l.cfg.Timeout); ms < 0 { - _ = ipc.SetReadDeadline(time.Time{}) - } else { - _ = ipc.SetReadDeadline(time.Now().Add(time.Duration(ms) * time.Millisecond)) - } - - n, cm, raddr, err := ip4c.ReadFrom(buf) - if ne, ok := err.(net.Error); ok && ne.Timeout() { - if ctx.Err() != nil { - return nil - } - continue - } - if err != nil { - if ctx.Err() != nil { - return nil - } - if l.log != nil { - l.log.Debug("uping/recv: read error", "err", err) - } - continue - } - - // Enforce ingress interface and destination. - if cm == nil || cm.IfIndex != l.ifIndex || !cm.Dst.Equal(l.src4) { - continue - } - - m, err := icmp.ParseMessage(1, buf[:n]) - if err != nil { - continue - } - if m.Type != ipv4.ICMPTypeEcho { - continue - } - echo, ok := m.Body.(*icmp.Echo) - if !ok || echo == nil { - continue - } - - // Build ICMP echo-reply (type 0), mirror id/seq/payload. - reply := &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ID: echo.ID, Seq: echo.Seq, Data: echo.Data}, - } - wb, err := reply.Marshal(nil) - if err != nil { - continue - } - - // Send the reply; SO_BINDTODEVICE keeps egress on the bound interface. - dst := raddr.(*net.IPAddr) - if _, err := ip4c.WriteTo(wb, &ipv4.ControlMessage{IfIndex: l.ifIndex, Src: l.src4}, dst); err == nil { - if l.log != nil { - l.log.Info("uping/recv: replied", "inst", inst, "dst", dst.IP.String(), "id", echo.ID, "seq", echo.Seq, "iface", l.iface.Name, "src", l.src4) - } - } else if l.log != nil { - l.log.Debug("uping/recv: write failed", "err", err, "iface", l.iface.Name, "src", l.src4) - } - } -} - -// pollTimeoutMs returns a millisecond poll() timeout derived from ctx deadline -// or falls back to the provided duration. -1 means “infinite” for poll(). -func pollTimeoutMs(ctx context.Context, fallback time.Duration) int { - if dl, ok := ctx.Deadline(); ok { - rem := time.Until(dl) - if rem <= 0 { - return 0 - } - const max = int(^uint32(0) >> 1) - if rem > (1<<31-1)*time.Millisecond { - return max - } - return int(rem / time.Millisecond) - } - if fallback > 0 { - const max = int(^uint32(0) >> 1) - if fallback > (1<<31-1)*time.Millisecond { - return max - } - return int(fallback / time.Millisecond) - } - return -1 -} diff --git a/tools/uping/pkg/uping/listener_test.go b/tools/uping/pkg/uping/listener_test.go deleted file mode 100644 index 6d02bed58..000000000 --- a/tools/uping/pkg/uping/listener_test.go +++ /dev/null @@ -1,460 +0,0 @@ -//go:build linux - -package uping - -import ( - "context" - "crypto/rand" - "encoding/binary" - "net" - "testing" - "time" - - "github.com/stretchr/testify/require" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/sys/unix" -) - -// Ensures the listener pinned to loopback replies to echo requests and reports RTTs. -func TestUping_Listener_Loopback_Responds(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 200 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - - res, err := s.Send(ctx, SendConfig{ - Target: net.IPv4(127, 0, 0, 1), - Count: 2, - Timeout: 600 * time.Millisecond, - }) - require.NoError(t, err) - require.Len(t, res.Results, 2) - for i, r := range res.Results { - require.NoErrorf(t, r.Error, "i=%d", i) - require.Greater(t, r.RTT, time.Duration(0)) - } -} - -// Verifies the listener exits promptly when the context is cancelled. -func TestUping_Listener_ContextCancel_Exits(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 150 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(t.Context()) - done := make(chan struct{}) - go func() { _ = l.Listen(ctx); close(done) }() - time.Sleep(30 * time.Millisecond) - cancel() - - select { - case <-done: - case <-time.After(500 * time.Millisecond): - t.Fatal("listener did not exit after cancel") - } -} - -// Confirms non-echo ICMP is ignored and that subsequent valid echo still gets a reply. -func TestUping_Listener_Ignores_NonEcho_Then_Replies(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 200 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - // Inject a non-echo ICMP (dest unreachable) using ipv4.PacketConn. - c, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}) - require.NoError(t, err) - ip4c := ipv4.NewPacketConn(c) - defer func() { _ = ip4c.Close(); _ = c.Close() }() - _ = ip4c.SetTTL(64) - - nonEcho := &icmp.Message{Type: ipv4.ICMPTypeDestinationUnreachable, Code: 0, Body: &icmp.DstUnreach{}} - nb, err := nonEcho.Marshal(nil) - require.NoError(t, err) - _, err = ip4c.WriteTo(nb, &ipv4.ControlMessage{IfIndex: 1, Src: net.IPv4(127, 0, 0, 1)}, &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}) - require.NoError(t, err) - - // Now a real echo via our Sender should still get a reply. - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - - res, err := s.Send(ctx, SendConfig{ - Target: net.IPv4(127, 0, 0, 1), - Count: 1, - Timeout: 600 * time.Millisecond, - }) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) -} - -// Validates config error paths for missing iface/IP and invalid timeout. -func TestUping_ListenerConfig_Validate_Errors(t *testing.T) { - t.Parallel() - - _, err := NewListener(ListenerConfig{IP: net.IPv4(127, 0, 0, 1), Timeout: time.Second}) - require.Error(t, err) - - _, err = NewListener(ListenerConfig{Interface: "lo", Timeout: time.Second}) - require.Error(t, err) - _, err = NewListener(ListenerConfig{Interface: "lo", IP: net.IPv6loopback, Timeout: time.Second}) - require.Error(t, err) - - cfg := ListenerConfig{Interface: "lo", IP: net.IPv4(127, 0, 0, 1), Timeout: -time.Second} - require.Error(t, cfg.Validate()) -} - -// Exercises large ICMP payloads and ensures the listener continues to reply. -func TestUping_Listener_LargePayload(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 200 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - // Send a large echo request using ipv4.PacketConn to 127.0.0.1. - c, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}) - require.NoError(t, err) - ip4c := ipv4.NewPacketConn(c) - defer func() { _ = ip4c.Close(); _ = c.Close() }() - _ = ip4c.SetTTL(64) - - payload := make([]byte, 4096) - _, _ = rand.Read(payload) - msg := &icmp.Message{ - Type: ipv4.ICMPTypeEcho, Code: 0, - Body: &icmp.Echo{ID: 0x4242, Seq: 0x0102, Data: payload}, - } - wb, err := msg.Marshal(nil) - require.NoError(t, err) - _, err = ip4c.WriteTo(wb, &ipv4.ControlMessage{IfIndex: 1, Src: net.IPv4(127, 0, 0, 1)}, &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}) - require.NoError(t, err) - - // Confirm we still get a reply using the Sender path. - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) -} - -// Verifies truncated/invalid IPv4/ICMP inputs are ignored and normal operation resumes. -func TestUping_Listener_IgnoresTruncatedJunkAndKeepsWorking(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 150 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - // For malformed frames, use raw unix socket (ipv4.PacketConn won’t craft broken IP). - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP) - require.NoError(t, err) - defer unix.Close(fd) - - dst := &unix.SockaddrInet4{Addr: [4]byte{127, 0, 0, 1}} - - // Truncated IP header - require.NoError(t, unix.Sendto(fd, []byte{0x45, 0x00}, 0, dst)) - - // Non-ICMP protocol in IP header - ip := make([]byte, 20+8) - ip[0] = 0x45 - ip[9] = 6 - copy(ip[12:16], []byte{127, 0, 0, 1}) - copy(ip[16:20], []byte{127, 0, 0, 1}) - binary.BigEndian.PutUint16(ip[10:], icmpChecksum(ip[:20])) - require.NoError(t, unix.Sendto(fd, ip, 0, dst)) - - // Too-short ICMP payload - ip2 := make([]byte, 20+4) - ip2[0] = 0x45 - ip2[9] = 1 - copy(ip2[12:16], []byte{127, 0, 0, 1}) - copy(ip2[16:20], []byte{127, 0, 0, 1}) - binary.BigEndian.PutUint16(ip2[10:], icmpChecksum(ip2[:20])) - require.NoError(t, unix.Sendto(fd, ip2, 0, dst)) - - // Normal echo still works afterward. - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 600 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) -} - -// Ensures echo requests with bad ICMP checksums are ignored; normal echo still works afterward. -func TestUping_Listener_Ignores_BadICMPChecksum_Then_Replies(t *testing.T) { - t.Parallel() - requireRawSockets(t) - - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 200 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - // Craft an echo with an intentionally bad checksum; inject via raw unix. - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP) - require.NoError(t, err) - defer unix.Close(fd) - - payload := make([]byte, 64) - _, _ = rand.Read(payload) - req := make([]byte, 8+len(payload)) - req[0] = 8 - req[1] = 0 - binary.BigEndian.PutUint16(req[4:], 0xBEEF) - binary.BigEndian.PutUint16(req[6:], 0x0001) - copy(req[8:], payload) - sum := icmpChecksum(req) - sum ^= 0x00FF - binary.BigEndian.PutUint16(req[2:], sum) - - ip := make([]byte, 20+len(req)) - ip[0] = 0x45 - ip[9] = 1 - copy(ip[12:16], net.IPv4(127, 0, 0, 1).To4()) - copy(ip[16:20], net.IPv4(127, 0, 0, 1).To4()) - binary.BigEndian.PutUint16(ip[:20][10:], icmpChecksum(ip[:20])) - copy(ip[20:], req) - - require.NoError(t, unix.Sendto(fd, ip, 0, &unix.SockaddrInet4{Addr: [4]byte{127, 0, 0, 1}})) - - // Normal echo afterwards. - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) -} - -// Validates pollTimeoutMs against deadline/fallback edge cases and infinite mode. -func TestUping_Listener_pollTimeoutMs(t *testing.T) { - t.Parallel() - - { - ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) - defer cancel() - ms := pollTimeoutMs(ctx, 500*time.Millisecond) - require.InDelta(t, 50, ms, 25) - } - - { - ctx := t.Context() - ms := pollTimeoutMs(ctx, 123*time.Millisecond) - require.InDelta(t, 123, ms, 10) - } - - { - ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond) - time.Sleep(200 * time.Microsecond) - defer cancel() - ms := pollTimeoutMs(ctx, 5*time.Second) - require.Equal(t, 0, ms) - } - - { - ctx := t.Context() - ms := pollTimeoutMs(ctx, 0) - require.Equal(t, -1, ms) - } -} - -// Loopback listener; sender bound to a different (non-loopback) interface should NOT see replies. -func TestUping_Listener_RepliesStayOnLoopbackInterface(t *testing.T) { - t.Parallel() - requireRawSockets(t) // for listener (raw ICMP) - requirePingSocket(t) // for sender (ping datagram) - - // Start listener pinned to loopback. - l, err := NewListener(ListenerConfig{ - Interface: "lo", - IP: net.IPv4(127, 0, 0, 1), - Timeout: 150 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - go func() { _ = l.Listen(ctx) }() - time.Sleep(40 * time.Millisecond) - - // Find a non-loopback IPv4 + iface. - ip := pickNonLoopbackV4(t) - ifname := ifaceNameForIP(t, ip) - - // Sender is *not* on loopback; it should not receive the loopback reply. - sWAN, err := NewSender(SenderConfig{Source: ip, Interface: ifname}) - require.NoError(t, err) - defer sWAN.Close() - - res, err := sWAN.Send(ctx, SendConfig{ - Target: net.IPv4(127, 0, 0, 1), - Count: 1, - Timeout: 700 * time.Millisecond, - }) - // Either a transport-level error or a probe timeout is acceptable here. - if err == nil { - require.Len(t, res.Results, 1) - require.Error(t, res.Results[0].Error, "expected no reply across interfaces") - } -} - -// Verifies that the Listener replies to ICMP Echo Requests on the same non-loopback interface -// it’s bound to. -func TestUping_Listener_RepliesStayOnNonLoopbackInterface_InjectRequest(t *testing.T) { - t.Parallel() - requireRawSockets(t) // RAW needed for listener/inject/receive - - src := pickNonLoopbackV4(t) - ifname := ifaceNameForIP(t, src) - ifi, err := net.InterfaceByName(ifname) - require.NoError(t, err) - - // Start the listener - l, err := NewListener(ListenerConfig{ - Interface: ifname, - IP: src, - Timeout: 150 * time.Millisecond, - }) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) - defer cancel() - - errCh := make(chan error, 1) - go func() { errCh <- l.Listen(ctx) }() - select { - case e := <-errCh: - require.NoErrorf(t, e, "listener exited immediately") - case <-time.After(100 * time.Millisecond): - } - - // Build Echo request with deterministic ID/Seq - const echoID = 0xBEEF - const seq = 0x0001 - payload := []byte{1, 2, 3, 4, 5, 6, 7, 8} - reqBytes, err := (&icmp.Message{ - Type: ipv4.ICMPTypeEcho, Code: 0, - Body: &icmp.Echo{ID: echoID, Seq: seq, Data: payload}, - }).Marshal(nil) - require.NoError(t, err) - - // Injector: RAW ip4:icmp, pinned to (src, ifname) - injIP, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: src}) - require.NoError(t, err) - defer injIP.Close() - require.NoError(t, bindToDevice(injIP, ifname)) - inj := ipv4.NewPacketConn(injIP) - defer inj.Close() - require.NoError(t, inj.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst, true)) - - // Receiver: separate RAW ip4:icmp, pinned to (src, ifname) - rcvIP, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: src}) - require.NoError(t, err) - defer rcvIP.Close() - require.NoError(t, bindToDevice(rcvIP, ifname)) - rcv := ipv4.NewPacketConn(rcvIP) - defer rcv.Close() - require.NoError(t, rcv.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst, true)) - - // Inject request as if it arrived on that iface - cm := &ipv4.ControlMessage{IfIndex: ifi.Index, Src: src} - _, err = inj.WriteTo(reqBytes, cm, &net.IPAddr{IP: src}) - require.NoError(t, err, "failed to inject echo request") - - // Wait for the Echo reply on RAW receiver - _ = rcvIP.SetReadDeadline(time.Now().Add(1500 * time.Millisecond)) - buf := make([]byte, 4096) - for { - n, cmin, _, err := rcv.ReadFrom(buf) - if ne, ok := err.(net.Error); ok && ne.Timeout() { - t.Fatalf("timeout waiting for echo reply on %s (%s)", ifname, src) - } - require.NoError(t, err) - - // rcv gets full IPv4 packet; strip header if present - p := buf[:n] - if len(p) >= 20 && p[0]>>4 == 4 { - ihl := int(p[0]&0x0F) * 4 - if ihl < 20 || len(p) < ihl+8 { - continue - } - p = p[ihl:] - } - - rm, perr := icmp.ParseMessage(1, p) - if perr != nil || rm.Type != ipv4.ICMPTypeEchoReply { - continue - } - if echo, ok := rm.Body.(*icmp.Echo); ok && echo != nil && echo.ID == echoID && echo.Seq == seq { - if cmin == nil || cmin.IfIndex == 0 { - t.Skip("kernel did not provide PKTINFO; cannot verify interface confinement") - } - require.Equal(t, ifi.Index, cmin.IfIndex, "reply arrived on wrong iface") - return // success - } - } -} diff --git a/tools/uping/pkg/uping/privileges.go b/tools/uping/pkg/uping/privileges.go deleted file mode 100644 index 3c0192b4b..000000000 --- a/tools/uping/pkg/uping/privileges.go +++ /dev/null @@ -1,69 +0,0 @@ -package uping - -import ( - "bufio" - "errors" - "fmt" - "os" - "strconv" - "strings" -) - -const ( - CAP_NET_ADMIN = 12 - CAP_NET_RAW = 13 -) - -// RequirePrivileges checks: root OR CAP_NET_RAW (and CAP_NET_ADMIN if binding to device). -func RequirePrivileges(bindingToIface bool) error { - if os.Geteuid() == 0 { - return nil - } - rawOK, err := hasCap(CAP_NET_RAW) - if err != nil { - return err - } - if !rawOK { - return fmt.Errorf("requires CAP_NET_RAW (or root). grant with: sudo setcap cap_net_raw+ep /path/to/uping-send (and /path/to/uping-recv)") - } - if bindingToIface { - adminOK, err := hasCap(CAP_NET_ADMIN) - if err != nil { - return err - } - if !adminOK { - return fmt.Errorf("SO_BINDTODEVICE typically requires CAP_NET_ADMIN. grant with: sudo setcap cap_net_admin+ep /path/to/uping-send (and /path/to/uping-recv)") - } - } - return nil -} - -func hasCap(bit int) (bool, error) { - f, err := os.Open("/proc/self/status") - if err != nil { - return false, err - } - defer f.Close() - - var capEffStr string - sc := bufio.NewScanner(f) - for sc.Scan() { - line := sc.Text() - if strings.HasPrefix(line, "CapEff:") { - fields := strings.Fields(line) - if len(fields) >= 2 { - capEffStr = fields[1] - break - } - } - } - if capEffStr == "" { - return false, errors.New("CapEff not found in /proc/self/status") - } - - val, err := strconv.ParseUint(capEffStr, 16, 64) - if err != nil { - return false, err - } - return (val & (1 << uint(bit))) != 0, nil -} diff --git a/tools/uping/pkg/uping/sender.go b/tools/uping/pkg/uping/sender.go deleted file mode 100644 index c57887c68..000000000 --- a/tools/uping/pkg/uping/sender.go +++ /dev/null @@ -1,516 +0,0 @@ -//go:build linux - -package uping - -import ( - "context" - "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "log/slog" - "net" - "os" - "sync" - "time" - - "syscall" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/sys/unix" -) - -// Defaults for a short, responsive probing loop. -const ( - defaultSenderCount = 3 - defaultSenderTimeout = 3 * time.Second - maxPollSlice = 200 * time.Millisecond // cap per-Recv block to avoid overshooting deadlines -) - -// SenderConfig configures the raw-ICMP sender. -// Both Interface and Source are REQUIRED and must be IPv4-capable. -type SenderConfig struct { - Logger *slog.Logger // optional - Interface string // required: interface name; used to resolve ifindex for PKTINFO - Source net.IP // required: IPv4 source address used as Spec_dst - NewEchoIDFunc func() (uint16, error) // optional: function to generate a unique echo ID (default: random) -} - -// Validate enforces required fields and IPv4-ness. -func (cfg *SenderConfig) Validate() error { - if cfg.Interface == "" { - return fmt.Errorf("interface is required") - } - if cfg.Source == nil || cfg.Source.To4() == nil { - return fmt.Errorf("source must be a valid IPv4 address") - } - if cfg.NewEchoIDFunc == nil { - cfg.NewEchoIDFunc = randomEchoID - } - return nil -} - -// SendConfig describes a single multi-probe operation. -type SendConfig struct { - Target net.IP // required: IPv4 destination - Count int // number of probes; defaulted if zero - Timeout time.Duration // per-probe absolute timeout; defaulted if zero -} - -func (cfg *SendConfig) Validate() error { - if cfg.Count == 0 { - cfg.Count = defaultSenderCount - } - if cfg.Count <= 0 { - return fmt.Errorf("count must be greater than 0") - } - if cfg.Timeout == 0 { - cfg.Timeout = defaultSenderTimeout - } - if cfg.Timeout <= 0 { - return fmt.Errorf("timeout must be greater than 0") - } - return nil -} - -func randomEchoID() (uint16, error) { - var idb [2]byte - if _, err := rand.Read(idb[:]); err != nil { - return 0, fmt.Errorf("rand echo id: %w", err) - } - pid := binary.BigEndian.Uint16(idb[:]) - return pid, nil -} - -// SendResults is a bag of per-probe results; Failed() indicates any error occurred. -type SendResults struct{ Results []SendResult } - -func (rs *SendResults) Failed() bool { - for _, r := range rs.Results { - if r.Error != nil { - return true - } - } - return false -} - -// SendResult records the RTT (on success) or the error (on failure) for a single probe. -type SendResult struct { - RTT time.Duration - Error error -} - -// Sender exposes the echo send/wait API. -type Sender interface { - Send(ctx context.Context, cfg SendConfig) (*SendResults, error) - Close() error -} - -// sender owns the socket and addressing state. -// A mutex serializes Send and Close to the single underlying conn. -type sender struct { - log *slog.Logger - cfg SenderConfig - sip net.IP // IPv4 source (validated) - ifIndex int // ifindex derived from Interface - pid uint16 // echo identifier (random per instance) - ip4c *ipv4.PacketConn // ipv4 wrapper over ICMP datagram (“ping”) socket - mu sync.Mutex -} - -// NewSender opens an ICMP socket bound to Source, pins to device, sets TTL, -// validates IPv4 source, and resolves the interface index. Fails fast on misconfig. -func NewSender(cfg SenderConfig) (Sender, error) { - if err := cfg.Validate(); err != nil { - return nil, err - } - - sip := cfg.Source.To4() // safe: Validate() ensures IPv4 - - // Resolve interface index; fail if not present. - ifi, err := net.InterfaceByName(cfg.Interface) - if err != nil { - return nil, fmt.Errorf("lookup interface %q: %w", cfg.Interface, err) - } - - // Unique random Echo ID per instance (enables kernel demux with ping sockets). - var idb [2]byte - if _, err := rand.Read(idb[:]); err != nil { - return nil, fmt.Errorf("rand echo id: %w", err) - } - pid := binary.BigEndian.Uint16(idb[:]) - - // Create an ICMP datagram (“ping”) socket and bind it to the source IP + Echo ID (Linux demux key). - pconn, err := listenICMPDatagram(sip, pid) - if err != nil { - return nil, err - } - - // Wrap so we can use control messages and TTL helpers. - ip4c := ipv4.NewPacketConn(pconn) - _ = ip4c.SetTTL(64) - _ = ip4c.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst, true) - - // Pin the socket to the given interface for both RX and TX routing. - if err := bindToDevice(pconn, ifi.Name); err != nil { - _ = ip4c.Close() - return nil, fmt.Errorf("bind-to-device %q: %w", ifi.Name, err) - } - - return &sender{ - log: cfg.Logger, - cfg: cfg, - sip: sip, - ifIndex: ifi.Index, - pid: pid, - ip4c: ip4c, - }, nil -} - -// Close closes the underlying socket. Concurrency-safe with Send via s.mu. -func (s *sender) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.ip4c != nil { - return s.ip4c.Close() - } - return nil -} - -// Send transmits Count echo requests and waits up to Timeout for each reply. -// It steers egress by iface and source using an ipv4.ControlMessage and validates -// echo replies by id/seq/nonce. -func (s *sender) Send(ctx context.Context, cfg SendConfig) (*SendResults, error) { - if err := cfg.Validate(); err != nil { - return nil, err - } - dip := cfg.Target.To4() - if dip == nil { - return nil, fmt.Errorf("invalid target IP: %s", cfg.Target) - } - - // Serialize access to the single socket and protect against Close. - s.mu.Lock() - defer s.mu.Unlock() - - results := &SendResults{Results: make([]SendResult, 0, cfg.Count)} - seq := uint16(1) - dst := &net.IPAddr{IP: dip} - - // Per-Send() reusable buffers to avoid hot-path allocations. - buf := make([]byte, 8192) // RX buffer (ICMP payload when using PacketConn) - - // 8-byte nonce in ICMP echo data, initialized from crypto/rand. - var n8 [8]byte - if _, err := rand.Read(n8[:]); err != nil { - return nil, fmt.Errorf("rand nonce: %w", err) - } - nonce := binary.BigEndian.Uint64(n8[:]) - - payload := make([]byte, 8) // reused; overwritten each probe - - for i := 0; i < cfg.Count; i++ { - select { - case <-ctx.Done(): - return results, ctx.Err() - default: - } - - // Prepare ICMP Echo with (id, seq, nonce). - nonce++ - binary.BigEndian.PutUint64(payload, nonce) - wb, err := (&icmp.Message{ - Type: ipv4.ICMPTypeEcho, Code: 0, - Body: &icmp.Echo{ID: int(s.pid), Seq: int(seq), Data: payload}, - }).Marshal(nil) - if err != nil { - results.Results = append(results.Results, SendResult{RTT: -1, Error: err}) - seq++ - continue - } - - // Per-packet steering: IfIndex + Src emulate IP_PKTINFO (Spec_dst + ifindex). - cm := &ipv4.ControlMessage{IfIndex: s.ifIndex, Src: s.sip} - - t0 := time.Now() - if _, err := s.ip4c.WriteTo(wb, cm, dst); err != nil { - // Try a reopen on common transient send failures. - if transientSendRetryable(err) { - if s.log != nil { - s.log.Info("uping/sender: reopen after send err", "i", i+1, "seq", seq, "err", err) - } - if e := s.reopen(); e == nil { - cm = &ipv4.ControlMessage{IfIndex: s.ifIndex, Src: s.sip} - _, err = s.ip4c.WriteTo(wb, cm, dst) - } - } - // One-shot retry on EPERM after a tiny backoff - // This can happen sometimes especially on loopback interfaces. - if err != nil && errors.Is(err, syscall.EPERM) { - time.Sleep(5 * time.Millisecond) - _, err = s.ip4c.WriteTo(wb, nil, dst) - } - if err != nil { - if s.log != nil { - s.log.Error("uping/sender: send", "i", i+1, "seq", seq, "err", err) - } - results.Results = append(results.Results, SendResult{RTT: -1, Error: err}) - seq++ - continue - } - } - - got := false - deadline := t0.Add(cfg.Timeout) - - // Poll for a reply until the absolute deadline. - for { - if ctx.Err() != nil { - return results, ctx.Err() - } - remain := time.Until(deadline) - if remain <= 0 { - break - } - if remain > maxPollSlice { - remain = maxPollSlice - } - _ = s.ip4c.SetReadDeadline(time.Now().Add(remain)) - - n, rcm, raddr, err := s.ip4c.ReadFrom(buf) - if ne, ok := err.(net.Error); ok && ne.Timeout() { - continue - } - if err != nil { - // If the socket became invalid/transient, try reopening and continue waiting. - if transientSocketErr(err) { - if s.log != nil { - s.log.Info("uping/sender: reopen after recv err", "i", i+1, "seq", seq, "err", err) - } - if e := s.reopen(); e == nil { - _ = s.ip4c.SetReadDeadline(time.Now().Add(time.Until(deadline))) - continue - } - } - if s.log != nil { - s.log.Error("uping/sender: recv", "i", i+1, "seq", seq, "err", err) - } - continue - } - - // Optionally filter by ingress ifindex when available. - if rcm != nil && rcm.IfIndex != 0 && rcm.IfIndex != s.ifIndex { - continue - } - - // Parse and validate an echo reply. buf[:n] is ICMP payload with PacketConn, - // or full IPv4 if the stack delivers that; validateEchoReply handles both. - rtt := time.Since(t0) - ok, src, itype, icode := validateEchoReply(buf[:n], s.pid, seq, nonce) - if ok { - if s.log != nil { - ip := src - if ip == nil || ip.Equal(net.IPv4zero) { - if ipaddr, _ := raddr.(*net.IPAddr); ipaddr != nil { - ip = ipaddr.IP - } else if ipaddr, _ := raddr.(*net.UDPAddr); ipaddr != nil { - ip = ipaddr.IP - } - } - s.log.Info("uping/sender: reply", "i", i+1, "seq", seq, "src", ip.String(), "rtt", rtt, "len", n) - } - results.Results = append(results.Results, SendResult{RTT: rtt, Error: nil}) - got = true - break - } - if s.log != nil { - ip := src - if ip == nil || ip.Equal(net.IPv4zero) { - if ipaddr, _ := raddr.(*net.IPAddr); ipaddr != nil { - ip = ipaddr.IP - } - } - s.log.Debug("uping/sender: ignored", "i", i+1, "seq", seq, "src", ip.String(), "icmp_type", itype, "icmp_code", icode) - } - } - - if !got { - err := fmt.Errorf("timeout waiting for seq=%d", seq) - if s.log != nil { - s.log.Warn("uping/sender: timeout", "i", i+1, "seq", seq, "err", err) - } - results.Results = append(results.Results, SendResult{RTT: -1, Error: err}) - } - seq++ - } - - return results, nil -} - -// bindToDevice applies SO_BINDTODEVICE to c’s socket so traffic stays on ifname. -func bindToDevice(c any, ifname string) error { - sc, ok := c.(syscall.Conn) - if !ok { - return fmt.Errorf("no raw fd") - } - var setErr error - raw, err := sc.SyscallConn() - if err != nil { - return err - } - if err := raw.Control(func(fd uintptr) { - if e := unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); e != nil { - setErr = e - } - }); err != nil { - return err - } - return setErr -} - -// validateEchoReply parses a packet or ICMP message, verifies checksum, -// and returns true only for Echo Reply (type=0, code=0) matching (id, seq, nonce). -// Accepts either a full IPv4 packet or a bare ICMP payload. -func validateEchoReply(pkt []byte, wantID, wantSeq uint16, wantNonce uint64) (bool, net.IP, int, int) { - // Full IPv4? - if len(pkt) >= 20 && pkt[0]>>4 == 4 { - ihl := int(pkt[0]&0x0F) * 4 - if ihl < 20 || len(pkt) < ihl+8 { - return false, net.IPv4zero, -1, -1 - } - if pkt[9] != 1 { // not ICMP - return false, net.IP(pkt[12:16]), int(pkt[9]), -1 - } - src := net.IP(pkt[12:16]) - return validateICMPEcho(pkt[ihl:], wantID, wantSeq, wantNonce, src) - } - // Otherwise treat as bare ICMP payload from PacketConn. - return validateICMPEcho(pkt, wantID, wantSeq, wantNonce, net.IPv4zero) -} - -// validateICMPEcho verifies checksum, parses with icmp.ParseMessage, and matches id/seq/nonce. -// src is surfaced unchanged (IPv4zero for bare ICMP). -func validateICMPEcho(icmpb []byte, wantID, wantSeq uint16, wantNonce uint64, src net.IP) (bool, net.IP, int, int) { - if len(icmpb) < 8 { - return false, src, -1, -1 - } - // Raw for logging/return - itype := int(icmpb[0]) - icode := int(icmpb[1]) - - // Verify Internet checksum over ICMP message. - if icmpChecksum(icmpb) != 0 { - return false, src, itype, icode - } - - m, err := icmp.ParseMessage(1, icmpb) - if err != nil { - return false, src, itype, icode - } - - // Only accept Echo Reply (type=0, code=0). Use m.Type for the predicate. - if m.Type != ipv4.ICMPTypeEchoReply { - return false, src, itype, icode - } - echo, ok := m.Body.(*icmp.Echo) - if !ok || echo == nil { - return false, src, itype, icode - } - if len(echo.Data) < 8 { - return false, src, itype, icode - } - gotNonce := binary.BigEndian.Uint64(echo.Data[:8]) - if uint16(echo.ID) == wantID && uint16(echo.Seq) == wantSeq && gotNonce == wantNonce { - return true, src, itype, icode - } - return false, src, itype, icode -} - -// icmpChecksum computes the standard Internet checksum over the ICMP message. -func icmpChecksum(b []byte) uint16 { - var s uint32 - for i := 0; i+1 < len(b); i += 2 { - s += uint32(binary.BigEndian.Uint16(b[i:])) - } - if len(b)%2 == 1 { - s += uint32(b[len(b)-1]) << 8 - } - for s>>16 != 0 { - s = (s & 0xffff) + (s >> 16) - } - return ^uint16(s) -} - -// reopen replaces the socket with a fresh ICMP datagram socket and reapplies base options. -// Used after transient errors (device down, address not ready, etc.). -func (s *sender) reopen() error { - if s.ip4c != nil { - _ = s.ip4c.Close() - } - pconn, err := listenICMPDatagram(s.sip, s.pid) // keep same Echo ID for kernel demux - if err != nil { - return err - } - ip4c := ipv4.NewPacketConn(pconn) - _ = ip4c.SetTTL(64) - _ = ip4c.SetControlMessage(ipv4.FlagInterface|ipv4.FlagDst, true) - - // Re-pin to device. - if err := bindToDevice(pconn, s.cfg.Interface); err != nil { - _ = ip4c.Close() - return fmt.Errorf("bind-to-device %q: %w", s.cfg.Interface, err) - } - - // Re-resolve ifindex defensively. - s.refreshIfIndex() - - s.ip4c = ip4c - return nil -} - -// refreshIfIndex re-resolves the interface index on demand (e.g., after a socket reopen). -func (s *sender) refreshIfIndex() { - ifi, err := net.InterfaceByName(s.cfg.Interface) - if err == nil { - s.ifIndex = ifi.Index - } -} - -// transientSocketErr classifies socket errors that are often recoverable with a reopen. -func transientSocketErr(err error) bool { - // net errors often wrap unix errors; keep the common set. - return errors.Is(err, net.ErrClosed) || - errors.Is(err, unix.EBADF) || errors.Is(err, unix.ENETDOWN) || errors.Is(err, unix.ENODEV) || - errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.ENETRESET) || - errors.Is(err, unix.ENOMEM) -} - -// transientSendRetryable classifies send errors that are often recoverable with a reopen -func transientSendRetryable(err error) bool { - return errors.Is(err, net.ErrClosed) || - errors.Is(err, unix.EBADF) || errors.Is(err, unix.ENODEV) || errors.Is(err, unix.ENETDOWN) -} - -// listenICMPDatagram creates an ICMP “ping” datagram socket bound to sip with sin_port=echoID. -func listenICMPDatagram(sip net.IP, echoID uint16) (net.PacketConn, error) { - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, unix.IPPROTO_ICMP) - if err != nil { - return nil, err - } - sa := &unix.SockaddrInet4{Port: int(echoID)} - copy(sa.Addr[:], sip.To4()) - if err := unix.Bind(fd, sa); err != nil { - _ = unix.Close(fd) - return nil, err - } - f := os.NewFile(uintptr(fd), "icmp-dgram") - pc, err := net.FilePacketConn(f) - if err != nil { - _ = f.Close() // close on error only - return nil, err - } - // pc now owns the fd; safe to close the *os.File wrapper without closing the fd. - _ = f.Close() - return pc, nil -} diff --git a/tools/uping/pkg/uping/sender_test.go b/tools/uping/pkg/uping/sender_test.go deleted file mode 100644 index 42a7d73aa..000000000 --- a/tools/uping/pkg/uping/sender_test.go +++ /dev/null @@ -1,554 +0,0 @@ -//go:build linux - -package uping - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "net" - "testing" - "time" - - "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" -) - -// Verifies ICMP echo packet construction and checksum correctness. -func TestUping_Sender_ChecksumAndICMPEcho(t *testing.T) { - t.Parallel() - id, seq := uint16(0x1234), uint16(0x9abc) - p := icmpEcho(id, seq, []byte{1, 2, 3, 4, 5}) - require.Equal(t, 13, len(p)) - require.Equal(t, byte(8), p[0]) - got := binary.BigEndian.Uint16(p[2:4]) - binary.BigEndian.PutUint16(p[2:4], 0) - require.Equal(t, icmpChecksum(p), got) -} - -// Confirms that validateEchoReply correctly detects valid ICMP echo replies. -func TestUping_Sender_ValidateEchoReply(t *testing.T) { - t.Parallel() - src := net.IPv4(10, 1, 2, 3).To4() - dst := net.IPv4(10, 9, 9, 9).To4() - id, seq, nonce := uint16(0x42), uint16(7), uint64(0xdeadbeefcafebabe) - req := icmpEcho(id, seq, func() []byte { b := make([]byte, 8); binary.BigEndian.PutUint64(b, nonce); return b }()) - rep := make([]byte, 20+len(req)) - rep[0] = 0x45 - copy(rep[12:16], src) - copy(rep[16:20], dst) - rep[9] = 1 - binary.BigEndian.PutUint16(rep[10:], icmpChecksum(rep[:20])) - icmp := rep[20:] - copy(icmp, req) - icmp[0] = 0 - binary.BigEndian.PutUint16(icmp[2:], 0) - binary.BigEndian.PutUint16(icmp[2:], icmpChecksum(icmp)) - ok, gotSrc, it, ic := validateEchoReply(rep, id, seq, nonce) - require.True(t, ok) - require.True(t, gotSrc.Equal(src)) - require.Equal(t, 0, it) - require.Equal(t, 0, ic) - ok, _, _, _ = validateEchoReply(rep, id, seq, nonce+1) - require.False(t, ok) -} - -// Verifies that a basic ping to localhost succeeds using loopback interface. -func TestUping_Sender_Localhost_Success(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 2, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 2) - for i, r := range res.Results { - require.NoErrorf(t, r.Error, "i=%d", i) - require.Greaterf(t, r.RTT, time.Duration(0), "i=%d", i) - require.LessOrEqualf(t, r.RTT, time.Second, "i=%d", i) - } -} - -// Confirms packets are correctly steered through a specific interface (loopback). -func TestUping_Sender_Interface_Steer_Loopback(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) - require.Greater(t, res.Results[0].RTT, time.Duration(0)) -} - -// Ensures timeout behavior when sending to a nonresponsive (blackhole) address. -func TestUping_Sender_Timeout_Blackhole(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - ip := pickLocalV4(t) - ifname := ifaceNameForIP(t, ip) - - s, err := NewSender(SenderConfig{Source: ip, Interface: ifname}) - require.NoError(t, err) - defer s.Close() - ctx, cancel := context.WithTimeout(t.Context(), 900*time.Millisecond) - defer cancel() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(203, 0, 113, 123), Count: 1, Timeout: 600 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.Error(t, res.Results[0].Error) -} - -// Tests SendConfig validation logic for defaults and invalid parameters. -func TestUping_SendConfig_Validate_DefaultsAndErrors(t *testing.T) { - t.Parallel() - c := SendConfig{} - err := c.Validate() - require.NoError(t, err) - require.Equal(t, defaultSenderCount, c.Count) - require.Equal(t, defaultSenderTimeout, c.Timeout) - require.Error(t, (&SendConfig{Count: -1, Timeout: time.Second}).Validate()) - require.Error(t, (&SendConfig{Count: 1, Timeout: -time.Second}).Validate()) -} - -// Rejects invalid IPv6 sources when creating a sender. -func TestUping_Sender_NewSender_InvalidSource(t *testing.T) { - t.Parallel() - _, err := NewSender(SenderConfig{Source: net.IPv6loopback, Interface: "lo"}) - require.Error(t, err) -} - -// Rejects creation with nonexistent network interface. -func TestUping_Sender_NewSender_BadInterfaceName(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - _, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "does-not-exist-xyz"}) - require.Error(t, err) -} - -// Ensures Send() exits cleanly if the context is canceled before sending. -func TestUping_Sender_ContextCanceledEarly(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - - ctx, cancel := context.WithCancel(t.Context()) - cancel() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 3, Timeout: 200 * time.Millisecond}) - require.ErrorIs(t, err, context.Canceled) - require.NotNil(t, res) - require.Len(t, res.Results, 0) -} - -// Validates that malformed or non-ICMP packets are rejected by the parser. -func TestUping_ValidateEchoReply_Negatives(t *testing.T) { - t.Parallel() - id, seq, nonce := uint16(1), uint16(2), uint64(3) - - ip := make([]byte, 20+16) - ip[0] = 0x45 - ip[9] = 6 - copy(ip[12:16], net.IPv4(1, 2, 3, 4).To4()) - copy(ip[16:20], net.IPv4(5, 6, 7, 8).To4()) - binary.BigEndian.PutUint16(ip[10:], icmpChecksum(ip[:20])) - ok, _, _, _ := validateEchoReply(ip, id, seq, nonce) - require.False(t, ok) - - ok, _, _, _ = validateEchoReply([]byte{0x45, 0x00}, id, seq, nonce) - require.False(t, ok) - - icmp := make([]byte, 8) - icmp[0] = 3 - binary.BigEndian.PutUint16(icmp[2:], icmpChecksum(icmp)) - pkt := buildIPv4Packet(net.IPv4(9, 9, 9, 9), net.IPv4(1, 1, 1, 1), 1, icmp) - ok, _, it, _ := validateEchoReply(pkt, id, seq, nonce) - require.False(t, ok) - require.Equal(t, 3, it) -} - -// Confirms partial timeouts still return full count of results. -func TestUping_Sender_PartialTimeoutsStillReturnCount(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - sLo, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer sLo.Close() - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - - okRes, err := sLo.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 500 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, okRes.Results, 1) - require.NoError(t, okRes.Results[0].Error) - - ip := pickLocalV4(t) - ifname := ifaceNameForIP(t, ip) - sWAN, err := NewSender(SenderConfig{Source: ip, Interface: ifname}) - require.NoError(t, err) - defer sWAN.Close() - - toRes, err := sWAN.Send(ctx, SendConfig{Target: net.IPv4(203, 0, 113, 123), Count: 1, Timeout: 400 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, toRes.Results, 1) - require.Error(t, toRes.Results[0].Error) -} - -// Checks SendResults.Failed() correctly identifies failures. -func TestUping_SendResults_Failed(t *testing.T) { - t.Parallel() - - rs := &SendResults{Results: []SendResult{ - {RTT: 10 * time.Millisecond, Error: nil}, - {RTT: -1, Error: errors.New("timeout")}, - }} - require.True(t, rs.Failed()) - - rs2 := &SendResults{Results: []SendResult{ - {RTT: 1 * time.Millisecond, Error: nil}, - }} - require.False(t, rs2.Failed()) -} - -// Rejects ICMP echo requests (type 8) as valid replies. -func TestUping_ValidateEchoReply_RejectsEchoRequest(t *testing.T) { - t.Parallel() - - src := net.IPv4(10, 0, 0, 1).To4() - dst := net.IPv4(10, 0, 0, 2).To4() - id, seq, nonce := uint16(11), uint16(22), uint64(33) - - payload := make([]byte, 8) - binary.BigEndian.PutUint64(payload, nonce) - req := icmpEcho(id, seq, payload) - ip := make([]byte, 20+len(req)) - ip[0] = 0x45 - ip[9] = 1 - copy(ip[12:16], src) - copy(ip[16:20], dst) - binary.BigEndian.PutUint16(ip[10:], icmpChecksum(ip[:20])) - copy(ip[20:], req) - - ok, _, it, _ := validateEchoReply(ip, id, seq, nonce) - require.False(t, ok) - require.Equal(t, 8, it) -} - -// Ensures socket reopen on send failure works and resumes successfully (PacketConn path). -func TestUping_Sender_ReopenOnSend_ReconnectAndSend(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - sIface, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer sIface.Close() - - s := sIface.(*sender) - - // Force a closed connection, then explicit reopen, then send should work. - _ = s.ip4c.Close() - require.NoError(t, s.reopen()) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) - require.Greater(t, res.Results[0].RTT, time.Duration(0)) -} - -// Verifies recv path handles blackholed targets cleanly without races or crashes. -func TestUping_Sender_RecvTimeout_Blackhole_NoCrash(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - ip := pickLocalV4(t) - ifname := ifaceNameForIP(t, ip) - - sIface, err := NewSender(SenderConfig{Source: ip, Interface: ifname}) - require.NoError(t, err) - defer sIface.Close() - - // Long enough overall context to let one probe time out cleanly on recv path. - ctx, cancel := context.WithTimeout(t.Context(), 1200*time.Millisecond) - defer cancel() - - // Blackhole target; expect probe-level timeout result, not a crash or top-level error. - res, err := sIface.Send(ctx, SendConfig{ - Target: net.IPv4(203, 0, 113, 200), // TEST-NET-3 - Count: 1, - Timeout: 900 * time.Millisecond, - }) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.Error(t, res.Results[0].Error) -} - -// Verifies transientSocketErr correctly classifies recoverable errors. -func TestUping_Sender_TransientSocketErr(t *testing.T) { - t.Parallel() - cases := []struct { - err error - want bool - }{ - {unix.EBADF, true}, - {unix.ENETDOWN, true}, - {unix.ENODEV, true}, - {unix.EADDRNOTAVAIL, true}, - {unix.ENOBUFS, true}, - {unix.ENETRESET, true}, - {unix.ENOMEM, true}, - {unix.EPERM, false}, - {unix.EINVAL, false}, - {fmt.Errorf("wrap: %w", unix.EBADF), true}, - {fmt.Errorf("wrap: %w", unix.ENOBUFS), true}, - {nil, false}, - {unix.EAGAIN, false}, - {errors.New("other"), false}, - } - for i, tc := range cases { - got := transientSocketErr(tc.err) - if got != tc.want { - t.Fatalf("case %d: err=%v got=%v want=%v", i, tc.err, got, tc.want) - } - } -} - -// Verifies transientSendRetryable correctly classifies retryable errors. -func TestUping_Sender_TransientSendRetryable(t *testing.T) { - t.Parallel() - cases := []struct { - err error - want bool - }{ - {unix.EBADF, true}, - {unix.ENODEV, true}, - {unix.ENETDOWN, true}, - {fmt.Errorf("wrap: %w", unix.EBADF), true}, - {fmt.Errorf("wrap: %w", unix.ENETDOWN), true}, - {nil, false}, - {unix.ENOBUFS, false}, - {unix.EADDRNOTAVAIL, false}, - {unix.ENETRESET, false}, - {unix.ENOMEM, false}, - {unix.EAGAIN, false}, - } - for i, tc := range cases { - got := transientSendRetryable(tc.err) - if got != tc.want { - t.Fatalf("case %d: err=%v got=%v want=%v", i, tc.err, got, tc.want) - } - } -} - -// Verifies that replies arrive when bound to the correct interface (loopback). -func TestUping_Sender_InterfaceBinding_AcceptsOnBoundInterface(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - s, err := NewSender(SenderConfig{Source: net.IPv4(127, 0, 0, 1), Interface: "lo"}) - require.NoError(t, err) - defer s.Close() - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - - res, err := s.Send(ctx, SendConfig{Target: net.IPv4(127, 0, 0, 1), Count: 1, Timeout: 800 * time.Millisecond}) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.NoError(t, res.Results[0].Error) -} - -// Verifies that replies do NOT arrive if we bind to a different interface. -// We bind to a non-loopback iface and attempt to ping 127.0.0.1. -// With SO_BINDTODEVICE, TX must use that iface and RX must arrive on it; -// loopback replies won't be delivered to this socket. We accept either a send error -// (route/source mismatch) or a clean per-probe timeout as success. -func TestUping_Sender_InterfaceBinding_RejectsFromOtherInterface(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - ip := pickNonLoopbackV4(t) - ifname := ifaceNameForIP(t, ip) - - s, err := NewSender(SenderConfig{Source: ip, Interface: ifname}) - require.NoError(t, err) - defer s.Close() - - ctx, cancel := context.WithTimeout(t.Context(), 1200*time.Millisecond) - defer cancel() - - res, err := s.Send(ctx, SendConfig{ - Target: net.IPv4(127, 0, 0, 1), - Count: 1, - Timeout: 900 * time.Millisecond, - }) - require.NoError(t, err) - require.Len(t, res.Results, 1) - require.Error(t, res.Results[0].Error, "expected timeout when bound to %s (%s)", ifname, ip) -} - -// Verifies demux across different Echo IDs: a socket bound with EchoID=B does not receive replies -// for EchoID=A. We send only from sA; sA must succeed, sB must timeout or error. This avoids -// kernel-version-dependent behavior for same-ID fanout. -func TestUping_KernelDemux_DifferentIDs_NoCrossDelivery(t *testing.T) { - t.Parallel() - requirePingSocket(t) - - src := net.IPv4(127, 0, 0, 1) - - newEchoID_A := func() (uint16, error) { return 0xA1A1, nil } - newEchoID_B := func() (uint16, error) { return 0xB2B2, nil } - - sA, err := NewSender(SenderConfig{Source: src, Interface: "lo", NewEchoIDFunc: newEchoID_A}) - require.NoError(t, err) - defer sA.Close() - - sB, err := NewSender(SenderConfig{Source: src, Interface: "lo", NewEchoIDFunc: newEchoID_B}) - require.NoError(t, err) - defer sB.Close() - - // Fire a single probe from sA only. - ctxA, cancelA := context.WithTimeout(t.Context(), 1500*time.Millisecond) - defer cancelA() - resA, errA := sA.Send(ctxA, SendConfig{Target: src, Count: 1, Timeout: 900 * time.Millisecond}) - require.NoError(t, errA) - require.Len(t, resA.Results, 1) - require.NoError(t, resA.Results[0].Error, "sA should receive its own reply") - - // Concurrently “listen” with sB by issuing a send to an unroutable target so it blocks on recv. - // If the kernel cross-delivered sA's reply to sB (it shouldn't), sB would succeed here; we expect timeout/error. - ctxB, cancelB := context.WithTimeout(t.Context(), 1200*time.Millisecond) - defer cancelB() - resB, errB := sB.Send(ctxB, SendConfig{Target: net.IPv4(203, 0, 113, 200), Count: 1, Timeout: 900 * time.Millisecond}) - - if errB != nil { - // Transport error is also fine (confirms no unexpected success). - t.Logf("sB got top-level error (acceptable): %v", errB) - return - } - require.Len(t, resB.Results, 1) - require.Error(t, resB.Results[0].Error, "sB should NOT receive sA's reply for a different EchoID") -} - -// helper to build a minimal IPv4+ICMP frame -func buildIPv4Packet(src, dst net.IP, proto byte, payload []byte) []byte { - ip := make([]byte, 20+len(payload)) - ip[0] = 0x45 - ip[9] = proto - copy(ip[12:16], src.To4()) - copy(ip[16:20], dst.To4()) - binary.BigEndian.PutUint16(ip[10:], icmpChecksum(ip[:20])) - copy(ip[20:], payload) - return ip -} - -// requireRawSockets ensures the environment can open a Linux raw ICMP socket. -func requireRawSockets(t *testing.T) { - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.IPPROTO_ICMP) - require.NoError(t, err) - _ = unix.Close(fd) -} - -// requirePingSocket ensures the environment can open a Linux ICMP “ping” datagram socket. -func requirePingSocket(t *testing.T) { - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, unix.IPPROTO_ICMP) - require.NoError(t, err) - _ = unix.Close(fd) -} - -// Pick a non-loopback IPv4 if available, else fall back to loopback. -func pickLocalV4(t *testing.T) net.IP { - ifs, err := net.Interfaces() - require.NoError(t, err) - for _, ifi := range ifs { - if (ifi.Flags & net.FlagUp) == 0 { - continue - } - addrs, _ := ifi.Addrs() - for _, a := range addrs { - if ipn, ok := a.(*net.IPNet); ok && ipn.IP.To4() != nil && !ipn.IP.IsLoopback() { - return ipn.IP.To4() - } - } - } - return net.IPv4(127, 0, 0, 1) -} - -// find the interface name that owns the given IPv4 address (exact match preferred, -// falls back to subnet containment). Fails the test if not found. -func ifaceNameForIP(t *testing.T, ip net.IP) string { - ifs, err := net.Interfaces() - require.NoError(t, err) - for _, ifi := range ifs { - addrs, _ := ifi.Addrs() - for _, a := range addrs { - if ipn, ok := a.(*net.IPNet); ok && ipn.IP.To4() != nil { - if ipn.IP.To4().Equal(ip.To4()) { - return ifi.Name - } - } - } - } - for _, ifi := range ifs { - addrs, _ := ifi.Addrs() - for _, a := range addrs { - if ipn, ok := a.(*net.IPNet); ok && ipn.IP.To4() != nil { - if ipn.Contains(ip) { - return ifi.Name - } - } - } - } - t.Fatalf("could not find interface name for ip %v", ip) - return "" -} - -func icmpEcho(id, seq uint16, payload []byte) []byte { - h := make([]byte, 8+len(payload)) - h[0] = 8 - binary.BigEndian.PutUint16(h[4:], id) - binary.BigEndian.PutUint16(h[6:], seq) - copy(h[8:], payload) - binary.BigEndian.PutUint16(h[2:], icmpChecksum(h)) - return h -} - -// pickNonLoopbackV4 returns a non-loopback IPv4 address. -func pickNonLoopbackV4(t *testing.T) net.IP { - ifs, err := net.Interfaces() - require.NoError(t, err) - for _, ifi := range ifs { - if (ifi.Flags&net.FlagUp) == 0 || (ifi.Flags&net.FlagLoopback) != 0 { - continue - } - addrs, _ := ifi.Addrs() - for _, a := range addrs { - if ipn, ok := a.(*net.IPNet); ok { - ip := ipn.IP.To4() - if ip != nil && !ip.IsLoopback() { - return ip - } - } - } - } - t.Fatalf("could not find non-loopback IPv4 address") - return nil -}