diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index ef4f4ec3..596b30b4 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "lists.go", "math.go", "native.go", + "network.go", "protos.go", "regex.go", "sets.go", @@ -60,6 +61,7 @@ go_test( "lists_test.go", "math_test.go", "native_test.go", + "network_test.go", "protos_test.go", "regex_test.go", "sets_test.go", diff --git a/ext/network.go b/ext/network.go new file mode 100644 index 00000000..3465f07f --- /dev/null +++ b/ext/network.go @@ -0,0 +1,488 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// Network returns a cel.EnvOption to configure extended functions for network +// address parsing, inspection, and CIDR range manipulation. +// +// Note: This library defines global functions `ip`, `cidr`, `isIP`, `isCIDR` +// and `ip.isCanonical`. If you are currently using variables named `ip` or +// `cidr`, these functions will likely work as intended, however there is a +// chance for collision. +// +// The library closely mirrors the behavior of the Kubernetes CEL network +// libraries, treating IP addresses and CIDR ranges as opaque types. It parses +// IPs strictly: IPv4-mapped IPv6 addresses and IP zones are not allowed. +// +// This library includes a TypeAdapter that allows `netip.Addr` and +// `netip.Prefix` Go types to be passed directly into the CEL environment. +// +// # IP Addresses +// +// The `ip` function converts a string to an IP address (IPv4 or IPv6). If the +// string is not a valid IP, an error is returned. The `isIP` function checks +// if a string is a valid IP address without throwing an error. +// +// ip(string) -> ip +// isIP(string) -> bool +// +// Examples: +// +// ip('127.0.0.1') +// ip('::1') +// isIP('1.2.3.4') // true +// isIP('invalid') // false +// +// # CIDR Ranges +// +// The `cidr` function converts a string to a Classless Inter-Domain Routing +// (CIDR) range. If the string is not valid, an error is returned. The `isCIDR` +// function checks if a string is a valid CIDR notation. +// +// cidr(string) -> cidr +// isCIDR(string) -> bool +// +// Examples: +// +// cidr('192.168.0.0/24') +// cidr('::1/128') +// isCIDR('10.0.0.0/8') // true +// +// # IP Inspection and Canonicalization +// +// IP objects support various inspection methods. +// +// .family() -> int +// .isLoopback() -> bool +// .isGlobalUnicast() -> bool +// .isLinkLocalMulticast() -> bool +// .isLinkLocalUnicast() -> bool +// .isUnspecified() -> bool +// +// The `ip.isCanonical` function takes a string and returns true if it matches +// the RFC 5952 canonical string representation of that address. +// +// ip.isCanonical(string) -> bool +// +// Examples: +// +// ip('127.0.0.1').family() == 4 +// ip('::1').family() == 6 +// ip('127.0.0.1').isLoopback() == true +// ip.isCanonical('2001:db8::1') == true // RFC 5952 format +// ip.isCanonical('2001:DB8::1') == false // Uppercase is not canonical +// ip.isCanonical('2001:db8:0:0:0:0:0:1') == false // Expanded is not canonical +// +// # CIDR Member Functions +// +// CIDR objects support containment checks and property extraction. +// +// .containsIP(ip|string) -> bool +// .containsCIDR(cidr|string) -> bool +// .ip() -> ip +// .masked() -> cidr +// .prefixLength() -> int +// +// Examples: +// +// cidr('10.0.0.0/8').containsIP(ip('10.0.0.1')) == true +// cidr('10.0.0.0/8').containsIP('10.0.0.1') == true +// cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16') == true +// cidr('192.168.1.5/24').ip() == ip('192.168.1.5') +// cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24') +// cidr('192.168.1.0/24').prefixLength() == 24 +func Network() cel.EnvOption { + return func(e *cel.Env) (*cel.Env, error) { + // Install the library (Types and Functions) + e, err := cel.Lib(&networkLib{})(e) + if err != nil { + return nil, err + } + + // Install the Adapter (Wrapping the existing one) + adapter := &networkAdapter{Adapter: e.CELTypeAdapter()} + return cel.CustomTypeAdapter(adapter)(e) + } +} + +const ( + // Function names + isIPFunc = "isIP" + isCIDRFunc = "isCIDR" + ipFunc = "ip" + cidrFunc = "cidr" + familyFunc = "family" + isCanonicalFunc = "ip.isCanonical" + isLoopbackFunc = "isLoopback" + isGlobalUnicastFunc = "isGlobalUnicast" + isUnspecifiedFunc = "isUnspecified" + isLinkLocalMcastFunc = "isLinkLocalMulticast" + isLinkLocalUcastFunc = "isLinkLocalUnicast" + containsIPFunc = "containsIP" + containsCIDRFunc = "containsCIDR" + maskedFunc = "masked" + prefixLengthFunc = "prefixLength" + ipFromCIDRFunc = "ip" +) + +var ( + // Definitions for the Opaque Types + networkIPType = types.NewOpaqueType("network.IP") + networkCIDRType = types.NewOpaqueType("network.CIDR") +) + +type networkLib struct{} + +func (*networkLib) LibraryName() string { + return "cel.lib.ext.network" +} + +func (*networkLib) CompileOptions() []cel.EnvOption { + return []cel.EnvOption{ + // 1. Register Types + cel.Types( + networkIPType, + networkCIDRType, + ), + + // 2. Register Functions + cel.Function(isIPFunc, + cel.Overload("isIP_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsIP)), + ), + cel.Function(isCIDRFunc, + cel.Overload("isCIDR_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsCIDR)), + ), + cel.Function(ipFunc, + cel.Overload("ip_string", []*cel.Type{cel.StringType}, networkIPType, + cel.UnaryBinding(netIPString)), + ), + cel.Function(cidrFunc, + cel.Overload("cidr_string", []*cel.Type{cel.StringType}, networkCIDRType, + cel.UnaryBinding(netCIDRString)), + ), + cel.Function(familyFunc, + cel.MemberOverload("ip_family", []*cel.Type{networkIPType}, cel.IntType, + cel.UnaryBinding(netIPFamily)), + ), + cel.Function(isCanonicalFunc, + cel.Overload("ip_isCanonical_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIPIsCanonical)), + ), + cel.Function(isLoopbackFunc, + cel.MemberOverload("ip_isLoopback", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLoopback)), + ), + cel.Function(isGlobalUnicastFunc, + cel.MemberOverload("ip_isGlobalUnicast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsGlobalUnicast)), + ), + cel.Function(isUnspecifiedFunc, + cel.MemberOverload("ip_isUnspecified", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsUnspecified)), + ), + cel.Function(isLinkLocalMcastFunc, + cel.MemberOverload("ip_isLinkLocalMulticast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalMulticast)), + ), + cel.Function(isLinkLocalUcastFunc, + cel.MemberOverload("ip_isLinkLocalUnicast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalUnicast)), + ), + cel.Function(containsIPFunc, + cel.MemberOverload("cidr_containsIP_ip", []*cel.Type{networkCIDRType, networkIPType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIP)), + cel.MemberOverload("cidr_containsIP_string", []*cel.Type{networkCIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIPString)), + ), + cel.Function(containsCIDRFunc, + cel.MemberOverload("cidr_containsCIDR_cidr", []*cel.Type{networkCIDRType, networkCIDRType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDR)), + cel.MemberOverload("cidr_containsCIDR_string", []*cel.Type{networkCIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDRString)), + ), + cel.Function(maskedFunc, + cel.MemberOverload("cidr_masked", []*cel.Type{networkCIDRType}, networkCIDRType, + cel.UnaryBinding(netCIDRMasked)), + ), + cel.Function(prefixLengthFunc, + cel.MemberOverload("cidr_prefixLength", []*cel.Type{networkCIDRType}, cel.IntType, + cel.UnaryBinding(netCIDRPrefixLength)), + ), + cel.Function(ipFromCIDRFunc, + cel.MemberOverload("cidr_ip", []*cel.Type{networkCIDRType}, networkIPType, + cel.UnaryBinding(netCIDRIP)), + ), + } +} + +func (*networkLib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +// networkAdapter adapts netip types while preserving existing adapters. +type networkAdapter struct { + types.Adapter +} + +func (a *networkAdapter) NativeToValue(value any) ref.Val { + switch v := value.(type) { + case netip.Addr: + return IP{Addr: v} + case netip.Prefix: + return CIDR{Prefix: v} + } + // Delegate to the wrapped adapter (e.g., Protobuf adapter) + return a.Adapter.NativeToValue(value) +} + +// --- Implementation Logic --- + +func parseIPAddr(raw string) (netip.Addr, error) { + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("IP Address %q parse error: %v", raw, err) + } + if addr.Zone() != "" { + return netip.Addr{}, fmt.Errorf("IP address %q with zone value is not allowed", raw) + } + if addr.Is4In6() { + return netip.Addr{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw) + } + return addr, nil +} + +func netIsIP(val ref.Val) ref.Val { + s := val.(types.String) + _, err := parseIPAddr(string(s)) + return types.Bool(err == nil) +} + +func netIsCIDR(val ref.Val) ref.Val { + s := val.(types.String) + _, err := netip.ParsePrefix(string(s)) + return types.Bool(err == nil) +} + +func netIPString(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + addr, err := parseIPAddr(str) + if err != nil { + return types.NewErr("%v", err) + } + return IP{Addr: addr} +} + +func netCIDRString(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + prefix, err := netip.ParsePrefix(str) + if err != nil { + return types.NewErr("invalid cidr range: %s", str) + } + return CIDR{Prefix: prefix} +} + +func netIPFamily(val ref.Val) ref.Val { + ip := val.(IP) + if ip.Addr.Is4() { + return types.Int(4) + } + return types.Int(6) +} + +func netIPIsCanonical(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + addr, err := parseIPAddr(str) + if err != nil { + return types.NewErr("%v", err) + } + return types.Bool(addr.String() == str) +} + +func netIPIsLoopback(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLoopback()) +} + +func netIPIsGlobalUnicast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsGlobalUnicast()) +} + +func netIPIsUnspecified(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsUnspecified()) +} + +func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalMulticast()) +} + +func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalUnicast()) +} + +func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { + cidr := lhs.(CIDR) + ip := rhs.(IP) + return types.Bool(cidr.Prefix.Contains(ip.Addr)) +} + +func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { + cidr := lhs.(CIDR) + s := rhs.(types.String) + addr, err := parseIPAddr(string(s)) + if err != nil { + return types.NewErr("%v", err) + } + return types.Bool(cidr.Prefix.Contains(addr)) +} + +func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { + parent := lhs.(CIDR) + child := rhs.(CIDR) + return types.Bool(parent.Prefix.Overlaps(child.Prefix) && parent.Prefix.Bits() <= child.Prefix.Bits()) +} + +func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { + parent := lhs.(CIDR) + s := rhs.(types.String) + childPrefix, err := netip.ParsePrefix(string(s)) + if err != nil { + return types.NewErr("invalid cidr range: %s", s) + } + return types.Bool(parent.Prefix.Overlaps(childPrefix) && parent.Prefix.Bits() <= childPrefix.Bits()) +} + +func netCIDRMasked(val ref.Val) ref.Val { + cidr := val.(CIDR) + return CIDR{Prefix: cidr.Prefix.Masked()} +} + +func netCIDRPrefixLength(val ref.Val) ref.Val { + cidr := val.(CIDR) + return types.Int(cidr.Prefix.Bits()) +} + +func netCIDRIP(val ref.Val) ref.Val { + cidr := val.(CIDR) + return IP{Addr: cidr.Prefix.Addr()} +} + +// --- Opaque Type Wrappers --- + +// IP is an exported CEL value that wraps netip.Addr. +type IP struct { + netip.Addr +} + +func (i IP) ConvertToNative(typeDesc reflect.Type) (any, error) { + // Use reflect.TypeFor to avoid instantiating netip.Addr{} + if typeDesc == reflect.TypeFor[netip.Addr]() { + return i.Addr, nil + } + if typeDesc.Kind() == reflect.String { + return i.Addr.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +func (i IP) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(i.Addr.String()) + case networkIPType: + return i + case types.TypeType: + return networkIPType + } + return types.NewErr("type conversion error from '%s' to '%s'", networkIPType, typeValue) +} + +func (i IP) Equal(other ref.Val) ref.Val { + o, ok := other.(IP) + if !ok { + return types.ValOrErr(other, "no such overload") + } + return types.Bool(i.Addr == o.Addr) +} + +func (i IP) Type() ref.Type { + return networkIPType +} + +func (i IP) Value() any { + return i.Addr +} + +// CIDR is an exported CEL value that wraps netip.Prefix. +type CIDR struct { + netip.Prefix +} + +func (c CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) { + // Use reflect.TypeFor to avoid instantiating netip.Prefix{} + if typeDesc == reflect.TypeFor[netip.Prefix]() { + return c.Prefix, nil + } + if typeDesc.Kind() == reflect.String { + return c.Prefix.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +func (c CIDR) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(c.Prefix.String()) + case networkCIDRType: + return c + case types.TypeType: + return networkCIDRType + } + return types.NewErr("type conversion error from '%s' to '%s'", networkCIDRType, typeValue) +} + +func (c CIDR) Equal(other ref.Val) ref.Val { + o, ok := other.(CIDR) + if !ok { + return types.ValOrErr(other, "no such overload") + } + return types.Bool(c.Prefix == o.Prefix) +} + +func (c CIDR) Type() ref.Type { + return networkCIDRType +} + +func (c CIDR) Value() any { + return c.Prefix +} diff --git a/ext/network_test.go b/ext/network_test.go new file mode 100644 index 00000000..017f57f6 --- /dev/null +++ b/ext/network_test.go @@ -0,0 +1,359 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "reflect" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" +) + +func TestNetwork_Success(t *testing.T) { + // These test cases are ported from kubernetes/staging/src/k8s.io/apiserver/pkg/cel/library + // to ensure 1-to-1 parity with the Kubernetes implementation. + tests := []struct { + name string + expr string + out any + }{ + // --- Global Checks (isIP, isCIDR) --- + { + name: "isIP valid IPv4", + expr: "isIP('1.2.3.4')", + out: true, + }, + { + name: "isIP valid IPv6", + expr: "isIP('2001:db8::1')", + out: true, + }, + { + name: "isIP invalid", + expr: "isIP('not.an.ip')", + out: false, + }, + { + name: "isIP with port (invalid)", + expr: "isIP('127.0.0.1:80')", + out: false, + }, + { + name: "isCIDR valid", + expr: "isCIDR('10.0.0.0/8')", + out: true, + }, + { + name: "isCIDR invalid mask", + expr: "isCIDR('10.0.0.0/999')", + out: false, + }, + + // --- IP Constructors & Equality --- + { + name: "ip equality IPv4", + expr: "ip('127.0.0.1') == ip('127.0.0.1')", + out: true, + }, + { + name: "ip inequality", + expr: "ip('127.0.0.1') == ip('1.2.3.4')", + out: false, + }, + { + name: "ip equality IPv6 mixed case inputs", + // Logic check: The value is equal even if string rep was different + expr: "ip('2001:db8::1') == ip('2001:DB8::1')", + out: true, + }, + + // --- Family --- + { + name: "family IPv4", + expr: "ip('127.0.0.1').family()", + out: int64(4), + }, + { + name: "family IPv6", + expr: "ip('::1').family()", + out: int64(6), + }, + + // --- Canonicalization (Critical Feature) --- + { + name: "isCanonical IPv4 simple", + expr: "ip.isCanonical('127.0.0.1')", + out: true, + }, + { + name: "isCanonical IPv6 standard", + expr: "ip.isCanonical('2001:db8::1')", + out: true, + }, + { + name: "isCanonical IPv6 uppercase (invalid)", + expr: "ip.isCanonical('2001:DB8::1')", + out: false, + }, + { + name: "isCanonical IPv6 expanded (invalid)", + expr: "ip.isCanonical('2001:db8:0:0:0:0:0:1')", + out: false, + }, + + // --- IP Types (Loopback, Unspecified, etc) --- + { + name: "isLoopback IPv4", + expr: "ip('127.0.0.1').isLoopback()", + out: true, + }, + { + name: "isLoopback IPv6", + expr: "ip('::1').isLoopback()", + out: true, + }, + { + name: "isUnspecified IPv4", + expr: "ip('0.0.0.0').isUnspecified()", + out: true, + }, + { + name: "isUnspecified IPv6", + expr: "ip('::').isUnspecified()", + out: true, + }, + { + name: "isGlobalUnicast 8.8.8.8", + expr: "ip('8.8.8.8').isGlobalUnicast()", + out: true, + }, + { + name: "isLinkLocalMulticast", + expr: "ip('ff02::1').isLinkLocalMulticast()", + out: true, + }, + + // --- CIDR Accessors --- + { + name: "cidr prefixLength", + expr: "cidr('192.168.0.0/24').prefixLength()", + out: int64(24), + }, + { + name: "cidr ip extraction", + expr: "cidr('192.168.0.0/24').ip() == ip('192.168.0.0')", + out: true, + }, + { + name: "cidr ip extraction (host bits set)", + // K8s behavior: cidr('1.2.3.4/24').ip() returns 1.2.3.4, not 1.2.3.0 + expr: "cidr('192.168.1.5/24').ip() == ip('192.168.1.5')", + out: true, + }, + { + name: "cidr masked", + // masked() zeroes out the host bits + expr: "cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + { + name: "cidr masked identity", + expr: "cidr('192.168.1.0/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + + // --- Containment (IP in CIDR) --- + { + name: "containsIP simple", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.1.2.3'))", + out: true, + }, + { + name: "containsIP string overload", + expr: "cidr('10.0.0.0/8').containsIP('10.1.2.3')", + out: true, + }, + { + name: "containsIP edge case (network address)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.0.0.0'))", + out: true, + }, + { + name: "containsIP edge case (broadcast)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.255.255.255'))", + out: true, + }, + { + name: "containsIP false", + expr: "cidr('10.0.0.0/8').containsIP(ip('11.0.0.0'))", + out: false, + }, + + // --- Containment (CIDR in CIDR) --- + { + name: "containsCIDR exact match", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.0.0.0/8'))", + out: true, + }, + { + name: "containsCIDR subnet", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.1.0.0/16'))", + out: true, + }, + { + name: "containsCIDR string overload", + expr: "cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16')", + out: true, + }, + { + name: "containsCIDR larger prefix (false)", + // /8 does not contain /4 + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('0.0.0.0/4'))", + out: false, + }, + { + name: "containsCIDR disjoint", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('11.0.0.0/8'))", + out: false, + }, + { + name: "containsCIDR different family", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('::1/128'))", + out: false, + }, + } + + // Initialize the environment with the Network extension + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval(%q) failed: %v", tst.expr, err) + } + + // Convert the CEL result to a native Go value for comparison + got, err := out.ConvertToNative(reflect.TypeOf(tst.out)) + if err != nil { + t.Fatalf("ConvertToNative failed for expr %q: %v", tst.expr, err) + } + + if !reflect.DeepEqual(got, tst.out) { + t.Errorf("Expr %q result got %v, wanted %v", tst.expr, got, tst.out) + } + }) + } +} + +func TestNetwork_RuntimeErrors(t *testing.T) { + tests := []struct { + name string + expr string + errContains string + }{ + { + name: "ip constructor invalid", + expr: "ip('999.999.999.999')", + errContains: "parse error", + }, + { + name: "cidr constructor invalid", + expr: "cidr('1.2.3.4')", + errContains: "invalid cidr range", + }, + { + name: "cidr constructor invalid mask", + expr: "cidr('10.0.0.0/999')", + errContains: "invalid cidr range", + }, + { + name: "containsIP string overload invalid", + expr: "cidr('10.0.0.0/8').containsIP('not-an-ip')", + errContains: "parse error", + }, + { + name: "containsCIDR string overload invalid", + expr: "cidr('10.0.0.0/8').containsCIDR('not-a-cidr')", + errContains: "invalid cidr range", + }, + } + + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + // Note: We only check runtime errors here. Compile errors are unexpected + // because these functions accept strings, so type-check passes. + t.Fatalf("Compile(%q) failed unexpectedly: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + _, _, err = prg.Eval(cel.NoVars()) + if err == nil { + t.Errorf("Expected runtime error for %q, got nil", tst.expr) + return + } + + // CEL errors are sometimes wrapped, so we check substring + if !types.IsError(types.NewErr(err.Error())) { + // Just a sanity check that it is indeed a CEL-compatible error structure + // Not strictly necessary but good practice + } + + // Standard substring check + gotErr := err.Error() + // We just check if the message contains the specific error text we return in network.go + found := false + // Note: The actual error might be wrapped in "evaluation error: ..." + if len(tst.errContains) > 0 { + // Simple string contains check + for i := 0; i < len(gotErr)-len(tst.errContains)+1; i++ { + if gotErr[i:i+len(tst.errContains)] == tst.errContains { + found = true + break + } + } + } + + if !found { + t.Errorf("Expected error containing %q, got %q", tst.errContains, gotErr) + } + }) + } +}