4545
4646StateT = TypeVar ("StateT" )
4747MetricsT = TypeVar ("MetricsT" , bound = Mapping [str , clu_metrics .Metric ])
48- MetaT = TypeVar ("MetaT " )
48+ ModelT = TypeVar ("ModelT " )
4949PyTree = Any
5050
5151
@@ -61,7 +61,7 @@ def opt_state(self) -> optax.OptState:
6161 """Returns the optimizer state."""
6262
6363
64- class JaxState (struct .PyTreeNode , Generic [MetaT ]):
64+ class JaxState (struct .PyTreeNode , Generic [ModelT ]):
6565 """A training state for a Jax model created using Flax / Haiku.
6666
6767 Attributes:
@@ -77,7 +77,7 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
7777 _apply: An optional function that can be used to apply the forward pass of
7878 the model. For Flax models this is usually set to `model.apply` while for
7979 Haiku models this is usually set to `transform.apply`.
80- _model: An optional reference to a stateless Flax model for convenience.
80+ _model: An optional reference to a model for convenience.
8181 mutable: A pytree of mutable variables that are used by `apply`.
8282 meta: Arbitrary metadata that is recorded on the state. This can be useful
8383 for tracking additional references in the state.
@@ -88,14 +88,14 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
8888 tx : optax .GradientTransformation = struct .field (pytree_node = False )
8989 opt_state : optax .OptState = struct .field (pytree_node = True )
9090 mutable : PyTree = struct .field (pytree_node = True , default_factory = dict )
91- meta : MetaT = struct .field (pytree_node = False , default_factory = dict )
91+ meta : Any = struct .field (pytree_node = False , default_factory = dict )
9292 _apply : Callable [..., Any ] | None = struct .field (
9393 pytree_node = False , default_factory = None
9494 )
95- _model : nn . Module | None = struct .field (pytree_node = False , default = None )
95+ _model : ModelT | None = struct .field (pytree_node = False , default = None )
9696
9797 @property
98- def model (self ) -> nn . Module :
98+ def model (self ) -> ModelT :
9999 """Returns a reference to the model used to create the state."""
100100 if self ._model is None :
101101 raise ValueError ("No Flax `model` is set on the state." )
@@ -112,7 +112,7 @@ def create(
112112 cls ,
113113 * ,
114114 apply : Callable [..., Any ] | None = None ,
115- model : nn . Module | None = None ,
115+ model : ModelT | None = None ,
116116 params : PyTree ,
117117 tx : optax .GradientTransformation ,
118118 ** kwargs ,
@@ -123,9 +123,8 @@ def create(
123123 apply: A function that can be used to apply the forward pass of the model.
124124 For Flax models this is usually set to `model.apply`. This cannot be set
125125 along with `model`.
126- model: A reference to a stateless Flax model. This cannot be set along
127- with `apply`. When set the `apply` attribute of the state will be set to
128- `model.apply`.
126+ model: A reference to a model. This cannot be set along with `apply`. When
127+ set the `apply` attribute of the state will be set to `model.apply`.
129128 params: A pytree of trainable variables that will be updated by `tx` and
130129 used in `apply`.
131130 tx: An optax gradient transformation that will be used to update the
@@ -137,7 +136,7 @@ def create(
137136 """
138137 if apply is not None and model is not None :
139138 raise ValueError ("Only one of `apply` or `model` can be provided." )
140- elif model is not None :
139+ elif model is not None and isinstance ( model , nn . Module ) :
141140 apply = model .apply
142141
143142 return cls (
@@ -311,30 +310,26 @@ def create_datasets(self) -> core.DatasetT:
311310 """
312311
313312 @abc .abstractmethod
314- def create_state (self , batch : PyTree , rng : jax . Array ) -> StateT :
313+ def create_state (self , batch : PyTree ) -> StateT :
315314 """Creates the training state.
316315
317316 Args:
318317 batch: A pytree of arrays making up a dummy batch for state
319318 initialization.
320- rng: A prng key that is passed from the trainer to control randomness
321- during variable initialization.
322319
323320 Returns:
324321 The state to use for training.
325322 """
326323
327324 @abc .abstractmethod
328325 def train_step (
329- self , batch : PyTree , state : StateT , rng : jax . Array
326+ self , batch : PyTree , state : StateT
330327 ) -> tuple [StateT , Mapping [str , clu_metrics .Metric ]]:
331328 """Updates the training state and accumulates metrics.
332329
333330 Args:
334331 batch: A pytree of arrays sampled from the training dataset.
335332 state: The training state created by `create_state`.
336- rng: A prng key that is passed from the trainer to control randomness
337- during training such as dropout.
338333
339334 Returns:
340335 A tuple[state, metrics] where the state is the updated training state
@@ -396,8 +391,6 @@ def __init__(
396391 checkpoint_interval : int | None = None ,
397392 max_checkpoints_to_keep : int = 5 ,
398393 continuous_eval_timeout : int = 30 ,
399- rng_seed : int = core .DEFAULT_RNG_SEED ,
400- rng_impl : str | None = None ,
401394 ):
402395 """Initializes the instance.
403396
@@ -431,11 +424,6 @@ def __init__(
431424 checkpoint before timing out during continuous evaluation. When a
432425 timeout happens, the job will check for a marker file on disk and if it
433426 exists, it will terminate successfully. Defaults to 30 seconds.
434- rng_seed: The seed to use for the PRNG key. By default this is set to a
435- fixed constant.
436- rng_impl: The implementation of the PRNG key. By default this is set to
437- None which means that the default implementation (generally
438- partitionable threefry) will be used.
439427 """
440428
441429 if not isinstance (steps_per_loop , int ) or steps_per_loop < 1 :
@@ -451,8 +439,6 @@ def __init__(
451439 self ._continuous_eval_timeout = continuous_eval_timeout
452440 self ._checkpoint_interval = checkpoint_interval or steps_per_loop
453441 self ._max_checkpoints_to_keep = max_checkpoints_to_keep
454- self ._rng_impl = rng_impl
455- self ._rng_seed = rng_seed
456442
457443 @functools .cached_property
458444 def checkpoint_manager (self ) -> ocp .CheckpointManager :
@@ -610,18 +596,10 @@ def process_task(
610596 ]:
611597 """Initializes the objects required for training from the task."""
612598
613- init_rng , step_rng = jax .random .split (
614- jax .random .key (self ._rng_seed , impl = self ._rng_impl )
615- )
616-
617- def _create_state (inputs : PyTree ) -> State :
618- return task .create_state (inputs , init_rng )
619-
620599 def _train_step (
621600 inputs : PyTree , state : State
622601 ) -> tuple [State , Mapping [str , clu_metrics .Metric ]]:
623- rng = jax .random .fold_in (step_rng , state .step ) # pytype: disable=attribute-error
624- state , metrics = task .train_step (inputs , state , rng )
602+ state , metrics = task .train_step (inputs , state )
625603 return state , {** _state_metrics (state ), ** metrics }
626604
627605 def _eval_step (
@@ -641,7 +619,7 @@ def _eval_step(
641619
642620 sharded_abstract_batch = self ._partitioner .shard_inputs (abstract_batch )
643621 init_fn = self ._partitioner .partition_init (
644- _create_state , abstract_batch = sharded_abstract_batch
622+ task . create_state , abstract_batch = sharded_abstract_batch
645623 )
646624 train_step = self ._partitioner .partition_step (_train_step , training = True )
647625 eval_step = self ._partitioner .partition_step (_eval_step )
0 commit comments