Skip to content

Commit 600c3d0

Browse files
committed
init
0 parents  commit 600c3d0

File tree

5 files changed

+770
-0
lines changed

5 files changed

+770
-0
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/sockfwd
2+

cmd/root.go

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"log"
7+
"net"
8+
"os"
9+
"os/signal"
10+
"runtime"
11+
"strings"
12+
"sync"
13+
"sync/atomic"
14+
"syscall"
15+
"time"
16+
17+
"github.com/sirupsen/logrus"
18+
"github.com/spf13/cobra"
19+
)
20+
21+
var rootCmd = &cobra.Command{
22+
Use: "sockfwd",
23+
Short: "Forward between sockets",
24+
RunE: runAction,
25+
}
26+
27+
func init() {
28+
rootCmd.Flags().BoolP("quiet", "q", false, "Quiet mode")
29+
rootCmd.Flags().StringP("source", "s", "", "Source address")
30+
rootCmd.Flags().StringP("destination", "d", "", "Destination address")
31+
}
32+
33+
func Execute() {
34+
if err := rootCmd.Execute(); err != nil {
35+
fmt.Println(err)
36+
os.Exit(1)
37+
}
38+
}
39+
40+
var (
41+
SIGINT = os.Interrupt
42+
SIGKILL = os.Kill
43+
//allows compilation under windows,
44+
//even though it cannot send USR signals
45+
SIGUSR1 = syscall.Signal(0xa)
46+
SIGUSR2 = syscall.Signal(0xc)
47+
SIGTERM = syscall.Signal(0xf)
48+
)
49+
50+
func listen(addr string) (net.Listener, error) {
51+
//listen
52+
network := "tcp"
53+
if strings.HasPrefix(addr, "unix:") {
54+
network = "unix"
55+
addr = strings.TrimPrefix(addr, "unix:")
56+
}
57+
return net.Listen(network, addr)
58+
}
59+
60+
func runAction(cmd *cobra.Command, args []string) error {
61+
source := cmd.Flag("source").Value.String()
62+
destination := cmd.Flag("destination").Value.String()
63+
quiet, err := cmd.Flags().GetBool("quiet")
64+
if err != nil {
65+
return err
66+
}
67+
l, err := listen(source)
68+
if err != nil {
69+
return err
70+
}
71+
//cleanup before shutdown
72+
go func() {
73+
c := make(chan os.Signal)
74+
signal.Notify(c)
75+
for sig := range c {
76+
switch sig {
77+
case SIGINT, SIGTERM, SIGKILL:
78+
l.Close()
79+
//os.Remove(config.SocketAddr)
80+
logrus.Info("closed listener and removed socket")
81+
os.Exit(0)
82+
case SIGUSR1:
83+
mem := runtime.MemStats{}
84+
runtime.ReadMemStats(&mem)
85+
logrus.Info("stats:\n"+
86+
" %s, uptime: %s\n"+
87+
" goroutines: %d, mem-alloc: %d\n"+
88+
" connections open: %d total: %d",
89+
runtime.Version(), time.Now().Sub(uptime),
90+
runtime.NumGoroutine(), mem.Alloc,
91+
atomic.LoadInt64(&current), atomic.LoadUint64(&total))
92+
case SIGUSR2:
93+
//toggle logging with USR2 signal
94+
// config.Quiet = !config.Quiet
95+
// logf("connection logging: %v", config.Quiet)
96+
}
97+
}
98+
}()
99+
//accept connections
100+
logrus.Info("listening on " + source + " and forwarding to " + destination)
101+
for {
102+
uconn, err := l.Accept()
103+
if err != nil {
104+
logrus.Info("accept failed: %s", err)
105+
continue
106+
}
107+
go fwd(uconn, destination, quiet)
108+
}
109+
}
110+
111+
//detailed statistics
112+
var uptime = time.Now()
113+
var total uint64
114+
var current int64
115+
116+
//pool of buffers (default to io.Copy buffer size)
117+
var pool = sync.Pool{
118+
New: func() interface{} {
119+
return make([]byte, 32*1024)
120+
},
121+
}
122+
123+
func dial(destination string) (net.Conn, error) {
124+
network := "tcp"
125+
if strings.HasPrefix(destination, "unix:") {
126+
network = "unix"
127+
destination = strings.TrimPrefix(destination, "unix:")
128+
}
129+
return net.Dial(network, destination)
130+
}
131+
132+
func fwd(uconn net.Conn, destination string, quiet bool) {
133+
tconn, err := dial(destination)
134+
if err != nil {
135+
log.Printf("tcp dial failed: %s", err)
136+
uconn.Close()
137+
return
138+
}
139+
//stats
140+
atomic.AddUint64(&total, 1)
141+
atomic.AddInt64(&current, 1)
142+
//optional log
143+
if !quiet {
144+
logrus.Info("connection #%d (%d open)", atomic.LoadUint64(&total), atomic.LoadInt64(&current))
145+
}
146+
//pipe!
147+
go func() {
148+
ubuff := pool.Get().([]byte)
149+
io.CopyBuffer(uconn, tconn, ubuff)
150+
pool.Put(ubuff)
151+
uconn.Close()
152+
//stats
153+
atomic.AddInt64(&current, -1)
154+
}()
155+
go func() {
156+
tbuff := pool.Get().([]byte)
157+
io.CopyBuffer(tconn, uconn, tbuff)
158+
pool.Put(tbuff)
159+
tconn.Close()
160+
}()
161+
}

go.mod

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module github.com/robberphex/sockfwd
2+
3+
go 1.17
4+
5+
require (
6+
github.com/jpillora/opts v1.2.0
7+
github.com/sirupsen/logrus v1.8.1
8+
github.com/spf13/cobra v1.2.1
9+
)
10+
11+
require (
12+
github.com/hashicorp/errwrap v1.0.0 // indirect
13+
github.com/hashicorp/go-multierror v1.0.0 // indirect
14+
github.com/inconshreveable/mousetrap v1.0.0 // indirect
15+
github.com/posener/complete v1.2.2-0.20190308074557-af07aa5181b3 // indirect
16+
github.com/spf13/pflag v1.0.5 // indirect
17+
golang.org/x/sys v0.0.0-20210510120138-977fb7262007 // indirect
18+
)

0 commit comments

Comments
 (0)