Skip to content

Commit 1087220

Browse files
author
Alessandro Lucantonio
committed
Fixed issues with setting CPU device for JAX-based fitness evals.
1 parent a7f95ce commit 1087220

File tree

3 files changed

+33
-272
lines changed

3 files changed

+33
-272
lines changed

environment.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ channels:
44
- defaults
55
dependencies:
66
- gmsh
7-
- jax
7+
- jax==0.5.0
88
- jaxopt
99
- numpy
1010
- pygmo
11-
- python
11+
- python==3.12
1212
- python-gmsh
1313
- trame
1414
- ipywidgets
@@ -26,3 +26,4 @@ dependencies:
2626
- pygmsh
2727
- tox
2828
- mygrad
29+
- tox-conda

examples/poisson.ipynb

-256
This file was deleted.

tests/test_basic_sr.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import jax.numpy as jnp
88
import ray
99

10-
# Define new functions
10+
# sets CPU device for JAX at process level
11+
os.environ["JAX_PLATFORMS"] = "cpu"
1112

1213
config()
1314

@@ -16,16 +17,19 @@ def compile_individuals(toolbox, individuals_str_batch):
1617
return [toolbox.compile(expr=ind) for ind in individuals_str_batch]
1718

1819

19-
x = jnp.array([x/10. for x in range(-10, 10)])
20+
x = jnp.array([x / 10.0 for x in range(-10, 10)])
2021
y = x**4 + x**3 + x**2 + x
2122

2223

2324
def eval_MSE_sol(individual, true_data):
25+
import os
26+
27+
os.environ["JAX_PLATFORMS"] = "cpu"
2428
config()
2529
# Evaluate the mean squared error between the expression
2630
# and the real function : x**4 + x**3 + x**2 + x
2731
y_pred = individual(true_data.X)
28-
MSE = jnp.sum(jnp.square(y_pred-true_data.y)) / len(true_data.X)
32+
MSE = jnp.sum(jnp.square(y_pred - true_data.y)) / len(true_data.X)
2933
if jnp.isnan(MSE):
3034
MSE = 1e5
3135
return MSE, y_pred
@@ -36,7 +40,7 @@ def predict(individuals_str, toolbox, true_data):
3640

3741
callables = compile_individuals(toolbox, individuals_str)
3842

39-
u = [None]*len(individuals_str)
43+
u = [None] * len(individuals_str)
4044

4145
for i, ind in enumerate(callables):
4246
_, u[i] = eval_MSE_sol(ind, true_data)
@@ -49,7 +53,7 @@ def score(individuals_str, toolbox, true_data):
4953

5054
callables = compile_individuals(toolbox, individuals_str)
5155

52-
MSE = [None]*len(individuals_str)
56+
MSE = [None] * len(individuals_str)
5357

5458
for i, ind in enumerate(callables):
5559
MSE[i], _ = eval_MSE_sol(ind, true_data)
@@ -61,7 +65,7 @@ def score(individuals_str, toolbox, true_data):
6165
def fitness(individuals_str, toolbox, true_data):
6266
callables = compile_individuals(toolbox, individuals_str)
6367

64-
fitnesses = [None]*len(individuals_str)
68+
fitnesses = [None] * len(individuals_str)
6569
for i, ind in enumerate(callables):
6670
MSE, _ = eval_MSE_sol(ind, true_data)
6771

@@ -76,18 +80,30 @@ def test_basic_sr(set_test_dir):
7680
with open(filename) as config_file:
7781
config_file_data = yaml.safe_load(config_file)
7882

79-
pset = gp.PrimitiveSetTyped("MAIN", [float,], float)
83+
pset = gp.PrimitiveSetTyped(
84+
"MAIN",
85+
[
86+
float,
87+
],
88+
float,
89+
)
8090
pset.addPrimitive(jnp.add, [float, float], float, "AddF")
81-
pset.renameArguments(ARG0='x')
91+
pset.renameArguments(ARG0="x")
8292

8393
common_data = {}
8494
seed = [
85-
"AddF(AddF(AddF(MulF(MulF(x, MulF(x, x)),x), MulF(x,MulF(x, x))), MulF(x, x)), x)"] # noqa: E501
86-
gpsr = GPSymbolicRegressor(pset=pset, fitness=fitness.remote,
87-
error_metric=score.remote, predict_func=predict.remote,
88-
common_data=common_data,
89-
config_file_data=config_file_data,
90-
seed=seed, batch_size=10)
95+
"AddF(AddF(AddF(MulF(MulF(x, MulF(x, x)),x), MulF(x,MulF(x, x))), MulF(x, x)), x)"
96+
] # noqa: E501
97+
gpsr = GPSymbolicRegressor(
98+
pset=pset,
99+
fitness=fitness.remote,
100+
error_metric=score.remote,
101+
predict_func=predict.remote,
102+
common_data=common_data,
103+
config_file_data=config_file_data,
104+
seed=seed,
105+
batch_size=10,
106+
)
91107

92108
train_data = Dataset("true_data", x, y)
93109
gpsr.fit(train_data)

0 commit comments

Comments
 (0)