Skip to content

Commit a8dcfb8

Browse files
committed
test: add property-based tests for COG and bisector functions
1 parent fc9c10c commit a8dcfb8

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

tests/test_defuzz.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import annotations
2+
3+
import math
4+
import timeit
5+
6+
import numpy as np
7+
import pytest
8+
from hypothesis import assume, given
9+
from hypothesis import strategies as st
10+
11+
from fuzzylogic.defuzz import (
12+
_get_max_points,
13+
bisector,
14+
cog,
15+
lom,
16+
mom,
17+
som,
18+
)
19+
from fuzzylogic.functions import Membership
20+
21+
# ---------------------------------------------------------------------------
22+
# Core Testing Infrastructure
23+
# ---------------------------------------------------------------------------
24+
25+
26+
class DummyDomain:
27+
"""Mock domain for testing fuzzy operations"""
28+
29+
def __init__(self, low: float, high: float, n_points: int = 101):
30+
assert low < high, "Invalid domain bounds"
31+
self._low = low
32+
self._high = high
33+
self._n_points = n_points
34+
35+
@property
36+
def range(self) -> list[float]:
37+
return np.linspace(self._low, self._high, self._n_points).tolist()
38+
39+
40+
class DummySet:
41+
"""Mock fuzzy set with configurable properties"""
42+
43+
def __init__(self, cog_value: float, membership_func: Membership | None = None):
44+
self._cog = cog_value
45+
self.membership_func = membership_func or (lambda x: 1.0)
46+
self.domain = None
47+
48+
def center_of_gravity(self) -> float:
49+
return self._cog
50+
51+
def __call__(self, x: float) -> float:
52+
return self.membership_func(x)
53+
54+
55+
# ---------------------------------------------------------------------------
56+
# Property-Based Tests
57+
# ---------------------------------------------------------------------------
58+
59+
60+
@given(
61+
cogs=st.lists(st.floats(min_value=-1e3, max_value=1e3), min_size=1, max_size=10),
62+
weights=st.lists(st.floats(min_value=0.1, max_value=1e3), min_size=1, max_size=10),
63+
domain=st.tuples(st.floats(min_value=-1e3), st.floats(min_value=-1e3)).filter(lambda x: x[0] < x[1]),
64+
)
65+
def test_cog_weighted_average_property(cogs: list[float], weights: list[float], domain: tuple[float, float]):
66+
"""Verify COG is proper weighted average of centroids"""
67+
assume(len(cogs) == len(weights))
68+
low, high = domain
69+
domain_obj = DummyDomain(low, high)
70+
71+
sets = [DummySet(cog) for cog in cogs]
72+
for s in sets:
73+
s.domain = domain_obj
74+
75+
target_weights = list(zip(sets, weights))
76+
result = cog(target_weights)
77+
78+
total_weight = sum(weights)
79+
expected = sum(c * w for c, w in zip(cogs, weights)) / total_weight
80+
assert math.isclose(result, expected, rel_tol=1e-5, abs_tol=1e-5)
81+
82+
83+
@given(
84+
peak=st.floats(allow_nan=False, allow_infinity=False),
85+
width=st.floats(min_value=0.1, max_value=100),
86+
domain=st.tuples(st.floats(), st.floats()).filter(lambda x: x[0] < x[1]),
87+
)
88+
def test_bisector_triangular_property(peak: float, width: float, domain: tuple[float, float]):
89+
"""Test bisector with generated triangular functions"""
90+
low, high = domain
91+
a = peak - width / 2
92+
b = peak
93+
c = peak + width / 2
94+
assume(low <= a < c <= high)
95+
96+
domain_obj = DummyDomain(low, high)
97+
points = domain_obj.range
98+
step = (high - low) / (len(points) - 1)
99+
100+
from fuzzylogic import functions
101+
102+
f = functions.triangular(a, c, c=b)
103+
104+
result = bisector(f, points, step)
105+
assert math.isclose(result, peak, rel_tol=0.01), f"Expected {peak}, got {result}"
106+
107+
108+
# ---------------------------------------------------------------------------
109+
# Edge Cases
110+
# ---------------------------------------------------------------------------
111+
112+
113+
@pytest.mark.parametrize("dtype, tol", [(np.float32, 1e-6), (np.float64, 1e-12), (np.longdouble, 1e-15)])
114+
def test_cog_precision(dtype, tol):
115+
"""Test numerical precision across data types"""
116+
domain = DummyDomain(0, 1, 1001)
117+
exact_val = dtype(0.5)
118+
fuzzy_set = DummySet(float(exact_val))
119+
fuzzy_set.domain = domain
120+
121+
result = cog([(fuzzy_set, 1.0)])
122+
assert abs(result - exact_val) < tol
123+
124+
125+
# ---------------------------------------------------------------------------
126+
# Performance
127+
# ---------------------------------------------------------------------------
128+
129+
130+
def test_cog_linear_scaling():
131+
"""Verify O(n) time complexity"""
132+
sizes = [100, 1000, 10000]
133+
times = []
134+
135+
# sourcery skip: no-loop-in-tests
136+
for _ in sizes:
137+
sets = [DummySet(0.5) for _ in range(10)]
138+
weights = [(s, 1.0) for s in sets]
139+
140+
t = timeit.timeit(lambda: cog(weights), number=10)
141+
times.append(t)
142+
143+
# Check linear correlation
144+
log_sizes = np.log(sizes)
145+
log_times = np.log(times)
146+
corr = np.corrcoef(log_sizes, log_times)[0, 1]
147+
assert corr > 0.95, f"Unexpected complexity (corr={corr:.2f})"
148+
149+
150+
# ---------------------------------------------------------------------------
151+
# Core Functionality
152+
# ---------------------------------------------------------------------------
153+
154+
155+
def test_mom_constant_membership():
156+
"""Test MOM with uniform maximum"""
157+
domain = DummyDomain(0, 10)
158+
points = domain.range
159+
result = mom(lambda _: 1.0, points)
160+
expected = (0 + 10) / 2
161+
assert math.isclose(result, expected)
162+
163+
164+
def test_som_lom_plateau():
165+
"""Test SOM/LOM with plateaued maximum"""
166+
domain = DummyDomain(0, 10)
167+
points = domain.range
168+
agg_mf = lambda x: 1.0 if 3 <= x <= 7 else 0.0
169+
170+
assert math.isclose(som(agg_mf, points), 3.0)
171+
assert math.isclose(lom(agg_mf, points), 7.0)
172+
173+
174+
def test_get_max_points():
175+
"""Test maximum point detection"""
176+
points = [0, 1, 2, 3, 4]
177+
agg_mf = lambda x: 1.0 if x == 2 else 0.5
178+
assert _get_max_points(agg_mf, points) == [2]
179+
180+
181+
if __name__ == "__main__":
182+
pytest.main([__file__, "-v", "--hypothesis-show-statistics"])

0 commit comments

Comments
 (0)