Skip to content

Commit a2644bc

Browse files
committed
nf2go: convert nftables rules to golang code
One of the biggest barriers to adopt the netlink format for nftables is the complexity of writing bytecode. This commits adds a tool that allows to take an nftables dump and generate the corresponding golang code and validating that the generated code produces the exact same output. Change-Id: I491b35e0d8062de33c67091dd4126d843b231838 Signed-off-by: Antonio Ojea <[email protected]>
1 parent 69f487d commit a2644bc

File tree

3 files changed

+826
-0
lines changed

3 files changed

+826
-0
lines changed

internal/nf2go/main.go

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"os"
7+
"os/exec"
8+
"runtime"
9+
"strings"
10+
11+
"github.com/google/go-cmp/cmp"
12+
"github.com/google/nftables"
13+
"github.com/vishvananda/netns"
14+
)
15+
16+
func main() {
17+
args := os.Args[1:]
18+
if len(args) != 1 {
19+
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
20+
}
21+
22+
filename := args[0]
23+
24+
runtime.LockOSThread()
25+
defer runtime.UnlockOSThread()
26+
27+
// Create a new network namespace
28+
ns, err := netns.New()
29+
if err != nil {
30+
log.Fatalf("netns.New() failed: %v", err)
31+
}
32+
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
33+
if err != nil {
34+
log.Fatalf("nftables.New() failed: %v", err)
35+
}
36+
37+
scriptOutput, err := applyNFTRuleset(filename)
38+
if err != nil {
39+
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
40+
}
41+
if len(scriptOutput) > 0 {
42+
log.Printf("nft output:\n%s", scriptOutput)
43+
}
44+
45+
// Create the output file
46+
f, err := os.Create("nftables_recreate.go")
47+
if err != nil {
48+
log.Fatal(err)
49+
}
50+
defer f.Close()
51+
52+
// Helper function to print to the file
53+
pf := func(format string, a ...interface{}) {
54+
_, err := fmt.Fprintf(f, format, a...)
55+
if err != nil {
56+
log.Fatal(err)
57+
}
58+
}
59+
60+
pf("// Code generated by nft2go. DO NOT EDIT.\n")
61+
pf("package main\n\n")
62+
pf("import (\n")
63+
pf("\t\"fmt\"\n")
64+
pf("\t\"log\"\n")
65+
pf("\t\"github.com/google/nftables\"\n")
66+
pf("\t\"github.com/google/nftables/expr\"\n")
67+
pf(")\n\n")
68+
pf("func main() {\n")
69+
pf("\tn, err:= nftables.New()\n")
70+
pf("\tif err!= nil {\n")
71+
pf("\t\tlog.Fatal(err)\n")
72+
pf("\t}\n\n")
73+
pf("\n")
74+
pf("\tvar expressions []expr.Any\n")
75+
pf("\tvar chain *nftables.Chain\n")
76+
77+
tables, err := n.ListTables()
78+
if err != nil {
79+
log.Fatalf("ListTables failed: %v", err)
80+
}
81+
82+
chains, err := n.ListChains()
83+
if err != nil {
84+
log.Fatal(err)
85+
}
86+
87+
for _, table := range tables {
88+
pf("\ttable:= n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
89+
for _, chain := range chains {
90+
if chain.Table.Name != table.Name {
91+
continue
92+
}
93+
94+
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
95+
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))
96+
97+
rules, err := n.GetRules(table, chain)
98+
if err != nil {
99+
log.Fatal(err)
100+
}
101+
102+
for _, rule := range rules {
103+
pf("\texpressions = []expr.Any{\n")
104+
for _, exp := range rule.Exprs {
105+
pf("\t\t%#v,\n", exp)
106+
}
107+
pf("\t\t}\n")
108+
pf("\tn.AddRule(&nftables.Rule{\n")
109+
pf("\t\tTable: table,\n")
110+
pf("\t\tChain: chain,\n")
111+
pf("\t\tExprs: expressions,\n")
112+
pf("\t})\n")
113+
}
114+
}
115+
116+
pf("\n\tif err:= n.Flush(); err!= nil {\n")
117+
pf("\t\tlog.Fatalf(\"fail to flush rules: %v\", err)\n")
118+
pf("\t}\n\n")
119+
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
120+
pf("}\n")
121+
122+
// Program nftables using your Go code
123+
if err := flushNFTRuleset(); err != nil {
124+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
125+
}
126+
127+
// Format the generated code
128+
cmd := exec.Command("gofmt", "-w", "-s", "nftables_recreate.go")
129+
output, err := cmd.CombinedOutput()
130+
if err != nil {
131+
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
132+
}
133+
134+
// Run the generated code
135+
cmd = exec.Command("go", "run", "nftables_recreate.go")
136+
output, err = cmd.CombinedOutput()
137+
if err != nil {
138+
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
139+
}
140+
141+
// Retrieve nftables state using nft
142+
actualOutput, err := listNFTRuleset()
143+
if err != nil {
144+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
145+
}
146+
147+
log.Printf("Actual output:\n%s", actualOutput)
148+
149+
expectedOutput, err := os.ReadFile(filename)
150+
if err != nil {
151+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
152+
}
153+
154+
if string(expectedOutput) != actualOutput {
155+
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(expectedOutput, actualOutput))
156+
}
157+
158+
if err := flushNFTRuleset(); err != nil {
159+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
160+
}
161+
}
162+
}
163+
164+
func applyNFTRuleset(scriptPath string) (string, error) {
165+
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
166+
out, err := cmd.CombinedOutput()
167+
if err != nil {
168+
return string(out), err
169+
}
170+
return strings.TrimSpace(string(out)), nil
171+
}
172+
173+
func listNFTRuleset() (string, error) {
174+
cmd := exec.Command("nft", "list", "ruleset")
175+
out, err := cmd.CombinedOutput()
176+
if err != nil {
177+
return string(out), err
178+
}
179+
return strings.TrimSpace(string(out)), nil
180+
}
181+
182+
func flushNFTRuleset() error {
183+
cmd := exec.Command("nft", "flush", "ruleset")
184+
return cmd.Run()
185+
}
186+
187+
func ChainHookRef(hookNum *nftables.ChainHook) string {
188+
i := uint32(0)
189+
if hookNum != nil {
190+
i = uint32(*hookNum)
191+
}
192+
switch i {
193+
case 0:
194+
return "nftables.ChainHookPrerouting"
195+
case 1:
196+
return "nftables.ChainHookInput"
197+
case 2:
198+
return "nftables.ChainHookForward"
199+
case 3:
200+
return "nftables.ChainHookOutput"
201+
case 4:
202+
return "nftables.ChainHookPostrouting"
203+
case 5:
204+
return "nftables.ChainHookIngress"
205+
case 6:
206+
return "nftables.ChainHookEgress"
207+
}
208+
return ""
209+
}
210+
211+
func ChainPrioRef(priority *nftables.ChainPriority) string {
212+
i := int32(0)
213+
if priority != nil {
214+
i = int32(*priority)
215+
}
216+
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
217+
}
218+
219+
func ChainTypeString(chaintype nftables.ChainType) string {
220+
switch chaintype {
221+
case nftables.ChainTypeFilter:
222+
return "nftables.ChainTypeFilter"
223+
case nftables.ChainTypeRoute:
224+
return "nftables.ChainTypeRoute"
225+
case nftables.ChainTypeNAT:
226+
return "nftables.ChainTypeNAT"
227+
default:
228+
return "nftables.ChainTypeFilter"
229+
}
230+
}
231+
232+
func TableFamilyString(family nftables.TableFamily) string {
233+
switch family {
234+
case nftables.TableFamilyUnspecified:
235+
return "nftables.TableFamilyUnspecified"
236+
case nftables.TableFamilyINet:
237+
return "nftables.TableFamilyINet"
238+
case nftables.TableFamilyIPv4:
239+
return "nftables.TableFamilyIPv4"
240+
case nftables.TableFamilyIPv6:
241+
return "nftables.TableFamilyIPv6"
242+
case nftables.TableFamilyARP:
243+
return "nftables.TableFamilyARP"
244+
case nftables.TableFamilyNetdev:
245+
return "nftables.TableFamilyNetdev"
246+
case nftables.TableFamilyBridge:
247+
return "nftables.TableFamilyBridge"
248+
default:
249+
return "nftables.TableFamilyIPv4"
250+
}
251+
}

0 commit comments

Comments
 (0)