Skip to content

Commit 118a52a

Browse files
author
Alessandro Lucantonio
committed
Updated fit, score, predict argument names.
1 parent 8834cd0 commit 118a52a

File tree

4 files changed

+45
-42
lines changed

4 files changed

+45
-42
lines changed

environment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ dependencies:
2727
- tox
2828
- mygrad
2929
- tox-conda
30+
- pmlb

src/alpine/gp/regressor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,9 @@ def mapper(f, individuals, toolbox_ref):
414414

415415
def fit(self, X_train, y_train=None, X_val=None, y_val=None):
416416
"""Fits the training data using GP-based symbolic regression."""
417-
train_data = {"X_train": X_train, "y_train": y_train}
417+
train_data = {"X": X_train, "y": y_train}
418418
if self.validate and X_val is not None:
419-
val_data = {"X_val": X_val, "y_val": y_val}
419+
val_data = {"X": X_val, "y": y_val}
420420
datasets = {"train": train_data, "val": val_data}
421421
else:
422422
datasets = {"train": train_data}
@@ -427,7 +427,7 @@ def fit(self, X_train, y_train=None, X_val=None, y_val=None):
427427
self.__run()
428428

429429
def predict(self, X_test):
430-
test_data = {"X_test": X_test}
430+
test_data = {"X": X_test}
431431
datasets = {"test": test_data}
432432
self.__store_datasets(datasets)
433433
self.__register_predict_func()
@@ -438,7 +438,7 @@ def score(self, X_test, y_test):
438438
"""Computes the error metric (passed to the `GPSymbolicRegressor` constructor)
439439
on a given dataset.
440440
"""
441-
test_data = {"X_test": X_test, "y_test": y_test}
441+
test_data = {"X": X_test, "y": y_test}
442442
datasets = {"test": test_data}
443443
self.__store_datasets(datasets)
444444
self.__register_score_func()

tests/test_basic_sr.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -39,38 +39,38 @@ def eval_MSE_sol(individual, X, y):
3939

4040

4141
@ray.remote
42-
def predict(individuals_str, toolbox, X_test):
42+
def predict(individuals_str, toolbox, X):
4343

4444
callables = compile_individuals(toolbox, individuals_str)
4545

4646
u = [None] * len(individuals_str)
4747

4848
for i, ind in enumerate(callables):
49-
_, u[i] = eval_MSE_sol(ind, X_test, None)
49+
_, u[i] = eval_MSE_sol(ind, X, None)
5050

5151
return u
5252

5353

5454
@ray.remote
55-
def score(individuals_str, toolbox, X_test, y_test):
55+
def score(individuals_str, toolbox, X, y):
5656

5757
callables = compile_individuals(toolbox, individuals_str)
5858

5959
MSE = [None] * len(individuals_str)
6060

6161
for i, ind in enumerate(callables):
62-
MSE[i], _ = eval_MSE_sol(ind, X_test, y_test)
62+
MSE[i], _ = eval_MSE_sol(ind, X, y)
6363

6464
return MSE
6565

6666

6767
@ray.remote
68-
def fitness(individuals_str, toolbox, X_train, y_train):
68+
def fitness(individuals_str, toolbox, X, y):
6969
callables = compile_individuals(toolbox, individuals_str)
7070

7171
fitnesses = [None] * len(individuals_str)
7272
for i, ind in enumerate(callables):
73-
MSE, _ = eval_MSE_sol(ind, X_train, y_train)
73+
MSE, _ = eval_MSE_sol(ind, X, y)
7474

7575
fitnesses[i] = (MSE,)
7676

tests/test_poisson1d.py

+34-32
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def eval_MSE_sol(
3737
residual: Callable, X, y, S: SimplicialComplex, u_0: C.CochainP0
3838
) -> float:
3939

40-
num_nodes = X.shape[1]
40+
num_nodes = S.num_nodes
4141

4242
# need to call config again before using JAX in energy evaluations to make sure that
4343
# the current worker has initialized JAX
@@ -57,47 +57,48 @@ def obj(x, y):
5757

5858
MSE = 0.0
5959

60-
u = []
60+
us = []
6161

62-
for i, curr_y in enumerate(y):
62+
for i, curr_force in enumerate(X):
6363
# set additional arguments of the objective function
6464
# (apart from the vector of unknowns)
65-
args = {"y": curr_y}
65+
args = {"y": curr_force}
6666
prb.set_obj_args(args)
6767

