Skip to content

Commit 41d8731

Browse files
committed
refactor: annotations in estimating stuff
1 parent a2e24a0 commit 41d8731

File tree

3 files changed

+37
-27
lines changed

3 files changed

+37
-27
lines changed

src/fuzzylogic/estimate.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import contextlib
1616
import inspect
1717
import sys
18+
from collections.abc import Callable
1819
from itertools import permutations
1920
from random import choice, randint
2021
from statistics import median
22+
from typing import Any
2123

2224
import numpy as np
2325

@@ -36,6 +38,8 @@
3638
triangular,
3739
)
3840

41+
type MembershipSetup = Callable[[Any], Membership]
42+
3943
np.seterr(all="raise")
4044
functions = [step, rectangular]
4145

@@ -61,7 +65,7 @@ def normalize(target: Array, output_length: int = 16) -> Array:
6165
return normalized_array
6266

6367

64-
def guess_function(target: Array) -> Membership:
68+
def guess_function(target: Array) -> MembershipSetup:
6569
normalized = normalize(target)
6670
return constant if np.all(normalized == 1) else singleton
6771

@@ -77,11 +81,11 @@ def fitness(func: Membership, target: Array, certainty: int | None = None) -> fl
7781
return result if certainty is None else round(result, certainty)
7882

7983

80-
def seed_population(func: Membership, target: Array) -> dict[tuple, float]:
84+
def seed_population(func: MembershipSetup, target: Array) -> dict[tuple[float, ...], float]:
8185
# create a random population of parameters
8286
params = [p for p in inspect.signature(func).parameters.values() if p.kind == p.POSITIONAL_OR_KEYWORD]
83-
seed_population = {}
84-
seed_numbers = [
87+
seed_population: dict[tuple[float, ...], float] = {}
88+
seed_numbers: list[float] = [
8589
sys.float_info.min,
8690
sys.float_info.max,
8791
0,
@@ -91,7 +95,7 @@ def seed_population(func: Membership, target: Array) -> dict[tuple, float]:
9195
-0.5,
9296
min(target),
9397
max(target),
94-
np.argmax(target),
98+
float(np.argmax(target)),
9599
]
96100
# seed population
97101
for combination in permutations(seed_numbers, len(params)):
@@ -101,35 +105,41 @@ def seed_population(func: Membership, target: Array) -> dict[tuple, float]:
101105
return seed_population
102106

103107

104-
def reproduce(parent1: tuple, parent2: tuple) -> tuple:
105-
child = []
108+
def reproduce(parent1: tuple[float, ...], parent2: tuple[float, ...]) -> tuple[float, ...]:
109+
child: list[float] = []
106110
for p1, p2 in zip(parent1, parent2):
107111
# mix the parts of the floats by randomness within the range of the parents
108112
# adding a random jitter should avoid issues when p1 == p2
109113
a1, a2 = np.frexp(p1)
110114
b1, b2 = np.frexp(p2)
115+
a1 = float(a1)
116+
b1 = float(b1)
117+
a2 = float(a2)
118+
b2 = float(b2)
111119
a1 += randint(-1, 1)
112120
a2 += randint(-1, 1)
113121
b1 += randint(-1, 1)
114122
b2 += randint(-1, 1)
115-
child.append(((a1 + b1) / 2) * 2 ** np.random.uniform(a2, b2))
123+
child.append(float((a1 + b1) / 2) * 2 ** np.random.uniform(a2, b2))
116124
return tuple(child)
117125

118126

119127
def guess_parameters(
120-
func: Membership, target: Array, precision: int | None = None, certainty: int | None = None
121-
) -> tuple:
122-
"""Find the best fitting parameters for a function, targetting an array.
128+
func: MembershipSetup, target: Array, precision: int | None = None, certainty: int | None = None
129+
) -> tuple[float, ...]:
130+
"""Find the best fitting parameters for a function, targeting an array.
123131
124132
Args:
125-
func (Callable): A possibly matching membership function, such as `fuzzylogic.functions.triangular`.
126-
array (np.ndarray): The target array to fit the function to.
133+
func (MembershipSetup): A possibly matching membership function.
134+
target (Array): The target array to fit the function to.
135+
precision (int | None): The precision of the parameters.
136+
certainty (int | None): The certainty of the fitness score.
127137
128138
Returns:
129-
tuple: The best fitting parameters for the function.
139+
tuple[float, ...]: The best fitting parameters for the function.
130140
"""
131141

132-
def best() -> tuple:
142+
def best() -> tuple[float, ...]:
133143
return sorted(population.items(), key=lambda x: x[1])[0][0]
134144

135145
seed_pop = seed_population(func, target)
@@ -141,9 +151,9 @@ def best() -> tuple:
141151
last_pop = {}
142152
for generation in range(12):
143153
# sort the population by fitness
144-
pop: list[tuple[tuple, float]] = sorted(population.items(), key=lambda x: x[1], reverse=True)[
145-
:pop_size
146-
]
154+
pop: list[tuple[tuple[float, ...], float]] = sorted(
155+
population.items(), key=lambda x: x[1], reverse=True
156+
)[:pop_size]
147157
if not pop:
148158
population = last_pop
149159
return best()
@@ -153,7 +163,7 @@ def best() -> tuple:
153163
print("Lucky!")
154164
return best()
155165
# the next generation
156-
new_population = {}
166+
new_population: dict[tuple[float, ...], float] = {}
157167
killed = 0
158168
for parent1 in pop:
159169
while True:
@@ -195,14 +205,14 @@ def best() -> tuple:
195205
pressure **= 0.999
196206
population |= seed_pop
197207
else:
198-
pressure = median([x[1] for x in population.items()])
208+
pressure: float = median([x[1] for x in population.items()])
199209
return best()
200210

201211

202-
def shave(target: Array, components: dict[Membership, tuple]) -> Array:
212+
def shave(target: Array, components: dict[Membership, tuple[float, ...]]) -> Array:
203213
"""Remove the membership functions from the target array."""
204-
result = np.zeros_like(target)
214+
result: Array = np.zeros_like(target, dtype=float)
205215
for func, params in components.items():
206216
f = func(*params)
207-
result += np.fromiter([f(x) for x in np.arange(*target.shape)], float)
208-
return target - result
217+
result += np.fromiter((f(x) for x in np.arange(*target.shape)), dtype=float) # type: ignore
218+
return np.asarray(target - result, dtype=target.dtype) # type: ignore

src/fuzzylogic/functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def f(x: float) -> float:
509509
try:
510510
return limit - limit / exp(k * x)
511511
except OverflowError:
512-
return float(limit)
512+
return limit
513513

514514
return f
515515

src/fuzzylogic/neural_network.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def generate_examples() -> dict[str, list[Array]]:
17-
examples = defaultdict(lambda: [])
17+
examples: dict[str, list[Array]] = defaultdict(lambda: [])
1818
examples["constant"] = [np.ones(16)]
1919
for x in range(16):
2020
A = np.zeros(16)
@@ -23,5 +23,5 @@ def generate_examples() -> dict[str, list[Array]]:
2323

2424
for x in range(1, 16):
2525
func = R(0, x)
26-
examples["R"].append(func(np.linspace(0, 1, 16)))
26+
examples["R"].append(func(np.linspace(0, 1, 16))) # type: ignore
2727
return examples

0 commit comments

Comments
 (0)