Skip to content

Commit

Permalink
Merge pull request #4592 from jakevdp:fix-shape
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733430128
  • Loading branch information
Flax Authors committed Mar 4, 2025
2 parents a24d790 + 4e300d4 commit 45a8f84
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,9 +956,9 @@ def param(
# NOTE: We could check dtype consistency here as well but it's
# usefuleness is less obvious. We might intentionally change the dtype
# for inference to a half float type for example.
if jnp.shape(val) != jnp.shape(abs_val):
if np.shape(val) != np.shape(abs_val):
raise errors.ScopeParamShapeError(
name, self.path_text, jnp.shape(abs_val), jnp.shape(val)
name, self.path_text, np.shape(abs_val), np.shape(val)
)
else:
if not self.is_mutable_collection('params'):
Expand Down
10 changes: 5 additions & 5 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from flax.nnx.object import Object
from flax.nnx import variablelib
from flax.nnx.bridge import variables as bridge_variables
import jax.numpy as jnp
import numpy as np

A = tp.TypeVar('A')
M = tp.TypeVar('M', bound='Module')
Expand Down Expand Up @@ -231,9 +231,9 @@ def param( # type: ignore[invalid-annotation]
abs_value_flat = jax.tree_util.tree_leaves(abs_value)
value_flat = jax.tree_util.tree_leaves(value)
for val, abs_val in zip(value_flat, abs_value_flat):
if jnp.shape(val) != jnp.shape(abs_val):
if np.shape(val) != np.shape(abs_val):
raise errors.ScopeParamShapeError(
name, '', jnp.shape(abs_val), jnp.shape(val)
name, '', np.shape(abs_val), np.shape(val)
)

if isinstance(abs_value, variablelib.VariableMetadata):
Expand Down Expand Up @@ -282,9 +282,9 @@ def variable( # type: ignore[invalid-annotation]
abs_value_flat = jax.tree_util.tree_leaves(abs_value)
value_flat = jax.tree_util.tree_leaves(value)
for val, abs_val in zip(value_flat, abs_value_flat):
if jnp.shape(val) != jnp.shape(abs_val):
if np.shape(val) != np.shape(abs_val):
raise errors.ScopeParamShapeError(
name, '', jnp.shape(abs_val), jnp.shape(val)
name, '', np.shape(abs_val), np.shape(val)
)

if isinstance(abs_value, variablelib.VariableMetadata):
Expand Down

0 comments on commit 45a8f84

Please sign in to comment.