@@ -37,7 +37,7 @@ def eval_MSE_sol(
37
37
residual : Callable , X , y , S : SimplicialComplex , u_0 : C .CochainP0
38
38
) -> float :
39
39
40
- num_nodes = X . shape [ 1 ]
40
+ num_nodes = S . num_nodes
41
41
42
42
# need to call config again before using JAX in energy evaluations to make sure that
43
43
# the current worker has initialized JAX
@@ -57,47 +57,48 @@ def obj(x, y):
57
57
58
58
MSE = 0.0
59
59
60
- u = []
60
+ us = []
61
61
62
- for i , curr_y in enumerate (y ):
62
+ for i , curr_force in enumerate (X ):
63
63
# set additional arguments of the objective function
64
64
# (apart from the vector of unknowns)
65
- args = {"y" : curr_y }
65
+ args = {"y" : curr_force }
66
66
prb .set_obj_args (args )
67
67
68
68
# minimize the objective
69
- x = prb .solve (
69
+ u = prb .solve (
70
70
x0 = u_0 .coeffs .flatten (), ftol_abs = 1e-12 , ftol_rel = 1e-12 , maxeval = 1000
71
71
)
72
72
73
- if (
74
- prb .last_opt_result == 1
75
- or prb .last_opt_result == 3
76
- or prb .last_opt_result == 4
77
- ):
73
+ if y is not None :
74
+ if (
75
+ prb .last_opt_result == 1
76
+ or prb .last_opt_result == 3
77
+ or prb .last_opt_result == 4
78
+ ):
78
79
79
- current_err = np .linalg .norm (x - X [i , :]) ** 2
80
- else :
81
- current_err = math .nan
80
+ current_err = np .linalg .norm (u - y [i , :]) ** 2
81
+ else :
82
+ current_err = math .nan
82
83
83
- if math .isnan (current_err ):
84
- MSE = 1e5
85
- break
84
+ if math .isnan (current_err ):
85
+ MSE = 1e5
86
+ break
86
87
87
- MSE += current_err
88
+ MSE += current_err
88
89
89
- u .append (x )
90
+ us .append (u )
90
91
91
- MSE *= 1 / X . shape [ 0 ]
92
+ MSE *= 1 / num_nodes
92
93
93
- return MSE , u
94
+ return MSE , us
94
95
95
96
96
97
@ray .remote
97
98
def predict (
98
99
individuals_str : list [str ],
99
100
toolbox ,
100
- X_test ,
101
+ X ,
101
102
S : SimplicialComplex ,
102
103
u_0 : C .CochainP0 ,
103
104
penalty : dict ,
@@ -108,7 +109,7 @@ def predict(
108
109
u = [None ] * len (individuals_str )
109
110
110
111
for i , ind in enumerate (callables ):
111
- _ , u [i ] = eval_MSE_sol (ind , X_test , None , S , u_0 )
112
+ _ , u [i ] = eval_MSE_sol (ind , X , None , S , u_0 )
112
113
113
114
return u
114
115
@@ -117,8 +118,8 @@ def predict(
117
118
def score (
118
119
individuals_str : list [str ],
119
120
toolbox ,
120
- X_test ,
121
- y_test ,
121
+ X ,
122
+ y ,
122
123
S : SimplicialComplex ,
123
124
u_0 : C .CochainP0 ,
124
125
penalty : dict ,
@@ -129,7 +130,7 @@ def score(
129
130
MSE = [None ] * len (individuals_str )
130
131
131
132
for i , ind in enumerate (callables ):
132
- MSE [i ], _ = eval_MSE_sol (ind , X_test , y_test , S , u_0 )
133
+ MSE [i ], _ = eval_MSE_sol (ind , X , y , S , u_0 )
133
134
134
135
return MSE
135
136
@@ -138,8 +139,8 @@ def score(
138
139
def fitness (
139
140
individuals_str : list [str ],
140
141
toolbox ,
141
- X_train ,
142
- y_train ,
142
+ X ,
143
+ y ,
143
144
S : SimplicialComplex ,
144
145
u_0 : C .CochainP0 ,
145
146
penalty : dict ,
@@ -150,7 +151,7 @@ def fitness(
150
151
151
152
fitnesses = [None ] * len (individuals_str )
152
153
for i , ind in enumerate (callables ):
153
- MSE , _ = eval_MSE_sol (ind , X_train , y_train , S , u_0 )
154
+ MSE , _ = eval_MSE_sol (ind , X , y , S , u_0 )
154
155
155
156
# add penalty on length of the tree to promote simpler solutions
156
157
fitnesses [i ] = (MSE + penalty ["reg_param" ] * indlen [i ],)
@@ -181,8 +182,9 @@ def test_poisson1d(set_test_dir, yamlfile):
181
182
# Delta u + f = 0, where Delta is the discrete Laplace-de Rham operator
182
183
f = C .laplacian (u )
183
184
f .coeffs *= - 1.0
184
- X_train = np .array ([u .coeffs .flatten ()], dtype = dctkit .float_dtype )
185
- y_train = np .array ([f .coeffs .flatten ()], dtype = dctkit .float_dtype )
185
+
186
+ X_train = np .array ([f .coeffs .flatten ()], dtype = dctkit .float_dtype )
187
+ y_train = np .array ([u .coeffs .flatten ()], dtype = dctkit .float_dtype )
186
188
187
189
# initial guess for the unknown of the Poisson problem (cochain of nodals values)
188
190
u_0_vec = np .zeros (num_nodes , dtype = dctkit .float_dtype )
@@ -219,15 +221,15 @@ def test_poisson1d(set_test_dir, yamlfile):
219
221
** regressor_params
220
222
)
221
223
222
- train_data = Dataset ("D" , X_train , y_train )
224
+ # train_data = Dataset("D", X_train, y_train)
223
225
224
226
gpsr .fit (X_train , y_train , X_val = X_train , y_val = y_train )
225
227
226
228
u_best = gpsr .predict (X_train )
227
229
228
230
fit_score = gpsr .score (X_train , y_train )
229
231
230
- gpsr .save_best_test_sols (train_data , "./" )
232
+ # gpsr.save_best_test_sols(train_data, "./")
231
233
232
234
ray .shutdown ()
233
235
assert np .allclose (u .coeffs .flatten (), np .ravel (u_best ))
0 commit comments