6868
# minimize the objective
69-
x = prb.solve(
69+
u = prb.solve(
7070
x0=u_0.coeffs.flatten(), ftol_abs=1e-12, ftol_rel=1e-12, maxeval=1000
7171
)
7272

73-
if (
74-
prb.last_opt_result == 1
75-
or prb.last_opt_result == 3
76-
or prb.last_opt_result == 4
77-
):
73+
if y is not None:
74+
if (
75+
prb.last_opt_result == 1
76+
or prb.last_opt_result == 3
77+
or prb.last_opt_result == 4
78+
):
7879

79-
current_err = np.linalg.norm(x - X[i, :]) ** 2
80-
else:
81-
current_err = math.nan
80+
current_err = np.linalg.norm(u - y[i, :]) ** 2
81+
else:
82+
current_err = math.nan
8283

83-
if math.isnan(current_err):
84-
MSE = 1e5
85-
break
84+
if math.isnan(current_err):
85+
MSE = 1e5
86+
break
8687

87-
MSE += current_err
88+
MSE += current_err
8889

89-
u.append(x)
90+
us.append(u)
9091

91-
MSE *= 1 / X.shape[0]
92+
MSE *= 1 / num_nodes
9293

93-
return MSE, u
94+
return MSE, us
9495

9596

9697
@ray.remote
9798
def predict(
9899
individuals_str: list[str],
99100
toolbox,
100-
X_test,
101+
X,
101102
S: SimplicialComplex,
102103
u_0: C.CochainP0,
103104
penalty: dict,
@@ -108,7 +109,7 @@ def predict(
108109
u = [None] * len(individuals_str)
109110

110111
for i, ind in enumerate(callables):
111-
_, u[i] = eval_MSE_sol(ind, X_test, None, S, u_0)
112+
_, u[i] = eval_MSE_sol(ind, X, None, S, u_0)
112113

113114
return u
114115

@@ -117,8 +118,8 @@ def predict(
117118
def score(
118119
individuals_str: list[str],
119120
toolbox,
120-
X_test,
121-
y_test,
121+
X,
122+
y,
122123
S: SimplicialComplex,
123124
u_0: C.CochainP0,
124125
penalty: dict,
@@ -129,7 +130,7 @@ def score(
129130
MSE = [None] * len(individuals_str)
130131

131132
for i, ind in enumerate(callables):
132-
MSE[i], _ = eval_MSE_sol(ind, X_test, y_test, S, u_0)
133+
MSE[i], _ = eval_MSE_sol(ind, X, y, S, u_0)
133134

134135
return MSE
135136

@@ -138,8 +139,8 @@ def score(
138139
def fitness(
139140
individuals_str: list[str],
140141
toolbox,
141-
X_train,
142-
y_train,
142+
X,
143+
y,
143144
S: SimplicialComplex,
144145
u_0: C.CochainP0,
145146
penalty: dict,
@@ -150,7 +151,7 @@ def fitness(
150151

151152
fitnesses = [None] * len(individuals_str)
152153
for i, ind in enumerate(callables):
153-
MSE, _ = eval_MSE_sol(ind, X_train, y_train, S, u_0)
154+
MSE, _ = eval_MSE_sol(ind, X, y, S, u_0)
154155

155156
# add penalty on length of the tree to promote simpler solutions
156157
fitnesses[i] = (MSE + penalty["reg_param"] * indlen[i],)
@@ -181,8 +182,9 @@ def test_poisson1d(set_test_dir, yamlfile):
181182
# Delta u + f = 0, where Delta is the discrete Laplace-de Rham operator
182183
f = C.laplacian(u)
183184
f.coeffs *= -1.0
184-
X_train = np.array([u.coeffs.flatten()], dtype=dctkit.float_dtype)
185-
y_train = np.array([f.coeffs.flatten()], dtype=dctkit.float_dtype)
185+
186+
X_train = np.array([f.coeffs.flatten()], dtype=dctkit.float_dtype)
187+
y_train = np.array([u.coeffs.flatten()], dtype=dctkit.float_dtype)
186188

187189
# initial guess for the unknown of the Poisson problem (cochain of nodals values)
188190
u_0_vec = np.zeros(num_nodes, dtype=dctkit.float_dtype)
@@ -219,15 +221,15 @@ def test_poisson1d(set_test_dir, yamlfile):
219221
**regressor_params
220222
)
221223

222-
train_data = Dataset("D", X_train, y_train)
224+
# train_data = Dataset("D", X_train, y_train)
223225

224226
gpsr.fit(X_train, y_train, X_val=X_train, y_val=y_train)
225227

226228
u_best = gpsr.predict(X_train)
227229

228230
fit_score = gpsr.score(X_train, y_train)
229231

230-
gpsr.save_best_test_sols(train_data, "./")
232+
# gpsr.save_best_test_sols(train_data, "./")
231233

232234
ray.shutdown()
233235
assert np.allclose(u.coeffs.flatten(), np.ravel(u_best))

0 commit comments

Comments
 (0)