7
7
import jax .numpy as jnp
8
8
import ray
9
9
10
- # Define new functions
10
+ # sets CPU device for JAX at process level
11
+ os .environ ["JAX_PLATFORMS" ] = "cpu"
11
12
12
13
config ()
13
14
@@ -16,16 +17,19 @@ def compile_individuals(toolbox, individuals_str_batch):
16
17
return [toolbox .compile (expr = ind ) for ind in individuals_str_batch ]
17
18
18
19
19
- x = jnp .array ([x / 10. for x in range (- 10 , 10 )])
20
+ x = jnp .array ([x / 10.0 for x in range (- 10 , 10 )])
20
21
y = x ** 4 + x ** 3 + x ** 2 + x
21
22
22
23
23
24
def eval_MSE_sol (individual , true_data ):
25
+ import os
26
+
27
+ os .environ ["JAX_PLATFORMS" ] = "cpu"
24
28
config ()
25
29
# Evaluate the mean squared error between the expression
26
30
# and the real function : x**4 + x**3 + x**2 + x
27
31
y_pred = individual (true_data .X )
28
- MSE = jnp .sum (jnp .square (y_pred - true_data .y )) / len (true_data .X )
32
+ MSE = jnp .sum (jnp .square (y_pred - true_data .y )) / len (true_data .X )
29
33
if jnp .isnan (MSE ):
30
34
MSE = 1e5
31
35
return MSE , y_pred
@@ -36,7 +40,7 @@ def predict(individuals_str, toolbox, true_data):
36
40
37
41
callables = compile_individuals (toolbox , individuals_str )
38
42
39
- u = [None ]* len (individuals_str )
43
+ u = [None ] * len (individuals_str )
40
44
41
45
for i , ind in enumerate (callables ):
42
46
_ , u [i ] = eval_MSE_sol (ind , true_data )
@@ -49,7 +53,7 @@ def score(individuals_str, toolbox, true_data):
49
53
50
54
callables = compile_individuals (toolbox , individuals_str )
51
55
52
- MSE = [None ]* len (individuals_str )
56
+ MSE = [None ] * len (individuals_str )
53
57
54
58
for i , ind in enumerate (callables ):
55
59
MSE [i ], _ = eval_MSE_sol (ind , true_data )
@@ -61,7 +65,7 @@ def score(individuals_str, toolbox, true_data):
61
65
def fitness (individuals_str , toolbox , true_data ):
62
66
callables = compile_individuals (toolbox , individuals_str )
63
67
64
- fitnesses = [None ]* len (individuals_str )
68
+ fitnesses = [None ] * len (individuals_str )
65
69
for i , ind in enumerate (callables ):
66
70
MSE , _ = eval_MSE_sol (ind , true_data )
67
71
@@ -76,18 +80,30 @@ def test_basic_sr(set_test_dir):
76
80
with open (filename ) as config_file :
77
81
config_file_data = yaml .safe_load (config_file )
78
82
79
- pset = gp .PrimitiveSetTyped ("MAIN" , [float ,], float )
83
+ pset = gp .PrimitiveSetTyped (
84
+ "MAIN" ,
85
+ [
86
+ float ,
87
+ ],
88
+ float ,
89
+ )
80
90
pset .addPrimitive (jnp .add , [float , float ], float , "AddF" )
81
- pset .renameArguments (ARG0 = 'x' )
91
+ pset .renameArguments (ARG0 = "x" )
82
92
83
93
common_data = {}
84
94
seed = [
85
- "AddF(AddF(AddF(MulF(MulF(x, MulF(x, x)),x), MulF(x,MulF(x, x))), MulF(x, x)), x)" ] # noqa: E501
86
- gpsr = GPSymbolicRegressor (pset = pset , fitness = fitness .remote ,
87
- error_metric = score .remote , predict_func = predict .remote ,
88
- common_data = common_data ,
89
- config_file_data = config_file_data ,
90
- seed = seed , batch_size = 10 )
95
+ "AddF(AddF(AddF(MulF(MulF(x, MulF(x, x)),x), MulF(x,MulF(x, x))), MulF(x, x)), x)"
96
+ ] # noqa: E501
97
+ gpsr = GPSymbolicRegressor (
98
+ pset = pset ,
99
+ fitness = fitness .remote ,
100
+ error_metric = score .remote ,
101
+ predict_func = predict .remote ,
102
+ common_data = common_data ,
103
+ config_file_data = config_file_data ,
104
+ seed = seed ,
105
+ batch_size = 10 ,
106
+ )
91
107
92
108
train_data = Dataset ("true_data" , x , y )
93
109
gpsr .fit (train_data )
0 commit comments