From f9941219b9b17dfffddd7c8084d38a857ed6d64f Mon Sep 17 00:00:00 2001 From: Thomas Desrosiers Date: Wed, 19 Nov 2025 18:37:50 -0500 Subject: [PATCH 1/4] upstreaming CIDR and IP support to cel-go --- ext/BUILD.bazel | 2 + ext/network.go | 531 ++++++++++++++++++++++++++++++++++++++++++++ ext/network_test.go | 365 ++++++++++++++++++++++++++++++ 3 files changed, 898 insertions(+) create mode 100644 ext/network.go create mode 100644 ext/network_test.go diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index ef4f4ec3..596b30b4 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "lists.go", "math.go", "native.go", + "network.go", "protos.go", "regex.go", "sets.go", @@ -60,6 +61,7 @@ go_test( "lists_test.go", "math_test.go", "native_test.go", + "network_test.go", "protos_test.go", "regex_test.go", "sets_test.go", diff --git a/ext/network.go b/ext/network.go new file mode 100644 index 00000000..32a5b091 --- /dev/null +++ b/ext/network.go @@ -0,0 +1,531 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "net" + "reflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// Network returns a cel.EnvOption to configure extended functions for network +// address parsing, inspection, and CIDR range manipulation. +// +// Note: This library defines global functions `ip`, `cidr`, `isIP`, and +// `isCIDR`. If you are currently using variables named `ip` or `cidr`, these +// functions will likely work as intended, however there is a chance for +// collision. +// +// The library closely mirrors the behavior of the Kubernetes CEL network +// libraries, treating IP addresses and CIDR ranges as opaque types with +// specific member functions. +// +// # IP Addresses +// +// The `ip` function converts a string to an IP address (IPv4 or IPv6). If the +// string is not a valid IP, an error is returned. The `isIP` function checks +// if a string is a valid IP address without throwing an error. +// +// ip(string) -> ip +// isIP(string) -> bool +// +// Examples: +// +// ip('127.0.0.1') +// ip('::1') +// isIP('1.2.3.4') // true +// isIP('invalid') // false +// +// # CIDR Ranges +// +// The `cidr` function converts a string to a Classless Inter-Domain Routing +// (CIDR) range. If the string is not valid, an error is returned. The `isCIDR` +// function checks if a string is a valid CIDR notation. +// +// cidr(string) -> cidr +// isCIDR(string) -> bool +// +// Examples: +// +// cidr('192.168.0.0/24') +// cidr('::1/128') +// isCIDR('10.0.0.0/8') // true +// +// # IP Member Functions +// +// IP objects support various inspection methods. +// +// .family() -> int +// .isCanonical() -> bool +// .isLoopback() -> bool +// .isGlobalUnicast() -> bool +// .isLinkLocalMulticast() -> bool +// .isLinkLocalUnicast() -> bool +// .isUnspecified() -> bool +// +// Note on Canonicalization: `isCanonical()` returns true if the input string +// used to construct the IP matches the RFC 5952 canonical string representation +// of that address. +// +// Examples: +// +// ip('127.0.0.1').family() == 4 +// ip('::1').family() == 6 +// ip('127.0.0.1').isLoopback() == true +// ip('2001:db8::1').isCanonical() == true // RFC 5952 format +// ip('2001:DB8::1').isCanonical() == false // Uppercase is not canonical +// ip('2001:db8:0:0:0:0:0:1').isCanonical() == false // Expanded is not canonical +// +// # CIDR Member Functions +// +// CIDR objects support containment checks and property extraction. +// +// .containsIP(ip|string) -> bool +// .containsCIDR(cidr|string) -> bool +// .ip() -> ip +// .masked() -> cidr +// .prefixLength() -> int +// +// Examples: +// +// cidr('10.0.0.0/8').containsIP(ip('10.0.0.1')) == true +// cidr('10.0.0.0/8').containsIP('10.0.0.1') == true +// cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16') == true +// cidr('192.168.1.5/24').ip() == ip('192.168.1.5') +// cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24') +// cidr('192.168.1.0/24').prefixLength() == 24 +func Network() cel.EnvOption { + return cel.Lib(&networkLib{}) +} + +const ( + // Function names + isIPFunc = "isIP" + isCIDRFunc = "isCIDR" + ipFunc = "ip" + cidrFunc = "cidr" + familyFunc = "family" + isCanonicalFunc = "isCanonical" + isLoopbackFunc = "isLoopback" + isGlobalUnicastFunc = "isGlobalUnicast" + isUnspecifiedFunc = "isUnspecified" + isLinkLocalMcastFunc = "isLinkLocalMulticast" + isLinkLocalUcastFunc = "isLinkLocalUnicast" + containsIPFunc = "containsIP" + containsCIDRFunc = "containsCIDR" + maskedFunc = "masked" + prefixLengthFunc = "prefixLength" + ipFromCIDRFunc = "ip" +) + +var ( + // Definitions for the Opaque Types + networkIPType = cel.ObjectType("network.IP", traits.ReceiverType) + networkCIDRType = cel.ObjectType("network.CIDR", traits.ReceiverType) +) + +type networkLib struct{} + +func (*networkLib) LibraryName() string { + return "cel.lib.ext.network" +} + +// CompileOptions implements the Library interface method. +func (*networkLib) CompileOptions() []cel.EnvOption { + return []cel.EnvOption{ + cel.Types( + networkIPType, + networkCIDRType, + ), + // Global Checkers + cel.Function(isIPFunc, + cel.Overload("isIP_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsIP)), + ), + cel.Function(isCIDRFunc, + cel.Overload("isCIDR_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsCIDR)), + ), + + // Constructors + cel.Function(ipFunc, + cel.Overload("ip_string", []*cel.Type{cel.StringType}, networkIPType, + cel.UnaryBinding(netIPString)), + ), + cel.Function(cidrFunc, + cel.Overload("cidr_string", []*cel.Type{cel.StringType}, networkCIDRType, + cel.UnaryBinding(netCIDRString)), + ), + + // IP Member Functions + cel.Function(familyFunc, + cel.MemberOverload("ip_family", []*cel.Type{networkIPType}, cel.IntType, + cel.UnaryBinding(netIPFamily)), + ), + cel.Function(isCanonicalFunc, + cel.MemberOverload("ip_isCanonical", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsCanonical)), + ), + cel.Function(isLoopbackFunc, + cel.MemberOverload("ip_isLoopback", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLoopback)), + ), + cel.Function(isGlobalUnicastFunc, + cel.MemberOverload("ip_isGlobalUnicast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsGlobalUnicast)), + ), + cel.Function(isUnspecifiedFunc, + cel.MemberOverload("ip_isUnspecified", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsUnspecified)), + ), + cel.Function(isLinkLocalMcastFunc, + cel.MemberOverload("ip_isLinkLocalMulticast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalMulticast)), + ), + cel.Function(isLinkLocalUcastFunc, + cel.MemberOverload("ip_isLinkLocalUnicast", []*cel.Type{networkIPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalUnicast)), + ), + + // CIDR Member Functions + cel.Function(containsIPFunc, + cel.MemberOverload("cidr_containsIP_ip", []*cel.Type{networkCIDRType, networkIPType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIP)), + cel.MemberOverload("cidr_containsIP_string", []*cel.Type{networkCIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIPString)), + ), + cel.Function(containsCIDRFunc, + cel.MemberOverload("cidr_containsCIDR_cidr", []*cel.Type{networkCIDRType, networkCIDRType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDR)), + cel.MemberOverload("cidr_containsCIDR_string", []*cel.Type{networkCIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDRString)), + ), + cel.Function(maskedFunc, + cel.MemberOverload("cidr_masked", []*cel.Type{networkCIDRType}, networkCIDRType, + cel.UnaryBinding(netCIDRMasked)), + ), + cel.Function(prefixLengthFunc, + cel.MemberOverload("cidr_prefixLength", []*cel.Type{networkCIDRType}, cel.IntType, + cel.UnaryBinding(netCIDRPrefixLength)), + ), + cel.Function(ipFromCIDRFunc, + cel.MemberOverload("cidr_ip", []*cel.Type{networkCIDRType}, networkIPType, + cel.UnaryBinding(netCIDRIP)), + ), + } +} + +func (*networkLib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +// --- Implementation Logic --- + +func netIsIP(val ref.Val) ref.Val { + s, ok := val.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(net.ParseIP(string(s)) != nil) +} + +func netIsCIDR(val ref.Val) ref.Val { + s, ok := val.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + _, _, err := net.ParseCIDR(string(s)) + return types.Bool(err == nil) +} + +func netIPString(val ref.Val) ref.Val { + s, ok := val.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + str := string(s) + ip := net.ParseIP(str) + if ip == nil { + return types.NewErr("invalid ip address: %s", str) + } + // Store both the logic (IP) and the representation (str) + return ipValue{IP: ip, str: str} +} + +func netCIDRString(val ref.Val) ref.Val { + s, ok := val.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + str := string(s) + // CHANGE: Capture 'ip' (the specific address) alongside 'ipNet' + ip, ipNet, err := net.ParseCIDR(str) + if err != nil { + return types.NewErr("invalid cidr range: %s", str) + } + return cidrValue{IPNet: ipNet, ExtraIP: ip, str: str} +} + +func netIPFamily(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + if ip.IP.To4() != nil { + return types.Int(4) + } + return types.Int(6) +} + +func netIPIsCanonical(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + // 1-to-1 Parity: Check if the input string matches the standard library output + return types.Bool(ip.str == ip.IP.String()) +} + +func netIPIsLoopback(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(ip.IP.IsLoopback()) +} + +func netIPIsGlobalUnicast(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(ip.IP.IsGlobalUnicast()) +} + +func netIPIsUnspecified(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(ip.IP.IsUnspecified()) +} + +func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(ip.IP.IsLinkLocalMulticast()) +} + +func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { + ip, ok := val.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + return types.Bool(ip.IP.IsLinkLocalUnicast()) +} + +func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { + cidr, ok := lhs.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(lhs) + } + ip, ok := rhs.(ipValue) + if !ok { + return types.MaybeNoSuchOverloadErr(rhs) + } + return types.Bool(cidr.IPNet.Contains(ip.IP)) +} + +func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { + cidr, ok := lhs.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(lhs) + } + s, ok := rhs.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(rhs) + } + ip := net.ParseIP(string(s)) + if ip == nil { + return types.NewErr("invalid ip address: %s", s) + } + return types.Bool(cidr.IPNet.Contains(ip)) +} + +func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { + parent, ok := lhs.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(lhs) + } + child, ok := rhs.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(rhs) + } + ones1, _ := parent.IPNet.Mask.Size() + ones2, _ := child.IPNet.Mask.Size() + return types.Bool(parent.IPNet.Contains(child.IPNet.IP) && ones1 <= ones2) +} + +func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { + parent, ok := lhs.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(lhs) + } + s, ok := rhs.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(rhs) + } + _, childNet, err := net.ParseCIDR(string(s)) + if err != nil { + return types.NewErr("invalid cidr range: %s", s) + } + ones1, _ := parent.IPNet.Mask.Size() + ones2, _ := childNet.Mask.Size() + return types.Bool(parent.IPNet.Contains(childNet.IP) && ones1 <= ones2) +} + +func netCIDRMasked(val ref.Val) ref.Val { + cidr, ok := val.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + // Mask both the network struct AND our extra IP + maskedIP := cidr.IPNet.IP.Mask(cidr.IPNet.Mask) + newNet := &net.IPNet{IP: maskedIP, Mask: cidr.IPNet.Mask} + + // In a masked CIDR, the specific IP is the same as the network IP + return cidrValue{IPNet: newNet, ExtraIP: maskedIP, str: newNet.String()} +} + +func netCIDRPrefixLength(val ref.Val) ref.Val { + cidr, ok := val.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + ones, _ := cidr.IPNet.Mask.Size() + return types.Int(ones) +} + +func netCIDRIP(val ref.Val) ref.Val { + cidr, ok := val.(cidrValue) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + // Extract IP. Use String() to ensure we have a valid canonical string representation + // for the new IP object. Return the specific host IP, not the network address + return ipValue{IP: cidr.ExtraIP, str: cidr.ExtraIP.String()} +} + +// --- Opaque Type Wrappers --- + +// ipValue implements ref.Val +type ipValue struct { + net.IP + str string // Kept for isCanonical checks +} + +func (i ipValue) ConvertToNative(typeDesc reflect.Type) (any, error) { + if typeDesc == reflect.TypeOf(net.IP{}) { + return i.IP, nil + } + if typeDesc.Kind() == reflect.String { + return i.IP.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +func (i ipValue) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(i.IP.String()) + case networkIPType: + return i + case types.TypeType: + return networkIPType + } + return types.NewErr("type conversion error from '%s' to '%s'", networkIPType, typeValue) +} + +func (i ipValue) Equal(other ref.Val) ref.Val { + o, ok := other.(ipValue) + if !ok { + return types.ValOrErr(other, "no such overload") + } + // Correctness: Equality is based on the actual IP bytes, + // NOT the string representation. + // ip("127.0.0.1") == ip("127.000.000.001") -> True + return types.Bool(i.IP.Equal(o.IP)) +} + +func (i ipValue) Type() ref.Type { + return networkIPType +} + +func (i ipValue) Value() any { + return i.IP +} + +// cidrValue implements ref.Val +type cidrValue struct { + *net.IPNet + ExtraIP net.IP // IP Address with host bits set + str string +} + +func (c cidrValue) ConvertToNative(typeDesc reflect.Type) (any, error) { + if typeDesc == reflect.TypeOf(&net.IPNet{}) { + return c.IPNet, nil + } + if typeDesc.Kind() == reflect.String { + return c.IPNet.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +func (c cidrValue) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(c.IPNet.String()) + case networkCIDRType: + return c + case types.TypeType: + return networkCIDRType + } + return types.NewErr("type conversion error from '%s' to '%s'", networkCIDRType, typeValue) +} + +func (c cidrValue) Equal(other ref.Val) ref.Val { + o, ok := other.(cidrValue) + if !ok { + return types.ValOrErr(other, "no such overload") + } + // Correctness: Equality is based on the IP bytes AND the mask bytes. + return types.Bool(c.IPNet.IP.Equal(o.IPNet.IP) && c.IPNet.Mask.String() == o.IPNet.Mask.String()) +} + +func (c cidrValue) Type() ref.Type { + return networkCIDRType +} + +func (c cidrValue) Value() any { + return c.IPNet +} diff --git a/ext/network_test.go b/ext/network_test.go new file mode 100644 index 00000000..8b2ab757 --- /dev/null +++ b/ext/network_test.go @@ -0,0 +1,365 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "reflect" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" +) + +func TestNetwork_Success(t *testing.T) { + // These test cases are ported from kubernetes/staging/src/k8s.io/apiserver/pkg/cel/library + // to ensure 1-to-1 parity with the Kubernetes implementation. + tests := []struct { + name string + expr string + out any + }{ + // --- Global Checks (isIP, isCIDR) --- + { + name: "isIP valid IPv4", + expr: "isIP('1.2.3.4')", + out: true, + }, + { + name: "isIP valid IPv6", + expr: "isIP('2001:db8::1')", + out: true, + }, + { + name: "isIP invalid", + expr: "isIP('not.an.ip')", + out: false, + }, + { + name: "isIP with port (invalid)", + expr: "isIP('127.0.0.1:80')", + out: false, + }, + { + name: "isCIDR valid", + expr: "isCIDR('10.0.0.0/8')", + out: true, + }, + { + name: "isCIDR invalid mask", + expr: "isCIDR('10.0.0.0/999')", + out: false, + }, + + // --- IP Constructors & Equality --- + { + name: "ip equality IPv4", + expr: "ip('127.0.0.1') == ip('127.0.0.1')", + out: true, + }, + { + name: "ip inequality", + expr: "ip('127.0.0.1') == ip('1.2.3.4')", + out: false, + }, + { + name: "ip equality IPv6 mixed case inputs", + // Logic check: The value is equal even if string rep was different + expr: "ip('2001:db8::1') == ip('2001:DB8::1')", + out: true, + }, + + // --- Family --- + { + name: "family IPv4", + expr: "ip('127.0.0.1').family()", + out: int64(4), + }, + { + name: "family IPv6", + expr: "ip('::1').family()", + out: int64(6), + }, + + // --- Canonicalization (Critical Feature) --- + { + name: "isCanonical IPv4 simple", + expr: "ip('127.0.0.1').isCanonical()", + out: true, + }, + { + // ::ffff:127.0.0.1 is valid, but canonical form is 127.0.0.1 + name: "isCanonical IPv4-mapped IPv6 (valid but non-canonical)", + expr: "ip('::ffff:127.0.0.1').isCanonical()", + out: false, + }, + { + name: "isCanonical IPv6 standard", + expr: "ip('2001:db8::1').isCanonical()", + out: true, + }, + { + name: "isCanonical IPv6 uppercase (invalid)", + expr: "ip('2001:DB8::1').isCanonical()", + out: false, + }, + { + name: "isCanonical IPv6 expanded (invalid)", + expr: "ip('2001:db8:0:0:0:0:0:1').isCanonical()", + out: false, + }, + + // --- IP Types (Loopback, Unspecified, etc) --- + { + name: "isLoopback IPv4", + expr: "ip('127.0.0.1').isLoopback()", + out: true, + }, + { + name: "isLoopback IPv6", + expr: "ip('::1').isLoopback()", + out: true, + }, + { + name: "isUnspecified IPv4", + expr: "ip('0.0.0.0').isUnspecified()", + out: true, + }, + { + name: "isUnspecified IPv6", + expr: "ip('::').isUnspecified()", + out: true, + }, + { + name: "isGlobalUnicast 8.8.8.8", + expr: "ip('8.8.8.8').isGlobalUnicast()", + out: true, + }, + { + name: "isLinkLocalMulticast", + expr: "ip('ff02::1').isLinkLocalMulticast()", + out: true, + }, + + // --- CIDR Accessors --- + { + name: "cidr prefixLength", + expr: "cidr('192.168.0.0/24').prefixLength()", + out: int64(24), + }, + { + name: "cidr ip extraction", + expr: "cidr('192.168.0.0/24').ip() == ip('192.168.0.0')", + out: true, + }, + { + name: "cidr ip extraction (host bits set)", + // K8s behavior: cidr('1.2.3.4/24').ip() returns 1.2.3.4, not 1.2.3.0 + expr: "cidr('192.168.1.5/24').ip() == ip('192.168.1.5')", + out: true, + }, + { + name: "cidr masked", + // masked() zeroes out the host bits + expr: "cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + { + name: "cidr masked identity", + expr: "cidr('192.168.1.0/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + + // --- Containment (IP in CIDR) --- + { + name: "containsIP simple", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.1.2.3'))", + out: true, + }, + { + name: "containsIP string overload", + expr: "cidr('10.0.0.0/8').containsIP('10.1.2.3')", + out: true, + }, + { + name: "containsIP edge case (network address)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.0.0.0'))", + out: true, + }, + { + name: "containsIP edge case (broadcast)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.255.255.255'))", + out: true, + }, + { + name: "containsIP false", + expr: "cidr('10.0.0.0/8').containsIP(ip('11.0.0.0'))", + out: false, + }, + + // --- Containment (CIDR in CIDR) --- + { + name: "containsCIDR exact match", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.0.0.0/8'))", + out: true, + }, + { + name: "containsCIDR subnet", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.1.0.0/16'))", + out: true, + }, + { + name: "containsCIDR string overload", + expr: "cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16')", + out: true, + }, + { + name: "containsCIDR larger prefix (false)", + // /8 does not contain /4 + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('0.0.0.0/4'))", + out: false, + }, + { + name: "containsCIDR disjoint", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('11.0.0.0/8'))", + out: false, + }, + { + name: "containsCIDR different family", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('::1/128'))", + out: false, + }, + } + + // Initialize the environment with the Network extension + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval(%q) failed: %v", tst.expr, err) + } + + // Convert the CEL result to a native Go value for comparison + got, err := out.ConvertToNative(reflect.TypeOf(tst.out)) + if err != nil { + t.Fatalf("ConvertToNative failed for expr %q: %v", tst.expr, err) + } + + if !reflect.DeepEqual(got, tst.out) { + t.Errorf("Expr %q result got %v, wanted %v", tst.expr, got, tst.out) + } + }) + } +} + +func TestNetwork_RuntimeErrors(t *testing.T) { + tests := []struct { + name string + expr string + errContains string + }{ + { + name: "ip constructor invalid", + expr: "ip('999.999.999.999')", + errContains: "invalid ip address", + }, + { + name: "cidr constructor invalid", + expr: "cidr('1.2.3.4')", + errContains: "invalid cidr range", + }, + { + name: "cidr constructor invalid mask", + expr: "cidr('10.0.0.0/999')", + errContains: "invalid cidr range", + }, + { + name: "containsIP string overload invalid", + expr: "cidr('10.0.0.0/8').containsIP('not-an-ip')", + errContains: "invalid ip address", + }, + { + name: "containsCIDR string overload invalid", + expr: "cidr('10.0.0.0/8').containsCIDR('not-a-cidr')", + errContains: "invalid cidr range", + }, + } + + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + // Note: We only check runtime errors here. Compile errors are unexpected + // because these functions accept strings, so type-check passes. + t.Fatalf("Compile(%q) failed unexpectedly: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + _, _, err = prg.Eval(cel.NoVars()) + if err == nil { + t.Errorf("Expected runtime error for %q, got nil", tst.expr) + return + } + + // CEL errors are sometimes wrapped, so we check substring + if !types.IsError(types.NewErr(err.Error())) { + // Just a sanity check that it is indeed a CEL-compatible error structure + // Not strictly necessary but good practice + } + + // Standard substring check + gotErr := err.Error() + // We just check if the message contains the specific error text we return in network.go + found := false + // Note: The actual error might be wrapped in "evaluation error: ..." + if len(tst.errContains) > 0 { + // Simple string contains check + for i := 0; i < len(gotErr)-len(tst.errContains)+1; i++ { + if gotErr[i:i+len(tst.errContains)] == tst.errContains { + found = true + break + } + } + } + + if !found { + t.Errorf("Expected error containing %q, got %q", tst.errContains, gotErr) + } + }) + } +} From 008905d96f711980b7b140209717eb6e86fd0db0 Mon Sep 17 00:00:00 2001 From: Thomas Desrosiers Date: Thu, 20 Nov 2025 11:25:08 -0500 Subject: [PATCH 2/4] adding type adapter and exporting types IP and CIDR --- ext/network.go | 125 ++++++++++++++++++++++++++----------------------- 1 file changed, 67 insertions(+), 58 deletions(-) diff --git a/ext/network.go b/ext/network.go index 32a5b091..273d5a56 100644 --- a/ext/network.go +++ b/ext/network.go @@ -137,8 +137,8 @@ const ( var ( // Definitions for the Opaque Types - networkIPType = cel.ObjectType("network.IP", traits.ReceiverType) - networkCIDRType = cel.ObjectType("network.CIDR", traits.ReceiverType) + networkIPType = types.NewTypeValue("network.IP", traits.ReceiverType) + networkCIDRType = types.NewTypeValue("network.CIDR", traits.ReceiverType) ) type networkLib struct{} @@ -147,13 +147,18 @@ func (*networkLib) LibraryName() string { return "cel.lib.ext.network" } -// CompileOptions implements the Library interface method. func (*networkLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ + // 1. Register the types cel.Types( networkIPType, networkCIDRType, ), + // 2. Register the Adapter (Correctly placed here) + cel.CustomTypeAdapter(&networkAdapter{ + Adapter: types.DefaultTypeAdapter, + }), + // 3. Register the Functions // Global Checkers cel.Function(isIPFunc, cel.Overload("isIP_string", []*cel.Type{cel.StringType}, cel.BoolType, @@ -163,7 +168,6 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.Overload("isCIDR_string", []*cel.Type{cel.StringType}, cel.BoolType, cel.UnaryBinding(netIsCIDR)), ), - // Constructors cel.Function(ipFunc, cel.Overload("ip_string", []*cel.Type{cel.StringType}, networkIPType, @@ -173,7 +177,6 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.Overload("cidr_string", []*cel.Type{cel.StringType}, networkCIDRType, cel.UnaryBinding(netCIDRString)), ), - // IP Member Functions cel.Function(familyFunc, cel.MemberOverload("ip_family", []*cel.Type{networkIPType}, cel.IntType, @@ -203,7 +206,6 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.MemberOverload("ip_isLinkLocalUnicast", []*cel.Type{networkIPType}, cel.BoolType, cel.UnaryBinding(netIPIsLinkLocalUnicast)), ), - // CIDR Member Functions cel.Function(containsIPFunc, cel.MemberOverload("cidr_containsIP_ip", []*cel.Type{networkCIDRType, networkIPType}, cel.BoolType, @@ -236,6 +238,22 @@ func (*networkLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } +// networkAdapter implements types.Adapter to handle net.IP and *net.IPNet conversion. +type networkAdapter struct { + types.Adapter +} + +func (a *networkAdapter) NativeToValue(value any) ref.Val { + switch v := value.(type) { + case net.IP: + // If passing a raw net.IP, we assume standard string representation + return IP{IP: v, Str: v.String()} + case *net.IPNet: + return CIDR{IPNet: v, Str: v.String()} + } + return a.Adapter.NativeToValue(value) +} + // --- Implementation Logic --- func netIsIP(val ref.Val) ref.Val { @@ -265,8 +283,7 @@ func netIPString(val ref.Val) ref.Val { if ip == nil { return types.NewErr("invalid ip address: %s", str) } - // Store both the logic (IP) and the representation (str) - return ipValue{IP: ip, str: str} + return IP{IP: ip, Str: str} } func netCIDRString(val ref.Val) ref.Val { @@ -275,16 +292,16 @@ func netCIDRString(val ref.Val) ref.Val { return types.MaybeNoSuchOverloadErr(val) } str := string(s) - // CHANGE: Capture 'ip' (the specific address) alongside 'ipNet' ip, ipNet, err := net.ParseCIDR(str) if err != nil { return types.NewErr("invalid cidr range: %s", str) } - return cidrValue{IPNet: ipNet, ExtraIP: ip, str: str} + // Store the specific IP (which might have host bits set) alongside the network + return CIDR{IPNet: ipNet, ExtraIP: ip, Str: str} } func netIPFamily(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -295,16 +312,15 @@ func netIPFamily(val ref.Val) ref.Val { } func netIPIsCanonical(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } - // 1-to-1 Parity: Check if the input string matches the standard library output - return types.Bool(ip.str == ip.IP.String()) + return types.Bool(ip.Str == ip.IP.String()) } func netIPIsLoopback(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -312,7 +328,7 @@ func netIPIsLoopback(val ref.Val) ref.Val { } func netIPIsGlobalUnicast(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -320,7 +336,7 @@ func netIPIsGlobalUnicast(val ref.Val) ref.Val { } func netIPIsUnspecified(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -328,7 +344,7 @@ func netIPIsUnspecified(val ref.Val) ref.Val { } func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -336,7 +352,7 @@ func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { } func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { - ip, ok := val.(ipValue) + ip, ok := val.(IP) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -344,11 +360,11 @@ func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { } func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { - cidr, ok := lhs.(cidrValue) + cidr, ok := lhs.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(lhs) } - ip, ok := rhs.(ipValue) + ip, ok := rhs.(IP) if !ok { return types.MaybeNoSuchOverloadErr(rhs) } @@ -356,7 +372,7 @@ func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { } func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { - cidr, ok := lhs.(cidrValue) + cidr, ok := lhs.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(lhs) } @@ -372,11 +388,11 @@ func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { } func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { - parent, ok := lhs.(cidrValue) + parent, ok := lhs.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(lhs) } - child, ok := rhs.(cidrValue) + child, ok := rhs.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(rhs) } @@ -386,7 +402,7 @@ func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { } func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { - parent, ok := lhs.(cidrValue) + parent, ok := lhs.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(lhs) } @@ -404,20 +420,17 @@ func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { } func netCIDRMasked(val ref.Val) ref.Val { - cidr, ok := val.(cidrValue) + cidr, ok := val.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(val) } - // Mask both the network struct AND our extra IP maskedIP := cidr.IPNet.IP.Mask(cidr.IPNet.Mask) newNet := &net.IPNet{IP: maskedIP, Mask: cidr.IPNet.Mask} - - // In a masked CIDR, the specific IP is the same as the network IP - return cidrValue{IPNet: newNet, ExtraIP: maskedIP, str: newNet.String()} + return CIDR{IPNet: newNet, ExtraIP: maskedIP, Str: newNet.String()} } func netCIDRPrefixLength(val ref.Val) ref.Val { - cidr, ok := val.(cidrValue) + cidr, ok := val.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(val) } @@ -426,24 +439,23 @@ func netCIDRPrefixLength(val ref.Val) ref.Val { } func netCIDRIP(val ref.Val) ref.Val { - cidr, ok := val.(cidrValue) + cidr, ok := val.(CIDR) if !ok { return types.MaybeNoSuchOverloadErr(val) } - // Extract IP. Use String() to ensure we have a valid canonical string representation - // for the new IP object. Return the specific host IP, not the network address - return ipValue{IP: cidr.ExtraIP, str: cidr.ExtraIP.String()} + return IP{IP: cidr.ExtraIP, Str: cidr.ExtraIP.String()} } // --- Opaque Type Wrappers --- -// ipValue implements ref.Val -type ipValue struct { +// IP is an exported CEL value that wraps net.IP. +// It implements ref.Val. +type IP struct { net.IP - str string // Kept for isCanonical checks + Str string } -func (i ipValue) ConvertToNative(typeDesc reflect.Type) (any, error) { +func (i IP) ConvertToNative(typeDesc reflect.Type) (any, error) { if typeDesc == reflect.TypeOf(net.IP{}) { return i.IP, nil } @@ -453,7 +465,7 @@ func (i ipValue) ConvertToNative(typeDesc reflect.Type) (any, error) { return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) } -func (i ipValue) ConvertToType(typeValue ref.Type) ref.Val { +func (i IP) ConvertToType(typeValue ref.Type) ref.Val { switch typeValue { case types.StringType: return types.String(i.IP.String()) @@ -465,33 +477,31 @@ func (i ipValue) ConvertToType(typeValue ref.Type) ref.Val { return types.NewErr("type conversion error from '%s' to '%s'", networkIPType, typeValue) } -func (i ipValue) Equal(other ref.Val) ref.Val { - o, ok := other.(ipValue) +func (i IP) Equal(other ref.Val) ref.Val { + o, ok := other.(IP) if !ok { return types.ValOrErr(other, "no such overload") } - // Correctness: Equality is based on the actual IP bytes, - // NOT the string representation. - // ip("127.0.0.1") == ip("127.000.000.001") -> True return types.Bool(i.IP.Equal(o.IP)) } -func (i ipValue) Type() ref.Type { +func (i IP) Type() ref.Type { return networkIPType } -func (i ipValue) Value() any { +func (i IP) Value() any { return i.IP } -// cidrValue implements ref.Val -type cidrValue struct { +// CIDR is an exported CEL value that wraps *net.IPNet. +// It implements ref.Val. +type CIDR struct { *net.IPNet - ExtraIP net.IP // IP Address with host bits set - str string + ExtraIP net.IP + Str string } -func (c cidrValue) ConvertToNative(typeDesc reflect.Type) (any, error) { +func (c CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) { if typeDesc == reflect.TypeOf(&net.IPNet{}) { return c.IPNet, nil } @@ -501,7 +511,7 @@ func (c cidrValue) ConvertToNative(typeDesc reflect.Type) (any, error) { return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) } -func (c cidrValue) ConvertToType(typeValue ref.Type) ref.Val { +func (c CIDR) ConvertToType(typeValue ref.Type) ref.Val { switch typeValue { case types.StringType: return types.String(c.IPNet.String()) @@ -513,19 +523,18 @@ func (c cidrValue) ConvertToType(typeValue ref.Type) ref.Val { return types.NewErr("type conversion error from '%s' to '%s'", networkCIDRType, typeValue) } -func (c cidrValue) Equal(other ref.Val) ref.Val { - o, ok := other.(cidrValue) +func (c CIDR) Equal(other ref.Val) ref.Val { + o, ok := other.(CIDR) if !ok { return types.ValOrErr(other, "no such overload") } - // Correctness: Equality is based on the IP bytes AND the mask bytes. return types.Bool(c.IPNet.IP.Equal(o.IPNet.IP) && c.IPNet.Mask.String() == o.IPNet.Mask.String()) } -func (c cidrValue) Type() ref.Type { +func (c CIDR) Type() ref.Type { return networkCIDRType } -func (c cidrValue) Value() any { +func (c CIDR) Value() any { return c.IPNet } From 7976e9fa20def06f794f47e713164e683f8cbaed Mon Sep 17 00:00:00 2001 From: Thomas Desrosiers Date: Thu, 20 Nov 2025 13:14:59 -0500 Subject: [PATCH 3/4] upgrade network to use net/netip to more appropriately mirror k8s/apiserver implementation --- ext/network.go | 285 ++++++++++++++++++-------------------------- ext/network_test.go | 18 +-- 2 files changed, 119 insertions(+), 184 deletions(-) diff --git a/ext/network.go b/ext/network.go index 273d5a56..98b32e38 100644 --- a/ext/network.go +++ b/ext/network.go @@ -16,7 +16,7 @@ package ext import ( "fmt" - "net" + "net/netip" "reflect" "github.com/google/cel-go/cel" @@ -28,14 +28,17 @@ import ( // Network returns a cel.EnvOption to configure extended functions for network // address parsing, inspection, and CIDR range manipulation. // -// Note: This library defines global functions `ip`, `cidr`, `isIP`, and -// `isCIDR`. If you are currently using variables named `ip` or `cidr`, these -// functions will likely work as intended, however there is a chance for -// collision. +// Note: This library defines global functions `ip`, `cidr`, `isIP`, `isCIDR` +// and `ip.isCanonical`. If you are currently using variables named `ip` or +// `cidr`, these functions will likely work as intended, however there is a +// chance for collision. // // The library closely mirrors the behavior of the Kubernetes CEL network -// libraries, treating IP addresses and CIDR ranges as opaque types with -// specific member functions. +// libraries, treating IP addresses and CIDR ranges as opaque types. It parses +// IPs strictly: IPv4-mapped IPv6 addresses and IP zones are not allowed. +// +// This library includes a TypeAdapter that allows `netip.Addr` and +// `netip.Prefix` Go types to be passed directly into the CEL environment. // // # IP Addresses // @@ -68,30 +71,30 @@ import ( // cidr('::1/128') // isCIDR('10.0.0.0/8') // true // -// # IP Member Functions +// # IP Inspection and Canonicalization // // IP objects support various inspection methods. // // .family() -> int -// .isCanonical() -> bool // .isLoopback() -> bool // .isGlobalUnicast() -> bool // .isLinkLocalMulticast() -> bool // .isLinkLocalUnicast() -> bool // .isUnspecified() -> bool // -// Note on Canonicalization: `isCanonical()` returns true if the input string -// used to construct the IP matches the RFC 5952 canonical string representation -// of that address. +// The `ip.isCanonical` function takes a string and returns true if it matches +// the RFC 5952 canonical string representation of that address. +// +// ip.isCanonical(string) -> bool // // Examples: // // ip('127.0.0.1').family() == 4 // ip('::1').family() == 6 // ip('127.0.0.1').isLoopback() == true -// ip('2001:db8::1').isCanonical() == true // RFC 5952 format -// ip('2001:DB8::1').isCanonical() == false // Uppercase is not canonical -// ip('2001:db8:0:0:0:0:0:1').isCanonical() == false // Expanded is not canonical +// ip.isCanonical('2001:db8::1') == true // RFC 5952 format +// ip.isCanonical('2001:DB8::1') == false // Uppercase is not canonical +// ip.isCanonical('2001:db8:0:0:0:0:0:1') == false // Expanded is not canonical // // # CIDR Member Functions // @@ -122,7 +125,7 @@ const ( ipFunc = "ip" cidrFunc = "cidr" familyFunc = "family" - isCanonicalFunc = "isCanonical" + isCanonicalFunc = "ip.isCanonical" isLoopbackFunc = "isLoopback" isGlobalUnicastFunc = "isGlobalUnicast" isUnspecifiedFunc = "isUnspecified" @@ -149,17 +152,16 @@ func (*networkLib) LibraryName() string { func (*networkLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ - // 1. Register the types + // 1. Register Types cel.Types( networkIPType, networkCIDRType, ), - // 2. Register the Adapter (Correctly placed here) + // 2. Register Adapter (Bundled here so it applies automatically) cel.CustomTypeAdapter(&networkAdapter{ Adapter: types.DefaultTypeAdapter, }), - // 3. Register the Functions - // Global Checkers + // 3. Register Functions cel.Function(isIPFunc, cel.Overload("isIP_string", []*cel.Type{cel.StringType}, cel.BoolType, cel.UnaryBinding(netIsIP)), @@ -168,7 +170,6 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.Overload("isCIDR_string", []*cel.Type{cel.StringType}, cel.BoolType, cel.UnaryBinding(netIsCIDR)), ), - // Constructors cel.Function(ipFunc, cel.Overload("ip_string", []*cel.Type{cel.StringType}, networkIPType, cel.UnaryBinding(netIPString)), @@ -177,13 +178,12 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.Overload("cidr_string", []*cel.Type{cel.StringType}, networkCIDRType, cel.UnaryBinding(netCIDRString)), ), - // IP Member Functions cel.Function(familyFunc, cel.MemberOverload("ip_family", []*cel.Type{networkIPType}, cel.IntType, cel.UnaryBinding(netIPFamily)), ), cel.Function(isCanonicalFunc, - cel.MemberOverload("ip_isCanonical", []*cel.Type{networkIPType}, cel.BoolType, + cel.Overload("ip_isCanonical_string", []*cel.Type{cel.StringType}, cel.BoolType, cel.UnaryBinding(netIPIsCanonical)), ), cel.Function(isLoopbackFunc, @@ -206,7 +206,6 @@ func (*networkLib) CompileOptions() []cel.EnvOption { cel.MemberOverload("ip_isLinkLocalUnicast", []*cel.Type{networkIPType}, cel.BoolType, cel.UnaryBinding(netIPIsLinkLocalUnicast)), ), - // CIDR Member Functions cel.Function(containsIPFunc, cel.MemberOverload("cidr_containsIP_ip", []*cel.Type{networkCIDRType, networkIPType}, cel.BoolType, cel.BinaryBinding(netCIDRContainsIP)), @@ -238,229 +237,173 @@ func (*networkLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -// networkAdapter implements types.Adapter to handle net.IP and *net.IPNet conversion. +// networkAdapter adapts netip types. type networkAdapter struct { types.Adapter } func (a *networkAdapter) NativeToValue(value any) ref.Val { switch v := value.(type) { - case net.IP: - // If passing a raw net.IP, we assume standard string representation - return IP{IP: v, Str: v.String()} - case *net.IPNet: - return CIDR{IPNet: v, Str: v.String()} + case netip.Addr: + return IP{Addr: v} + case netip.Prefix: + return CIDR{Prefix: v} } return a.Adapter.NativeToValue(value) } // --- Implementation Logic --- -func netIsIP(val ref.Val) ref.Val { - s, ok := val.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(val) +func parseIPAddr(raw string) (netip.Addr, error) { + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("IP Address %q parse error: %v", raw, err) + } + if addr.Zone() != "" { + return netip.Addr{}, fmt.Errorf("IP address %q with zone value is not allowed", raw) + } + if addr.Is4In6() { + return netip.Addr{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw) } - return types.Bool(net.ParseIP(string(s)) != nil) + return addr, nil +} + +func netIsIP(val ref.Val) ref.Val { + s := val.(types.String) + _, err := parseIPAddr(string(s)) + return types.Bool(err == nil) } func netIsCIDR(val ref.Val) ref.Val { - s, ok := val.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - _, _, err := net.ParseCIDR(string(s)) + s := val.(types.String) + _, err := netip.ParsePrefix(string(s)) return types.Bool(err == nil) } func netIPString(val ref.Val) ref.Val { - s, ok := val.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } + s := val.(types.String) str := string(s) - ip := net.ParseIP(str) - if ip == nil { - return types.NewErr("invalid ip address: %s", str) + addr, err := parseIPAddr(str) + if err != nil { + return types.NewErr("%v", err) } - return IP{IP: ip, Str: str} + return IP{Addr: addr} } func netCIDRString(val ref.Val) ref.Val { - s, ok := val.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } + s := val.(types.String) str := string(s) - ip, ipNet, err := net.ParseCIDR(str) + prefix, err := netip.ParsePrefix(str) if err != nil { return types.NewErr("invalid cidr range: %s", str) } - // Store the specific IP (which might have host bits set) alongside the network - return CIDR{IPNet: ipNet, ExtraIP: ip, Str: str} + return CIDR{Prefix: prefix} } func netIPFamily(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - if ip.IP.To4() != nil { + ip := val.(IP) + if ip.Addr.Is4() { return types.Int(4) } return types.Int(6) } func netIPIsCanonical(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) + s := val.(types.String) + str := string(s) + addr, err := parseIPAddr(str) + if err != nil { + return types.NewErr("%v", err) } - return types.Bool(ip.Str == ip.IP.String()) + return types.Bool(addr.String() == str) } func netIPIsLoopback(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return types.Bool(ip.IP.IsLoopback()) + ip := val.(IP) + return types.Bool(ip.Addr.IsLoopback()) } func netIPIsGlobalUnicast(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return types.Bool(ip.IP.IsGlobalUnicast()) + ip := val.(IP) + return types.Bool(ip.Addr.IsGlobalUnicast()) } func netIPIsUnspecified(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return types.Bool(ip.IP.IsUnspecified()) + ip := val.(IP) + return types.Bool(ip.Addr.IsUnspecified()) } func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return types.Bool(ip.IP.IsLinkLocalMulticast()) + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalMulticast()) } func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { - ip, ok := val.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return types.Bool(ip.IP.IsLinkLocalUnicast()) + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalUnicast()) } func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { - cidr, ok := lhs.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(lhs) - } - ip, ok := rhs.(IP) - if !ok { - return types.MaybeNoSuchOverloadErr(rhs) - } - return types.Bool(cidr.IPNet.Contains(ip.IP)) + cidr := lhs.(CIDR) + ip := rhs.(IP) + return types.Bool(cidr.Prefix.Contains(ip.Addr)) } func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { - cidr, ok := lhs.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(lhs) - } - s, ok := rhs.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(rhs) - } - ip := net.ParseIP(string(s)) - if ip == nil { - return types.NewErr("invalid ip address: %s", s) + cidr := lhs.(CIDR) + s := rhs.(types.String) + addr, err := parseIPAddr(string(s)) + if err != nil { + return types.NewErr("%v", err) } - return types.Bool(cidr.IPNet.Contains(ip)) + return types.Bool(cidr.Prefix.Contains(addr)) } func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { - parent, ok := lhs.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(lhs) - } - child, ok := rhs.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(rhs) - } - ones1, _ := parent.IPNet.Mask.Size() - ones2, _ := child.IPNet.Mask.Size() - return types.Bool(parent.IPNet.Contains(child.IPNet.IP) && ones1 <= ones2) + parent := lhs.(CIDR) + child := rhs.(CIDR) + return types.Bool(parent.Prefix.Overlaps(child.Prefix) && parent.Prefix.Bits() <= child.Prefix.Bits()) } func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { - parent, ok := lhs.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(lhs) - } - s, ok := rhs.(types.String) - if !ok { - return types.MaybeNoSuchOverloadErr(rhs) - } - _, childNet, err := net.ParseCIDR(string(s)) + parent := lhs.(CIDR) + s := rhs.(types.String) + childPrefix, err := netip.ParsePrefix(string(s)) if err != nil { return types.NewErr("invalid cidr range: %s", s) } - ones1, _ := parent.IPNet.Mask.Size() - ones2, _ := childNet.Mask.Size() - return types.Bool(parent.IPNet.Contains(childNet.IP) && ones1 <= ones2) + return types.Bool(parent.Prefix.Overlaps(childPrefix) && parent.Prefix.Bits() <= childPrefix.Bits()) } func netCIDRMasked(val ref.Val) ref.Val { - cidr, ok := val.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - maskedIP := cidr.IPNet.IP.Mask(cidr.IPNet.Mask) - newNet := &net.IPNet{IP: maskedIP, Mask: cidr.IPNet.Mask} - return CIDR{IPNet: newNet, ExtraIP: maskedIP, Str: newNet.String()} + cidr := val.(CIDR) + return CIDR{Prefix: cidr.Prefix.Masked()} } func netCIDRPrefixLength(val ref.Val) ref.Val { - cidr, ok := val.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - ones, _ := cidr.IPNet.Mask.Size() - return types.Int(ones) + cidr := val.(CIDR) + return types.Int(cidr.Prefix.Bits()) } func netCIDRIP(val ref.Val) ref.Val { - cidr, ok := val.(CIDR) - if !ok { - return types.MaybeNoSuchOverloadErr(val) - } - return IP{IP: cidr.ExtraIP, Str: cidr.ExtraIP.String()} + cidr := val.(CIDR) + return IP{Addr: cidr.Prefix.Addr()} } // --- Opaque Type Wrappers --- -// IP is an exported CEL value that wraps net.IP. -// It implements ref.Val. +// IP is an exported CEL value that wraps netip.Addr. type IP struct { - net.IP - Str string + netip.Addr } func (i IP) ConvertToNative(typeDesc reflect.Type) (any, error) { - if typeDesc == reflect.TypeOf(net.IP{}) { - return i.IP, nil + // Use reflect.TypeFor to avoid instantiating netip.Addr{} + if typeDesc == reflect.TypeFor[netip.Addr]() { + return i.Addr, nil } if typeDesc.Kind() == reflect.String { - return i.IP.String(), nil + return i.Addr.String(), nil } return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) } @@ -468,7 +411,7 @@ func (i IP) ConvertToNative(typeDesc reflect.Type) (any, error) { func (i IP) ConvertToType(typeValue ref.Type) ref.Val { switch typeValue { case types.StringType: - return types.String(i.IP.String()) + return types.String(i.Addr.String()) case networkIPType: return i case types.TypeType: @@ -482,7 +425,7 @@ func (i IP) Equal(other ref.Val) ref.Val { if !ok { return types.ValOrErr(other, "no such overload") } - return types.Bool(i.IP.Equal(o.IP)) + return types.Bool(i.Addr == o.Addr) } func (i IP) Type() ref.Type { @@ -490,23 +433,21 @@ func (i IP) Type() ref.Type { } func (i IP) Value() any { - return i.IP + return i.Addr } -// CIDR is an exported CEL value that wraps *net.IPNet. -// It implements ref.Val. +// CIDR is an exported CEL value that wraps netip.Prefix. type CIDR struct { - *net.IPNet - ExtraIP net.IP - Str string + netip.Prefix } func (c CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) { - if typeDesc == reflect.TypeOf(&net.IPNet{}) { - return c.IPNet, nil + // Use reflect.TypeFor to avoid instantiating netip.Prefix{} + if typeDesc == reflect.TypeFor[netip.Prefix]() { + return c.Prefix, nil } if typeDesc.Kind() == reflect.String { - return c.IPNet.String(), nil + return c.Prefix.String(), nil } return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) } @@ -514,7 +455,7 @@ func (c CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) { func (c CIDR) ConvertToType(typeValue ref.Type) ref.Val { switch typeValue { case types.StringType: - return types.String(c.IPNet.String()) + return types.String(c.Prefix.String()) case networkCIDRType: return c case types.TypeType: @@ -528,7 +469,7 @@ func (c CIDR) Equal(other ref.Val) ref.Val { if !ok { return types.ValOrErr(other, "no such overload") } - return types.Bool(c.IPNet.IP.Equal(o.IPNet.IP) && c.IPNet.Mask.String() == o.IPNet.Mask.String()) + return types.Bool(c.Prefix == o.Prefix) } func (c CIDR) Type() ref.Type { @@ -536,5 +477,5 @@ func (c CIDR) Type() ref.Type { } func (c CIDR) Value() any { - return c.IPNet + return c.Prefix } diff --git a/ext/network_test.go b/ext/network_test.go index 8b2ab757..017f57f6 100644 --- a/ext/network_test.go +++ b/ext/network_test.go @@ -95,28 +95,22 @@ func TestNetwork_Success(t *testing.T) { // --- Canonicalization (Critical Feature) --- { name: "isCanonical IPv4 simple", - expr: "ip('127.0.0.1').isCanonical()", + expr: "ip.isCanonical('127.0.0.1')", out: true, }, - { - // ::ffff:127.0.0.1 is valid, but canonical form is 127.0.0.1 - name: "isCanonical IPv4-mapped IPv6 (valid but non-canonical)", - expr: "ip('::ffff:127.0.0.1').isCanonical()", - out: false, - }, { name: "isCanonical IPv6 standard", - expr: "ip('2001:db8::1').isCanonical()", + expr: "ip.isCanonical('2001:db8::1')", out: true, }, { name: "isCanonical IPv6 uppercase (invalid)", - expr: "ip('2001:DB8::1').isCanonical()", + expr: "ip.isCanonical('2001:DB8::1')", out: false, }, { name: "isCanonical IPv6 expanded (invalid)", - expr: "ip('2001:db8:0:0:0:0:0:1').isCanonical()", + expr: "ip.isCanonical('2001:db8:0:0:0:0:0:1')", out: false, }, @@ -287,7 +281,7 @@ func TestNetwork_RuntimeErrors(t *testing.T) { { name: "ip constructor invalid", expr: "ip('999.999.999.999')", - errContains: "invalid ip address", + errContains: "parse error", }, { name: "cidr constructor invalid", @@ -302,7 +296,7 @@ func TestNetwork_RuntimeErrors(t *testing.T) { { name: "containsIP string overload invalid", expr: "cidr('10.0.0.0/8').containsIP('not-an-ip')", - errContains: "invalid ip address", + errContains: "parse error", }, { name: "containsCIDR string overload invalid", From 8544fd52c38e763e93fdf98fcbdf3894438d7d58 Mon Sep 17 00:00:00 2001 From: Thomas Desrosiers Date: Mon, 24 Nov 2025 14:02:53 -0500 Subject: [PATCH 4/4] preserve chain of type adapters with wrapping, use opaque types for IP and CIDR --- ext/network.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/ext/network.go b/ext/network.go index 98b32e38..3465f07f 100644 --- a/ext/network.go +++ b/ext/network.go @@ -22,7 +22,6 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - "github.com/google/cel-go/common/types/traits" ) // Network returns a cel.EnvOption to configure extended functions for network @@ -115,7 +114,17 @@ import ( // cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24') // cidr('192.168.1.0/24').prefixLength() == 24 func Network() cel.EnvOption { - return cel.Lib(&networkLib{}) + return func(e *cel.Env) (*cel.Env, error) { + // Install the library (Types and Functions) + e, err := cel.Lib(&networkLib{})(e) + if err != nil { + return nil, err + } + + // Install the Adapter (Wrapping the existing one) + adapter := &networkAdapter{Adapter: e.CELTypeAdapter()} + return cel.CustomTypeAdapter(adapter)(e) + } } const ( @@ -140,8 +149,8 @@ const ( var ( // Definitions for the Opaque Types - networkIPType = types.NewTypeValue("network.IP", traits.ReceiverType) - networkCIDRType = types.NewTypeValue("network.CIDR", traits.ReceiverType) + networkIPType = types.NewOpaqueType("network.IP") + networkCIDRType = types.NewOpaqueType("network.CIDR") ) type networkLib struct{} @@ -157,11 +166,8 @@ func (*networkLib) CompileOptions() []cel.EnvOption { networkIPType, networkCIDRType, ), - // 2. Register Adapter (Bundled here so it applies automatically) - cel.CustomTypeAdapter(&networkAdapter{ - Adapter: types.DefaultTypeAdapter, - }), - // 3. Register Functions + + // 2. Register Functions cel.Function(isIPFunc, cel.Overload("isIP_string", []*cel.Type{cel.StringType}, cel.BoolType, cel.UnaryBinding(netIsIP)), @@ -237,7 +243,7 @@ func (*networkLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -// networkAdapter adapts netip types. +// networkAdapter adapts netip types while preserving existing adapters. type networkAdapter struct { types.Adapter } @@ -249,6 +255,7 @@ func (a *networkAdapter) NativeToValue(value any) ref.Val { case netip.Prefix: return CIDR{Prefix: v} } + // Delegate to the wrapped adapter (e.g., Protobuf adapter) return a.Adapter.NativeToValue(value) }