Skip to content

Commit 5dae564

Browse files
committed
nf2go: convert nftables rules to golang code
One of the biggest barriers to adopt the netlink format for nftables is the complexity of writing bytecode. This commits adds a tool that allows to take an nftables dump and generate the corresponding golang code and validating that the generated code produces the exact same output. Change-Id: I491b35e0d8062de33c67091dd4126d843b231838 Signed-off-by: Antonio Ojea <[email protected]>
1 parent 69f487d commit 5dae564

File tree

3 files changed

+857
-0
lines changed

3 files changed

+857
-0
lines changed

internal/nf2go/main.go

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
"io/ioutil"
8+
"log"
9+
"os"
10+
"os/exec"
11+
"path/filepath"
12+
"regexp"
13+
"runtime"
14+
"strings"
15+
16+
"github.com/google/go-cmp/cmp"
17+
"github.com/google/nftables"
18+
"github.com/vishvananda/netns"
19+
)
20+
21+
func main() {
22+
args := os.Args[1:]
23+
if len(args) != 1 {
24+
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
25+
}
26+
27+
filename := args[0]
28+
29+
runtime.LockOSThread()
30+
defer runtime.UnlockOSThread()
31+
32+
// Create a new network namespace
33+
ns, err := netns.New()
34+
if err != nil {
35+
log.Fatalf("netns.New() failed: %v", err)
36+
}
37+
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
38+
if err != nil {
39+
log.Fatalf("nftables.New() failed: %v", err)
40+
}
41+
42+
scriptOutput, err := applyNFTRuleset(filename)
43+
if err != nil {
44+
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
45+
}
46+
47+
var buf bytes.Buffer
48+
// Helper function to print to the file
49+
pf := func(format string, a ...interface{}) {
50+
_, err := fmt.Fprintf(&buf, format, a...)
51+
if err != nil {
52+
log.Fatal(err)
53+
}
54+
}
55+
56+
pf("// Code generated by nft2go. DO NOT EDIT.\n")
57+
pf("package main\n\n")
58+
pf("import (\n")
59+
pf("\t\"fmt\"\n")
60+
pf("\t\"log\"\n")
61+
pf("\t\"github.com/google/nftables\"\n")
62+
pf("\t\"github.com/google/nftables/expr\"\n")
63+
pf(")\n\n")
64+
pf("func main() {\n")
65+
pf("\tn, err:= nftables.New()\n")
66+
pf("\tif err!= nil {\n")
67+
pf("\t\tlog.Fatal(err)\n")
68+
pf("\t}\n\n")
69+
pf("\n")
70+
pf("\tvar expressions []expr.Any\n")
71+
pf("\tvar chain *nftables.Chain\n")
72+
73+
tables, err := n.ListTables()
74+
if err != nil {
75+
log.Fatalf("ListTables failed: %v", err)
76+
}
77+
78+
chains, err := n.ListChains()
79+
if err != nil {
80+
log.Fatal(err)
81+
}
82+
83+
for _, table := range tables {
84+
pf("\ttable:= n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
85+
for _, chain := range chains {
86+
if chain.Table.Name != table.Name {
87+
continue
88+
}
89+
90+
sets, err := n.GetSets(table)
91+
if err != nil {
92+
log.Fatal(err)
93+
}
94+
for _, set := range sets {
95+
// TODO datatype and the other options
96+
pf("\tn.AddSet(&nftables.Set{\n")
97+
pf("\t\tTable: table,\n")
98+
pf("\t\tName: \"%s\",\n", set.Name)
99+
pf("\t}, nil)\n")
100+
}
101+
102+
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
103+
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))
104+
105+
rules, err := n.GetRules(table, chain)
106+
if err != nil {
107+
log.Fatal(err)
108+
}
109+
110+
for _, rule := range rules {
111+
pf("\texpressions = []expr.Any{\n")
112+
for _, exp := range rule.Exprs {
113+
pf("\t\t%#v,\n", exp)
114+
}
115+
pf("\t\t}\n")
116+
pf("\tn.AddRule(&nftables.Rule{\n")
117+
pf("\t\tTable: table,\n")
118+
pf("\t\tChain: chain,\n")
119+
pf("\t\tExprs: expressions,\n")
120+
pf("\t})\n")
121+
}
122+
}
123+
124+
pf("\n\tif err:= n.Flush(); err!= nil {\n")
125+
pf("\t\tlog.Fatalf(\"fail to flush rules: %v\", err)\n")
126+
pf("\t}\n\n")
127+
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
128+
pf("}\n")
129+
130+
// Program nftables using your Go code
131+
if err := flushNFTRuleset(); err != nil {
132+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
133+
}
134+
135+
// Create the output file
136+
// Create a temporary directory
137+
tempDir, err := ioutil.TempDir("", "nftables_gen")
138+
if err != nil {
139+
log.Fatal(err)
140+
}
141+
defer os.RemoveAll(tempDir) // Clean up the temporary directory
142+
143+
// Create the temporary Go file
144+
tempGoFile := filepath.Join(tempDir, "nftables_recreate.go")
145+
f, err := os.Create(tempGoFile)
146+
if err != nil {
147+
log.Fatal(err)
148+
}
149+
defer f.Close()
150+
151+
fmt.Println("Generated code:")
152+
153+
mw := io.MultiWriter(f, os.Stdout)
154+
buf.WriteTo(mw)
155+
156+
// Format the generated code
157+
cmd := exec.Command("gofmt", "-w", "-s", tempGoFile)
158+
output, err := cmd.CombinedOutput()
159+
if err != nil {
160+
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
161+
}
162+
163+
// Run the generated code
164+
cmd = exec.Command("go", "run", tempGoFile)
165+
output, err = cmd.CombinedOutput()
166+
if err != nil {
167+
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
168+
}
169+
170+
// Retrieve nftables state using nft
171+
actualOutput, err := listNFTRuleset()
172+
if err != nil {
173+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
174+
}
175+
176+
expectedOutput, err := os.ReadFile(filename)
177+
if err != nil {
178+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
179+
}
180+
181+
if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) {
182+
log.Printf("Expected output:\n%s", string(expectedOutput))
183+
log.Printf("Actual output:\n%s", actualOutput)
184+
185+
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput))
186+
}
187+
188+
if err := flushNFTRuleset(); err != nil {
189+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
190+
}
191+
}
192+
}
193+
194+
func applyNFTRuleset(scriptPath string) (string, error) {
195+
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
196+
out, err := cmd.CombinedOutput()
197+
if err != nil {
198+
return string(out), err
199+
}
200+
return strings.TrimSpace(string(out)), nil
201+
}
202+
203+
func listNFTRuleset() (string, error) {
204+
cmd := exec.Command("nft", "list", "ruleset")
205+
out, err := cmd.CombinedOutput()
206+
if err != nil {
207+
return string(out), err
208+
}
209+
return strings.TrimSpace(string(out)), nil
210+
}
211+
212+
func flushNFTRuleset() error {
213+
cmd := exec.Command("nft", "flush", "ruleset")
214+
return cmd.Run()
215+
}
216+
217+
func ChainHookRef(hookNum *nftables.ChainHook) string {
218+
i := uint32(0)
219+
if hookNum != nil {
220+
i = uint32(*hookNum)
221+
}
222+
switch i {
223+
case 0:
224+
return "nftables.ChainHookPrerouting"
225+
case 1:
226+
return "nftables.ChainHookInput"
227+
case 2:
228+
return "nftables.ChainHookForward"
229+
case 3:
230+
return "nftables.ChainHookOutput"
231+
case 4:
232+
return "nftables.ChainHookPostrouting"
233+
case 5:
234+
return "nftables.ChainHookIngress"
235+
case 6:
236+
return "nftables.ChainHookEgress"
237+
}
238+
return ""
239+
}
240+
241+
func ChainPrioRef(priority *nftables.ChainPriority) string {
242+
i := int32(0)
243+
if priority != nil {
244+
i = int32(*priority)
245+
}
246+
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
247+
}
248+
249+
func ChainTypeString(chaintype nftables.ChainType) string {
250+
switch chaintype {
251+
case nftables.ChainTypeFilter:
252+
return "nftables.ChainTypeFilter"
253+
case nftables.ChainTypeRoute:
254+
return "nftables.ChainTypeRoute"
255+
case nftables.ChainTypeNAT:
256+
return "nftables.ChainTypeNAT"
257+
default:
258+
return "nftables.ChainTypeFilter"
259+
}
260+
}
261+
262+
func TableFamilyString(family nftables.TableFamily) string {
263+
switch family {
264+
case nftables.TableFamilyUnspecified:
265+
return "nftables.TableFamilyUnspecified"
266+
case nftables.TableFamilyINet:
267+
return "nftables.TableFamilyINet"
268+
case nftables.TableFamilyIPv4:
269+
return "nftables.TableFamilyIPv4"
270+
case nftables.TableFamilyIPv6:
271+
return "nftables.TableFamilyIPv6"
272+
case nftables.TableFamilyARP:
273+
return "nftables.TableFamilyARP"
274+
case nftables.TableFamilyNetdev:
275+
return "nftables.TableFamilyNetdev"
276+
case nftables.TableFamilyBridge:
277+
return "nftables.TableFamilyBridge"
278+
default:
279+
return "nftables.TableFamilyIPv4"
280+
}
281+
}
282+
283+
func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool {
284+
// Remove all indentation from both strings
285+
re := regexp.MustCompile(`(?m)^\s+`)
286+
str1 = re.ReplaceAllString(str1, "")
287+
str2 = re.ReplaceAllString(str2, "")
288+
289+
return str1 == str2
290+
}

0 commit comments

Comments
 (0)