You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I don't know how to write loss_fn, it have two input and two label, it should like this:
encode1(input1) -> s1 -> decoder(s1) -> predict1 -> (predict1,label1)
encode2(input2) -> s2 -> decoder(s2) -> predict2 -> (predict2,label2)
How should this train_step be written? please give me a example, thanks.
a example like this:
"""Train for a single step."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
x=batch['image'], train=True, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, (logits, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (logits, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
metrics = {
'loss': loss,
'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
}
return state, metrics```
assume this is batch: batch['input1'], batch['input2'] , batch['label1'], batch['label2'].
encode1(input1) -> s1 -> decoder(s1) -> predict1 -> (predict1,label1)
encode2(input2) -> s2 -> decoder(s2) -> predict2 -> (predict2,label2)
this is the loss:
mse_loss1 = optax.l2_loss(predict1, batch['label1'])
mse_loss2 = optax.l2_loss(predict2, batch['label2'])
loss = mse_loss1+ mse_loss2
I need the rest of full train_step
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
this is the model, it has two encoder and one decoder, call just for init, it not use in inference :
this is i create TrainState, if it has any issue, please tell me:
I don't know how to write loss_fn, it have two input and two label, it should like this:
encode1(input1) -> s1 -> decoder(s1) -> predict1 -> (predict1,label1)
encode2(input2) -> s2 -> decoder(s2) -> predict2 -> (predict2,label2)
How should this train_step be written? please give me a example, thanks.
a example like this:
Beta Was this translation helpful? Give feedback.
All reactions