Help: How to prevent params of a hybrid module from becoming dynamic or traced array by optax? #707
-
Hi, I'm writing a class
Thank you for your patience reading my question! import jax
import flax.linen as nn
from math import sqrt
from optax import apply_updates,adamw, cosine_decay_schedule, linear_schedule, adam, join_schedules
# branch = tf.keras.Sequential(
# [
# tf.keras.layers.InputLayer(input_shape=(m,)),
# tf.keras.layers.Reshape((config.get('resolution'), config.get('resolution'), 1)),
# tf.keras.layers.Conv2D(4, (2, 2), strides=1, activation=activation),
# tf.keras.layers.Conv2D(16, (3, 3), strides=2, activation=activation),
# tf.keras.layers.Conv2D(16, (3, 3), strides=2, activation=activation),
# tf.keras.layers.Flatten(),
# tf.keras.layers.Dense(256, activation=activation),
# tf.keras.layers.Dense(128),
# ]
# )
def create_learning_rate_fn(config, base_learning_rate, steps_per_epoch):
"""Creates learning rate schedule."""
warmup_fn = linear_schedule(
init_value=0., end_value=base_learning_rate,
transition_steps=config['warmup_epochs'] * steps_per_epoch)
cosine_epochs = max(config['num_epochs'] - config['warmup_epochs'], 1)
cosine_fn = cosine_decay_schedule(
init_value=base_learning_rate,
decay_steps=cosine_epochs * steps_per_epoch)
schedule_fn = join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[config['warmup_epochs'] * steps_per_epoch])
return schedule_fn
# branch.summary()
class CNN(nn.Module):
CNNfeature: tuple[int, ...] = (4,16,64,256,128)
# Densefeature: tuple[int, ...] = (256,128)
@nn.compact
def __call__(self, x):
# print(x.shape[0])
x = x.reshape((1,int(sqrt(x.shape[0])), int(sqrt(x.shape[0])),1))
x = nn.Conv(features=4, kernel_size=(2, 2))(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=128)(x)
x = np.squeeze(x)
return x
class ParaIndependentCNN(nn.Module):
CNNfeature: tuple[int, ...] = (4,16,64,256,128)
# Densefeature: tuple[int, ...] = (256,128)
@nn.compact
def __call__(self, x):
# print(x.shape[0])
x = x.reshape((1,int(sqrt(x.shape[0])), int(sqrt(x.shape[0])),1))
x = nn.Conv(features=4, kernel_size=(2, 2))(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = np.mean(x,axis=(1,2),keepdims=False) # avg over img
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=128)(x)
x = np.squeeze(x)
return x
class CNNold(nn.Module):
'''
deeponet branch net
'''
CNNfeature: tuple[int, ...] = (4,16,64,256,128)
# Densefeature: tuple[int, ...] = (256,128)
@nn.compact
def __call__(self, x):
# print(x.shape[0])
x = x.reshape((1,int(sqrt(x.shape[0])), int(sqrt(x.shape[0])),1))
x = nn.Conv(features=64, kernel_size=(5, 5))(x)
x = nn.relu(x)
x = nn.Conv(features=128, kernel_size=(5, 5))(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=128)(x)
x = nn.relu(x)
x = nn.Dense(features=128)(x)
x = np.squeeze(x)
return x
import jax.numpy as jnp
import math
product = math.prod
# Define the model
class PI_DeepONet:
def __init__(self, branch_layers, trunk_layers, res=32, seednum=1234, opt=None, sched=None):
# Network initialization and evaluation functions
# self.branch_init, self.branch_apply = MLP(branch_layers, activation=elu)
self.trunk_init, self.trunk_apply = MLP(trunk_layers, activation=elu)
if isinstance(branch_layers, nn.Module):
self.branch = branch_layers
else:
self.branch = CNN(branch_layers)
branch_params = self.branch.init(random.PRNGKey(seednum), jnp.empty((res**2, )))
num_params=count_array_shapes(branch_params)
num_params=sum((product(m) for m in num_params))
print(f"Number of branch parameters: {num_params}")
# Initialize
# branch_params = self.branch_init(rng_key = random.PRNGKey(1234))
trunk_params = self.trunk_init(rng_key = random.PRNGKey(seednum))
# num_params = count_array_shapes(trunk_params)
num_params=sum((product(m[0].shape)+ product(m[1].shape) for m in trunk_params))
print(f"Number of trunk parameters: {num_params}")
self.params = (branch_params, trunk_params)
if opt is None:
# Use optimizers to set optimizer initialization and update functions
self.opt_init, \
self.opt_update, \
self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3,
decay_steps=12500,
decay_rate=0.9))
self.opt_state = self.opt_init(self.params)
else:
if opt=='duiqiVOL':
self.opt = adamw(sched)
self.opt_state = self.opt.init(self.params)
# cosine_decay_schedule(2e-2, decay_steps=12500)
_, self.unravel_params = ravel_pytree(self.params)
# Logger
self.itercount = itertools.count()
self.loss_log = []
self.loss_bcs_log = []
self.loss_res_log = []
# Define DeepONet architecture
def operator_net(self, params, u, y1, y2):
y = np.stack([y1, y2])
branch_params, trunk_params = params
B = self.branch.apply(branch_params, u)
T = self.trunk_apply(trunk_params, y)
outputs = np.sum(B * T)
outputs = outputs * 0.01 * np.sin(np.pi*y1)*np.sin(np.pi*y2)
return outputs
# Define DeepONet 2 architecture
def operator_net_doaminwise(self, params, u, y1, y2):
y = np.stack([y1, y2],axis=1) # y (16) 1024,2 y1 and y2 (16) 1024
branch_params, trunk_params = params
B = self.branch.apply(branch_params, u) # (16) 128
T = self.trunk_apply(trunk_params, y) # (16) 1024,128
outputs = np.einsum('nl,l->n', T, B) # (16) 1024
outputs = outputs * 0.01 * np.sin(np.pi*y1)*np.sin(np.pi*y2) # (16) 1024
return outputs
# Define PDE residual
def residual_net(self, params, u, y1, y2, aux):
dy1 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux)[0]
dy2 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 3)(params, u, y1, y2)*aux)[0]
ddy1 = grad(dy1, argnums=2)(params, u, y1, y2, aux)
ddy2 = grad(dy2, argnums=3)(params, u, y1, y2, aux)
# def one_order_gradient(params, u, y1, y2, aux):
# # 虽然jax声称这里的argnums允许输入元组等序列,但是从代码实践来看,会报TypeError: unsupported operand type(s) for *: 'tuple' and 'BatchTracer'的错误
# dy = grad(self.operator_net, argnums = [2,3])(params, u, y1, y2)*aux
# return dy
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# ddy1 = grad(lambda params, u, y, aux: grad(self.operator_net, argnums = 2)(params, u, y)*aux, argnums = 2)(params, u, np.stack([y1, y2]), aux)
# ddy2 = grad(dy2, argnums = 3)(params, u, y1, y2, aux)
# res = s_y1**2 + s_y2**2
# here i want to write with "branch batch first" style but I do not finish the plan
return ddy1+ddy2
def residual_netdm(self, params, u, y1, y2, aux):
dy1 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux)
dy2 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 3)(params, u, y1, y2)*aux)
ddy1 = grad(dy1, argnums=2)(params, u, y1, y2, aux)
ddy2 = grad(dy2, argnums=3)(params, u, y1, y2, aux)
# def one_order_gradient(params, u, y1, y2, aux):
# # 虽然jax声称这里的argnums允许输入元组等序列,但是从代码实践来看,会报TypeError: unsupported operand type(s) for *: 'tuple' and 'BatchTracer'的错误
# dy = grad(self.operator_net, argnums = [2,3])(params, u, y1, y2)*aux
# return dy
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# ddy1 = grad(lambda params, u, y, aux: grad(self.operator_net, argnums = 2)(params, u, y)*aux, argnums = 2)(params, u, np.stack([y1, y2]), aux)
# ddy2 = grad(dy2, argnums = 3)(params, u, y1, y2, aux)
# res = s_y1**2 + s_y2**2
# here i want to write with "branch batch first" style but I do not finish the plan
return ddy1+ddy2
def energy_netdm(self, params, u, y1, y2, aux):
kdy1 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 2)(params, u, y1, y2)**2*aux)
kdy2 = lambda params, u, y1, y2, aux: (grad(self.operator_net, argnums = 3)(params, u, y1, y2)**2*aux)
# ddy1 = grad(dy1, argnums=2)(params, u, y1, y2, aux)
# ddy2 = grad(dy2, argnums=3)(params, u, y1, y2, aux)
# def one_order_gradient(params, u, y1, y2, aux):
# # 虽然jax声称这里的argnums允许输入元组等序列,但是从代码实践来看,会报TypeError: unsupported operand type(s) for *: 'tuple' and 'BatchTracer'的错误
# dy = grad(self.operator_net, argnums = [2,3])(params, u, y1, y2)*aux
# return dy
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# grad_y = grad(self.operator_net, argnums = 2)(params, u, y1, y2)*aux
# grad_yy1 = grad(grad_y[0], ar)
# ddy1 = grad(lambda params, u, y, aux: grad(self.operator_net, argnums = 2)(params, u, y)*aux, argnums = 2)(params, u, np.stack([y1, y2]), aux)
# ddy2 = grad(dy2, argnums = 3)(params, u, y1, y2, aux)
# res = s_y1**2 + s_y2**2
# here i want to write with "branch batch first" style but I do not finish the plan
return 0.5*(kdy1(params, u, y1, y2, aux) + kdy2(params, u, y1, y2, aux)) - self.operator_net(params, u, y1, y2)
def residual_netdomainwise(self, params, u, y1, y2, aux):
# y1 and y2 (16) 1024
# u (16) 1024
# aux (16) 1024
# I want to vmap
return np.mean((vmap(self.residual_netdm, in_axes=(None,None,0,0,0))(params, u, y1, y2, aux) -1.0)**2)
def energy_netdomainwise(self, params, u, y1, y2, aux):
# y1 and y2 (16) 1024
# u (16) 1024
# aux (16) 1024
# I want to vmap
return np.sum(vmap(self.energy_netdm, in_axes=(None,None,0,0,0))(params, u, y1, y2, aux))
def data_netdomainwise(self, params, u, y1, y2, aux):
# y1 and y2 (16) 1024
# u (16) 1024
# aux (16) 1024
# I want to vmap
return np.mean(vmap(self.res_operator, in_axes=(None,None,0,0,0))(params, u, y1, y2, aux)**2)
def res_operator(self, params, u, y1, y2, aux):
# Compute forward pass
s_pred = self.operator_net(params, u, y1, y2)
# Compute loss
loss = aux - s_pred
return loss
# Define boundary loss
def loss_bcs(self, params, batch):
# Fetch data
inputs, outputs = batch
u, y = inputs
# Compute forward pass
pred = vmap(self.operator_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])
# Compute loss
loss = np.mean((pred)**2)
return loss
# Define residual loss
def loss_res(self, params, batch):
# Fetch data
inputs, auxs = batch
u, y = inputs
pred = vmap(self.residual_net, (None, 0, 0, 0, 0))(params, u, y[:,0], y[:, 1], auxs)
loss = np.mean((pred - 1.0)**2)
return loss
def loss_resdomainwise(self, params, batch):
# Fetch data
inputs, auxs = batch
u, y = inputs
pred = vmap(self.residual_netdomainwise, (None, 0, None, None, 0))(params, u, y[:,0], y[:, 1], auxs)
loss = np.mean(pred)
return loss
def loss_engdomainwise(self, params, batch):
# Fetch data
inputs, auxs = batch
u, y = inputs
pred = vmap(self.energy_netdomainwise, (None, 0, None, None, 0))(params, u, y[:,0], y[:, 1], auxs)
loss = np.mean(pred)
return loss
def loss_datadomainwise(self, params, batch):
# Fetch data
inputs, auxs = batch
u, y = inputs
pred = vmap(self.data_netdomainwise, (None, 0, None, None, 0))(params, u, y[:,0], y[:, 1], auxs)
loss = np.mean(pred)
return loss
# Define total loss
def loss(self, params, bcs_batch, res_batch):
# loss_bcs = self.loss_bcs(params, bcs_batch)
loss_res = self.loss_res(params, res_batch)
loss = loss_res
return loss
# Define a compiled update step
@partial(jit, static_argnums=(0,))
def step(self, i, opt_state, bcs_batch, res_batch):
params = self.get_params(opt_state)
g = grad(self.loss)(params, bcs_batch, res_batch)
return self.opt_update(i, g, opt_state)
@partial(jit, static_argnums=(0,))
def stepdomainwise(self, i, opt_state, bcs_batch, res_batch):
params = self.get_params(opt_state)
# inputs, auxs = res_batch
# u, y = inputs
g = grad(self.loss_resdomainwise)(params, res_batch)
return self.opt_update(i, g, opt_state)
@partial(jit, static_argnums=(0,))
def stepengdomainwise(self, i, opt_state, bcs_batch, res_batch):
params = self.get_params(opt_state)
# inputs, auxs = res_batch
# u, y = inputs
g = grad(self.loss_engdomainwise)(params, res_batch)
return self.opt_update(i, g, opt_state)
@partial(jit, static_argnums=(0,))
def stepengdomainwise_optax(self, i, opt_state, bcs_batch, res_batch):
# params = self.get_params(opt_state)
# inputs, auxs = res_batch
# u, y = inputs
g = grad(self.loss_engdomainwise)(self.params, res_batch)
updates, opt_state = self.opt.update( g, opt_state, params=self.params)
self.params = apply_updates(self.params, updates)
return opt_state
@partial(jit, static_argnums=(0,))
def stepdomainwise_datadriven(self, i, opt_state, bcs_batch, res_batch):
params = self.get_params(opt_state)
# inputs, auxs = res_batch
# u, y = inputs
g = grad(self.loss_datadomainwise)(params, res_batch)
return self.opt_update(i, g, opt_state)
# Optimize parameters in a loop
def train(self, bcs_dataset, res_dataset, nIter = 10000, batch_size=16, domainwise=False, datadriven=False, eng=False):
pbar = trange(nIter)
# Main training loop
for it in pbar:
# Fetch data
# bcs_batch= next(bcs_data)
res_batch = res_dataset.train_next_batch(batch_size,domainwise=domainwise)
if domainwise == False:
self.opt_state = self.step(next(self.itercount), self.opt_state, None, res_batch)
else:
if datadriven==False:
if eng==True:
self.opt_state = self.stepengdomainwise_optax(next(self.itercount), self.opt_state, None, res_batch)
else:
self.opt_state = self.stepdomainwise(next(self.itercount), self.opt_state, None, res_batch)
else:
self.opt_state = self.stepdomainwise_datadriven(next(self.itercount), self.opt_state, None, res_batch)
if it % 100 == 0:
# params = self.get_params(self.opt_state)
# Compute losses
# loss_value = self.loss(params, None, res_batch)
# loss_bcs_value = self.loss_bcs(params, bcs_batch)
if domainwise==False:
loss_res_value = self.loss_res(self.params, res_batch)
else:
if datadriven==False:
if eng:
# loss_res_value = jax.jit(self.loss_engdomainwise)(np.array(self.params), res_batch) # dict is not able to be transformed
loss_res_value = jax.jit(self.loss_engdomainwise)(self.params, res_batch)
else:
loss_res_value = self.loss_resdomainwise(self.params, res_batch)
else:
loss_res_value = self.loss_datadomainwise(self.params, res_batch)
# Store losses
# self.loss_log.append(loss_value)
# self.loss_bcs_log.append(loss_bcs_value)
self.loss_res_log.append(loss_res_value)
# Print losses
pbar.set_postfix({'Loss': loss_res_value,
'loss_res': loss_res_value})
# Evaluates predictions at test points
@partial(jit, static_argnums=(0,))
def predict_sdomainwise(self, params, U_star, Y_star):
# U_star 16 1024
# Y_star 1024 2
# expect U 16 1024 Y 1024 2
s_pred = vmap(self.operator_net_doaminwise, (None, 0, None, None))(params, U_star, Y_star[:,0], Y_star[:,1])
return s_pred
def testpi(myconfig, model):
nona_test_xspace = NonAnalyticFunctionspace(len_train_data=myconfig.get('testnum'), resolution=myconfig.get('resolution'),shift=10,shuffle=False)
nona_test_yspace = NonAnalyticFunctionspace(len_train_data=myconfig.get('testnum'), resolution=myconfig.get('resolution'),shift=10,data_dir=r'/mnt/e/xtf/data/darcyflow/neuraloperator-master/data_generation/darcy/data/512_results/transpose',shuffle=False,feature_name="")
test_err = 0
params = model.get_params(model.opt_state)
for i in range(myconfig.get('testnum')):
evaluation_points = np.linspace(0,1, myconfig.get('resolution'))
evaluation_pointsx, evaluation_pointsxy = np.meshgrid(evaluation_points, evaluation_points,indexing='ij')
evaluation_points = np.vstack((evaluation_pointsx.ravel(), evaluation_pointsxy.ravel())).T
test_feature_x = nona_test_xspace.random(1)
test_feature_y = nona_test_yspace.random(1)
test_feature_x = nona_test_xspace.eval_batch(test_feature_x,evaluation_points)
test_feature_y = nona_test_yspace.eval_batch(test_feature_y,evaluation_points)
# for j in
test_pred = model.predict_sdomainwise(params, test_feature_x, evaluation_points)
test_pred = test_pred.reshape((1,myconfig.get('resolution'),myconfig.get('resolution')))
test_feature_y = test_feature_y.reshape((1,myconfig.get('resolution'),myconfig.get('resolution')))
test_err += dde.metrics.l2_relative_error(test_pred,test_feature_y)
if i==0:
visualize_tensors(test_pred,'/home/dutxtf/codespace/Physics-informed-DeepONets/Eikonal/results',img_name='pred{}'.format(i),iftensorflow=False,)
visualize_tensors(test_feature_y,'/home/dutxtf/codespace/Physics-informed-DeepONets/Eikonal/results',img_name='true{}'.format(i),iftensorflow=False,)
if i%1000==0:
print('test error: ', dde.metrics.l2_relative_error(test_pred,test_feature_y))
test_err = test_err/myconfig.get('testnum')
print('Res {}: ave test error: '.format(myconfig.get('resolution')), test_err) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hey @BraveDrXuTF . as far as I know the arrays become traced when one uses @jax.jit . If you remove all jitting from the code then you should be able to see the un-traced arrays. Alternatively I've found jax.debug.print to also be useful to print traced arrays: https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html |
Beta Was this translation helpful? Give feedback.
Hi @fabianp. Thank you for your response!
Yeah, I can indeed use debug to print traced array, but I strongly doubt that it is because I have not been able to use
optax
correctly in some ways that the parameters are transformed into traced-array after runningapply_updates
, while the parameters in another file in your gallery have NOT been transformed (I have tested them in colab with pythonprint
function, and provedparameters
keep themselves asjnp.array
) , even with@jax.jit
: