Replies: 7 comments 16 replies
-
Thanks for posting this update, I remembered you commented on this performance consideration somewhere months back and couldn't find it. Been doing this split/merge w/ standard JAX transforms in my code just in case (and to stay as pure JAX as possible). If a PR goes through to address this, I'll try switching to the NNX transforms! 👍 |
Beta Was this translation helpful? Give feedback.
-
Thanks for posting this! Is there a way to do update with metrics? Or only graphdef, state = nnx.split((model, optimizer, metrics))
...
nnx.update((model, optimizer, metrics), state) Cause now it raises an error:
And also it will be great to create page with speed up tips for NNX API! |
Beta Was this translation helpful? Give feedback.
-
@cgarciae As I mention above, I've been sticking to the split/merge + JAX transforms to future proof against any performance hits. However, I would consider switching to NNX transforms for my current dev if the expectation is that the Rust extension would definitively close the performance gap. Can you comment on the expected gains with flaxlib? |
Beta Was this translation helpful? Give feedback.
-
@cgarciae in your example, at the end, |
Beta Was this translation helpful? Give feedback.
-
Big fan of NNX! I personally think there are reasons other than performance to use split/merge and standard JAX transforms. It's "closer to the metal," if you will -- once you understand the split/merge API and JAX's core APIs, you're empowered to do pretty much anything, with a little more boilerplate (holding on to the graphdef) which is not too bad in my opinion (especially since y'all have done such a great job with the static typing!). You can mix NNX's mutable reference semantics with JAX's pure functional semantics to write both convenient and bug-free code. I worry that encouraging NNX transforms only, while sweeping split/merge under the rug, would be especially bad for newer JAX users. NNX transforms add a layer of abstraction that completely hides the underlying JAX abstractions, which may make it harder to pick up important concepts like tracing/staging out, PyTrees, sharding, etc. As a more experienced JAX user, I've definitely been finding split/merge with explicit state management more comfortable and legible. Another argument for encouraging this pattern is that, at least right now, you must understand split/merge and explicit state management to save and load checkpoints. I realize not everyone will agree with me! My vote would be to document both split/merge and NNX transforms side-by-side as equivalent ways of doing things, even after flaxlib is complete. That way, even if people do want to use NNX transforms to save on boilerplate, they can still acquire a mental model of what is happening under the hood. |
Beta Was this translation helpful? Give feedback.
-
fyi Pinning this discussion to google/flax/discussions/ @cgarciae #nnx |
Beta Was this translation helpful? Give feedback.
-
Hello, I'm using class GaussianFourierProjection(nnx.Module):
"""Gaussian random features for encoding time steps."""
#embed_dim: int
#scale: float = 30.
def __init__(self, embed_dim: int, scale: float, *, rngs: nnx.Rngs):
key = rngs.params()
dout = embed_dim // 2
self.W = nnx.Variable(jax.random.normal(key, (dout,)) * scale)
def __call__(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * jnp.pi
return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
class ScoreNet(nnx.Module):
channels: Tuple[int] = (32, 64, 128, 256)
embed_dim: int = 256
scale: float = 30.
def __init__(self, marginal_prob_std:Any, din_t:int, rngs: nnx.Rngs):
self.act = nnx.swish
self.marginal_prob_std = marginal_prob_std
#time embedding
self.embed = GaussianFourierProjection(embed_dim=self.embed_dim,
scale=self.scale,
rngs=rngs)
self.LayerEmb = nnx.Linear(self.embed_dim, self.embed_dim, rngs=rngs)
#encoding part
self.Layer1 = nnx.Linear(args['x_dim'],self.channels[0],use_bias=False, rngs=rngs)
self.Layer1e = nnx.Linear(self.embed_dim,self.channels[0], rngs=rngs)
self.Norm1 = nnx.GroupNorm(self.channels[0],num_groups=4, rngs=rngs)
self.Layer2 = nnx.Linear(self.channels[0],self.channels[1],use_bias=False, rngs=rngs)
self.Layer2e = nnx.Linear(self.embed_dim,self.channels[1], rngs=rngs)
self.Norm2 = nnx.GroupNorm(self.channels[1], rngs=rngs) # num_groups=32 by default
self.Layer3 = nnx.Linear(self.channels[1],self.channels[2],use_bias=False, rngs=rngs)
self.Layer3e = nnx.Linear(self.embed_dim,self.channels[2], rngs=rngs)
self.Norm3 = nnx.GroupNorm(self.channels[2], rngs=rngs) # num_groups=32 by default
self.Layer4 = nnx.Linear(self.channels[2],self.channels[3],use_bias=False, rngs=rngs)
self.Layer4e = nnx.Linear(self.embed_dim,self.channels[3], rngs=rngs)
self.Norm4 = nnx.GroupNorm(self.channels[3], rngs=rngs) # num_groups=32 by default
#decoding part
self.Layer5 = nnx.Linear(self.channels[3],self.channels[2],use_bias=False, rngs=rngs)
self.Layer5e = nnx.Linear(self.embed_dim,self.channels[2], rngs=rngs)
self.Norm5 = nnx.GroupNorm(self.channels[2], rngs=rngs) # num_groups=32 by default
self.Layer6 = nnx.Linear(2*self.channels[2],self.channels[1],use_bias=False, rngs=rngs)
self.Layer6e = nnx.Linear(self.embed_dim,self.channels[1], rngs=rngs)
self.Norm6 = nnx.GroupNorm(self.channels[1], rngs=rngs) # num_groups=32 by default
self.Layer7 = nnx.Linear(2*self.channels[1],self.channels[0],use_bias=False, rngs=rngs)
self.Layer7e = nnx.Linear(self.embed_dim,self.channels[0], rngs=rngs)
self.Norm7 = nnx.GroupNorm(self.channels[0], rngs=rngs) # num_groups=32 by default
self.Layer7 = nnx.Linear(2*self.channels[1],self.channels[0],use_bias=False, rngs=rngs)
self.Layer7e = nnx.Linear(self.embed_dim,self.channels[0], rngs=rngs)
self.Norm7 = nnx.GroupNorm(self.channels[0], rngs=rngs) # num_groups=32 by default
self.Layer8 = nnx.Linear(2*self.channels[0],args['x_dim'], rngs=rngs)
def __call__ (self,x,t):
# time embding
embed = self.act(self.LayerEmb(self.embed(t)))
# encoding
h1 = self.Layer1(x)
h1 += self.Layer1e(embed)
h1 = self.Norm1(h1)
h1 = self.act(h1)
h2 = self.Layer2(h1)
h2 += self.Layer2e(embed)
h2 = self.Norm2(h2)
h2 = self.act(h2)
h3 = self.Layer3(h2)
h3 += self.Layer3e(embed)
h3 = self.Norm3(h3)
h3 = self.act(h3)
h4 = self.Layer4(h3)
h4 += self.Layer4e(embed)
h4 = self.Norm4(h4)
h4 = self.act(h4)
# decondig
h = self.Layer5(h4)
h += self.Layer5e(embed)
h = self.Norm5(h)
h = self.act(h)
h = self.Layer6(jnp.concatenate([h, h3], axis=-1))
h += self.Layer6e(embed)
h = self.Norm6(h)
h = self.act(h)
h = self.Layer7(jnp.concatenate([h, h2], axis=-1))
h += self.Layer7e(embed)
h = self.Norm7(h)
h = self.act(h)
h = self.Layer8(jnp.concatenate([h, h1], axis=-1))
# normalisation
h = h / self.marginal_prob_std(t)[:, None]
return h I was wandering thst if the performance diffrence is still valid or my code is not perfect (btw I'm not an expert...)? |
Beta Was this translation helpful? Give feedback.
-
Currently
nnx.jit
traverses the object graph in Python. This is slow and primarily affects the small model regime, as the Python overhead starts to disappear as the model's width grows. To solve this in general, we will be developing a Rust extension calledflaxlib
(see first steps in #4196) to speedup some of the traversal logic ingraph.py
, similar to how JAX solved the same issue withjaxlib
for standard pytrees.UPDATE: see full Performance Considerations guide.
Beta Was this translation helpful? Give feedback.
All reactions