Skip to content

Commit f43a3a9

Browse files
committed
User-specified connections
Signed-off-by: Ziv Nevo <[email protected]>
1 parent d64f5f1 commit f43a3a9

File tree

5 files changed

+343
-20
lines changed

5 files changed

+343
-20
lines changed

pkg/analyzer/connections.go

+246-10
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,34 @@ SPDX-License-Identifier: Apache-2.0
77
package analyzer
88

99
import (
10+
"bufio"
1011
"fmt"
12+
"os"
13+
"slices"
14+
"strings"
15+
16+
core "k8s.io/api/core/v1"
17+
"k8s.io/apimachinery/pkg/util/intstr"
1118
)
1219

20+
type connectionExtractor struct {
21+
workloads []*Resource
22+
services []*Service
23+
logger Logger
24+
}
25+
1326
// This function is at the core of the topology analysis
1427
// For each resource, it finds other resources that may use it and compiles a list of connections holding these dependencies
15-
func discoverConnections(resources []*Resource, links []*Service, logger Logger) []*Connections {
28+
func (ce *connectionExtractor) discoverConnections() []*Connections {
1629
connections := []*Connections{}
17-
for _, destRes := range resources {
18-
deploymentServices := findServices(destRes, links)
19-
logger.Debugf("services matched to %v: %v", destRes.Resource.Name, deploymentServices)
30+
for _, destRes := range ce.workloads {
31+
deploymentServices := ce.findServices(destRes)
32+
ce.logger.Debugf("services matched to %v: %v", destRes.Resource.Name, deploymentServices)
2033
for _, svc := range deploymentServices {
21-
srcRes := findSource(resources, svc)
34+
srcRes := ce.findSource(svc)
2235
for _, r := range srcRes {
2336
if !r.equals(destRes) {
24-
logger.Debugf("source: %s target: %s link: %s", r.Resource.Name, destRes.Resource.Name, svc.Resource.Name)
37+
ce.logger.Debugf("source: %s target: %s link: %s", r.Resource.Name, destRes.Resource.Name, svc.Resource.Name)
2538
connections = append(connections, &Connections{Source: r, Target: destRes, Link: svc})
2639
}
2740
}
@@ -62,9 +75,9 @@ func areSelectorsContained(selectors1 map[string]string, selectors2 []string) bo
6275
}
6376

6477
// findServices returns a list of services that may be in front of a given workload resource
65-
func findServices(resource *Resource, links []*Service) []*Service {
78+
func (ce *connectionExtractor) findServices(resource *Resource) []*Service {
6679
var matchedSvc []*Service
67-
for _, link := range links {
80+
for _, link := range ce.services {
6881
if link.Resource.Namespace != resource.Resource.Namespace {
6982
continue
7083
}
@@ -79,9 +92,9 @@ func findServices(resource *Resource, links []*Service) []*Service {
7992
}
8093

8194
// findSource returns a list of resources that are likely trying to connect to the given service
82-
func findSource(resources []*Resource, service *Service) []*Resource {
95+
func (ce *connectionExtractor) findSource(service *Service) []*Resource {
8396
tRes := []*Resource{}
84-
for _, resource := range resources {
97+
for _, resource := range ce.workloads {
8598
serviceAddresses := getPossibleServiceAddresses(service, resource)
8699
foundSrc := *resource // We copy the resource so we can specify the ports used by the source found
87100
matched := false
@@ -133,3 +146,226 @@ func envValueMatchesService(envVal string, service *Service, serviceAddresses []
133146
}
134147
return false, SvcNetworkAttr{}
135148
}
149+
150+
const (
151+
srcDstDelim = "=>"
152+
endpointsPortDelim = ":"
153+
commentToken = "#"
154+
wildcardToken = "_"
155+
strongWildcardToken = "*"
156+
endpointParts = 3
157+
)
158+
159+
func (ce *connectionExtractor) connectionsFromFile(filename string) ([]*Connections, error) {
160+
file, err := os.Open(filename)
161+
if err != nil {
162+
return nil, err
163+
}
164+
defer file.Close()
165+
166+
conns := []*Connections{}
167+
168+
scanner := bufio.NewScanner(file)
169+
lineNum := 0
170+
for scanner.Scan() {
171+
line := strings.TrimSpace(scanner.Text())
172+
lineNum += 1
173+
if line == "" || strings.HasPrefix(line, commentToken) {
174+
continue
175+
}
176+
lineConns, err := ce.parseConnectionLine(line, lineNum)
177+
if err != nil {
178+
return nil, err
179+
}
180+
conns = slices.Concat(conns, lineConns)
181+
}
182+
183+
if err := scanner.Err(); err != nil {
184+
return nil, err
185+
}
186+
187+
return conns, nil
188+
}
189+
190+
func (ce *connectionExtractor) parseConnectionLine(line string, lineNum int) ([]*Connections, error) {
191+
// Take only the part before # starts a comment
192+
parts := strings.Split(line, commentToken)
193+
if len(parts) == 0 {
194+
return nil, syntaxError("unexpected comment", lineNum)
195+
}
196+
197+
line = parts[0]
198+
199+
parts = strings.Split(line, srcDstDelim)
200+
if len(parts) != 2 {
201+
return nil, syntaxError("connection line must have exactly one => separator", lineNum)
202+
}
203+
204+
src := strings.TrimSpace(parts[0])
205+
srcWorkloads, err := ce.parseEndpoints(src, lineNum)
206+
if err != nil {
207+
return nil, err
208+
}
209+
210+
parts = strings.Split(parts[1], endpointsPortDelim)
211+
if len(parts) == 0 {
212+
return nil, syntaxError("missing destination", lineNum)
213+
}
214+
if len(parts) > 2 {
215+
return nil, syntaxError("connection line must have at most one | separator", lineNum)
216+
}
217+
dst := strings.TrimSpace(parts[0])
218+
dstWorkloads, err := ce.parseEndpoints(dst, lineNum)
219+
if err != nil {
220+
return nil, err
221+
}
222+
223+
protAndPort := &SvcNetworkAttr{Protocol: core.ProtocolTCP}
224+
if len(parts) == 2 {
225+
protAndPort, err = parsePort(parts[1], lineNum)
226+
if err != nil {
227+
return nil, err
228+
}
229+
}
230+
231+
svc := Service{}
232+
svc.Resource.Network = []SvcNetworkAttr{*protAndPort}
233+
234+
conns := []*Connections{}
235+
for _, srcWl := range srcWorkloads {
236+
for _, dstWl := range dstWorkloads {
237+
if srcWl.equals(dstWl) {
238+
continue
239+
}
240+
conns = append(conns, &Connections{
241+
Source: srcWl,
242+
Target: dstWl,
243+
Link: &svc,
244+
})
245+
ce.logger.Infof("Added connection: src: %v, dst: %v, link: %v", srcWl.Resource.Name, dstWl.Resource.Name, svc)
246+
}
247+
}
248+
return conns, nil
249+
}
250+
251+
func (ce *connectionExtractor) parseEndpoints(endpoint string, lineNum int) ([]*Resource, error) {
252+
parts := strings.Split(endpoint, "/")
253+
if len(parts) != endpointParts {
254+
return nil, syntaxError("source and destination must be of the form namespace/kind/name", lineNum)
255+
}
256+
ns, kind, name := parts[0], parts[1], parts[2]
257+
kind = strings.ToUpper(kind[:1]) + kind[1:] // Capitalize kind's first letter
258+
259+
if ns == strongWildcardToken || kind == strongWildcardToken || name == strongWildcardToken {
260+
return ce.parseEndpointWithStrongWildcard(ns, kind, name)
261+
}
262+
263+
var res []*Resource
264+
switch kind {
265+
case service:
266+
res = ce.getWorkloadsBehindMatchingServices(ns, name)
267+
case wildcardToken:
268+
res = slices.Concat(ce.getWorkloadsBehindMatchingServices(ns, name), ce.getMatchingWorkloads(ns, kind, name))
269+
default:
270+
res = ce.getMatchingWorkloads(ns, kind, name)
271+
}
272+
if len(res) == 0 {
273+
return nil, fmt.Errorf("no matching endpoints for %s in the provided manifests", endpoint)
274+
}
275+
return res, nil
276+
}
277+
278+
func (ce *connectionExtractor) parseEndpointWithStrongWildcard(ns, kind, name string) ([]*Resource, error) {
279+
if kind != strongWildcardToken || name != strongWildcardToken {
280+
return nil, fmt.Errorf("bad endpoint pattern %s/%s/%s. Patterns with '*' should either equal '*/*/*' "+
281+
"or have the form '<namespace>/*/*'", ns, kind, name)
282+
}
283+
284+
return nil, fmt.Errorf("endpoints containing '*' are not yet supported")
285+
286+
/*res := Resource{}
287+
if ns != strongWildcardToken {
288+
if len(validation.IsDNS1123Subdomain(ns)) != 0 {
289+
return nil, fmt.Errorf("%s is not a proper namespace name", ns)
290+
}
291+
res.Resource.Namespace = ns
292+
}
293+
return []*Resource{&res}, nil*/
294+
}
295+
296+
func (ce *connectionExtractor) getWorkloadsBehindMatchingServices(ns, svcName string) []*Resource {
297+
workloads := []*Resource{}
298+
for _, svc := range ce.services {
299+
if strMatch(svc.Resource.Namespace, ns) && strMatch(svc.Resource.Name, svcName) {
300+
workloads = slices.Concat(workloads, ce.workloadsOfSvc(svc))
301+
}
302+
}
303+
return workloads
304+
}
305+
306+
func (ce *connectionExtractor) workloadsOfSvc(svc *Service) []*Resource {
307+
svcWorkloads := []*Resource{}
308+
for _, workload := range ce.workloads {
309+
if workload.Resource.Namespace == svc.Resource.Namespace &&
310+
areSelectorsContained(workload.Resource.Labels, svc.Resource.Selectors) {
311+
svcWorkloads = append(svcWorkloads, workload)
312+
}
313+
}
314+
return svcWorkloads
315+
}
316+
317+
func (ce *connectionExtractor) getMatchingWorkloads(ns, kind, name string) []*Resource {
318+
workloads := []*Resource{}
319+
for _, workload := range ce.workloads {
320+
if strMatch(workload.Resource.Namespace, ns) && strMatch(workload.Resource.Kind, kind) &&
321+
strMatch(workload.Resource.Name, name) {
322+
workloads = append(workloads, workload)
323+
}
324+
}
325+
return workloads
326+
}
327+
328+
func parsePort(spec string, lineNum int) (*SvcNetworkAttr, error) {
329+
protocol := core.ProtocolTCP
330+
var port *intstr.IntOrString
331+
332+
parts := strings.Fields(spec)
333+
switch len(parts) {
334+
case 0:
335+
case 2:
336+
parsedPort := intstr.Parse(parts[1])
337+
port = &parsedPort
338+
fallthrough
339+
case 1:
340+
var err error
341+
protocol, err = parseProtocol(parts[0], lineNum)
342+
if err != nil {
343+
return nil, err
344+
}
345+
default:
346+
return nil, syntaxError("port definition should have the form \"<protocol> [<port>]\"", lineNum)
347+
}
348+
349+
ret := &SvcNetworkAttr{Protocol: protocol}
350+
if port != nil {
351+
ret.TargetPort = *port
352+
}
353+
354+
return ret, nil
355+
}
356+
357+
func parseProtocol(protocol string, lineNum int) (core.Protocol, error) {
358+
protocols := []string{string(core.ProtocolTCP), string(core.ProtocolUDP), string(core.ProtocolSCTP)}
359+
if !slices.Contains(protocols, protocol) {
360+
return "", syntaxError("protocol must be one of TCP, UDP, SCTP", lineNum)
361+
}
362+
return core.Protocol(protocol), nil
363+
}
364+
365+
func strMatch(str, pattern string) bool {
366+
return pattern == wildcardToken || str == pattern
367+
}
368+
369+
func syntaxError(errorStr string, lineNum int) error {
370+
return fmt.Errorf("syntax error in line %d: %s", lineNum, errorStr)
371+
}

pkg/analyzer/connections_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package analyzer
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
"k8s.io/apimachinery/pkg/labels"
9+
"k8s.io/apimachinery/pkg/selection"
10+
)
11+
12+
func TestSelector(t *testing.T) {
13+
testStr := "key1=val1, key2=val2"
14+
reqs, err := labels.ParseToRequirements(testStr)
15+
if err != nil {
16+
t.Fatalf("Conversion error: %v", err)
17+
}
18+
19+
res := map[string]string{}
20+
for _, req := range reqs {
21+
if req.Operator() != selection.Equals {
22+
t.Fatalf("Wrong operator: %s", req.Operator())
23+
}
24+
res[req.Key()] = req.Values().List()[0]
25+
}
26+
27+
t.Logf("labels: %v", res)
28+
}
29+
30+
func TestConnectionsFile(t *testing.T) {
31+
logger := NewDefaultLogger()
32+
sockshopDir := filepath.Join(getTestsDir(), "sockshop")
33+
manifestsDir := filepath.Join(sockshopDir, "manifests")
34+
mf := manifestFinder{logger, false, filepath.WalkDir}
35+
manifestFiles, fileErrors := mf.searchForManifestsInDirs([]string{manifestsDir})
36+
require.Empty(t, fileErrors)
37+
38+
resAcc := newResourceAccumulator(logger, false)
39+
parseErrors := resAcc.parseK8sYamls(manifestFiles)
40+
require.Empty(t, parseErrors)
41+
42+
ce := connectionExtractor{workloads: resAcc.workloads, services: resAcc.services, logger: logger}
43+
connections := ce.discoverConnections()
44+
require.NotEmpty(t, connections)
45+
connFilePath := filepath.Join(sockshopDir, "connections.txt")
46+
fileConns, err := ce.connectionsFromFile(connFilePath)
47+
require.Nil(t, err)
48+
require.Len(t, fileConns, 15)
49+
}

pkg/analyzer/policies_synthesizer.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ package analyzer
1313
import (
1414
"io/fs"
1515
"path/filepath"
16+
"slices"
1617

1718
networking "k8s.io/api/networking/v1"
1819
"k8s.io/apimachinery/pkg/util/intstr"
@@ -242,7 +243,18 @@ func (ps *PoliciesSynthesizer) extractConnections(resAcc *resourceAccumulator) (
242243
resAcc.exposeServices()
243244

244245
// Discover all connections between resources
245-
connections := discoverConnections(resAcc.workloads, resAcc.services, ps.logger)
246+
ce := connectionExtractor{workloads: resAcc.workloads, services: resAcc.services, logger: ps.logger}
247+
connections := ce.discoverConnections()
248+
249+
// If user specified a file with extra connections, add them too
250+
if ps.connectionsFile != "" {
251+
fileConns, err := ce.connectionsFromFile(ps.connectionsFile)
252+
if err != nil {
253+
fpErr := failedReadingFile(ps.connectionsFile, err)
254+
return nil, nil, appendAndLogNewError(fileErrors, fpErr, ps.logger)
255+
}
256+
connections = slices.Concat(connections, fileConns)
257+
}
246258
return resAcc.workloads, connections, fileErrors
247259
}
248260

0 commit comments

Comments
 (0)