Skip to content

Commit 3292265

Browse files
Hilly12recml authors
authored andcommitted
Refactor a few APIs.
Notably this removes the `rng` argument from `JaxTrainer` to avoid implicitly passing it. PiperOrigin-RevId: 789073073
1 parent 847628b commit 3292265

File tree

10 files changed

+136
-127
lines changed

10 files changed

+136
-127
lines changed

recml/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
from recml.core.utils.types import Factory
3939
from recml.core.utils.types import FactoryProtocol
4040
from recml.core.utils.types import ObjectFactory
41+
from recml.layers.common import EmbeddingSpec

recml/core/data/tf_dataset_factory.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,13 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
206206
dataset. Defaults to `ShardingInfo(num_processes=jax.process_count(),
207207
process_index=jax.process_index())`. This is similar to `InputContext` in
208208
tensorflow.
209+
cache_reading: Whether to cache the reading of the dataset. This is useful
210+
for debugging and testing. Defaults to False.
209211
debug: An optional boolean indicating whether to debug input boundedness. If
210212
`True`, the dataset will consist of a single batch that's cached and
211213
infinitely repeated
212214
"""
213215

214-
cache_reading: bool = False
215216
input_path: str | Sequence[str] = ""
216217
tfds_source: str | Sequence[str] = ""
217218
file_format: FileFormat = FileFormat.RECORDIO
@@ -246,6 +247,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246247
sharding_info: DatasetShardingInfo = dataclasses.field(
247248
default_factory=DatasetShardingInfo
248249
)
250+
cache_reading: bool = False
249251
debug: bool = False
250252

251253
def __post_init__(self):
@@ -478,7 +480,7 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
478480
)
479481

480482
# Generate a tf.Example dataset by cycling through all uris in parallel.
481-
return dataset.interleave(
483+
dataset = dataset.interleave(
482484
map_func=reader,
483485
cycle_length=self.cycle_length,
484486
block_length=self.block_length,
@@ -490,6 +492,12 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
490492
deterministic=self.deterministic,
491493
)
492494

495+
# Cache the reading of examples from files.
496+
if self.cache_reading:
497+
dataset = dataset.cache()
498+
499+
return dataset
500+
493501
def _parse_dataset(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
494502
"""Batches and parses an examples dataset."""
495503
# Batch the dataset to the global or per replica batch size.
@@ -556,10 +564,7 @@ def _maybe_apply_tf_data_service(
556564
def make(self) -> tf.data.Dataset:
557565
"""Creates a `tf.data.Dataset` instance with all dataset ops applied."""
558566
# Create an examples dataset.
559-
if self.cache_reading:
560-
dataset = self._create_dataset().cache()
561-
else:
562-
dataset = self._create_dataset()
567+
dataset = self._create_dataset()
563568
# Shuffle and repeat the dataset.
564569
dataset = self._maybe_shuffle_and_repeat(dataset)
565570
# Batch and parse the examples dataset.

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/core.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@
5757
class Trainer(abc.ABC, Generic[TaskT]):
5858
"""A base trainer interface for training and evaluation."""
5959

60+
class Mode(enum.StrEnum):
61+
"""Mode to run an experiment."""
62+
63+
TRAIN = "train"
64+
EVAL = "eval"
65+
TRAIN_AND_EVAL = "train_and_eval"
66+
CONTINUOUS_EVAL = "continuous_eval"
67+
6068
@abc.abstractmethod
6169
def __init__(self, model_dir: str, *args, **kwargs):
6270
"""Initializes the instance."""
@@ -77,6 +85,23 @@ def train_and_evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None:
7785
def evaluate_continuously(self, task: TaskT, *args, **kwargs) -> Logs | None:
7886
"""Performs continuous evaluation until a condition is met."""
7987

88+
def run(self, task: TaskT, mode: Any) -> Logs | None:
89+
"""Runs the experiment in the given mode."""
90+
if mode == Trainer.Mode.TRAIN_AND_EVAL:
91+
return self.train_and_evaluate(task)
92+
elif mode == Trainer.Mode.TRAIN:
93+
return self.train(task)
94+
elif mode == Trainer.Mode.EVAL:
95+
return self.evaluate(task)
96+
elif mode == Trainer.Mode.CONTINUOUS_EVAL:
97+
return self.evaluate_continuously(task)
98+
else:
99+
raise ValueError(f"The job mode provided is not supported: {mode}.")
100+
101+
@classmethod
102+
def setup(cls):
103+
"""Sets up the trainer before it is instantiated."""
104+
80105

81106
@dataclasses.dataclass(frozen=True)
82107
class Experiment(Generic[TaskT]):
@@ -90,32 +115,13 @@ class Experiment(Generic[TaskT]):
90115
trainer: The trainer to use for the experiment.
91116
"""
92117

93-
class Mode(enum.StrEnum):
94-
"""Mode to run an experiment."""
95-
96-
TRAIN = "train"
97-
EVAL = "eval"
98-
TRAIN_AND_EVAL = "train_and_eval"
99-
CONTINUOUS_EVAL = "continuous_eval"
100-
101118
task: TaskT
102119
trainer: Trainer[TaskT]
103120

104121

105-
def run_experiment(
106-
experiment: Experiment, mode: Experiment.Mode
107-
) -> Logs | None:
122+
def run_experiment(experiment: Experiment, mode: Any) -> Logs | None:
108123
"""Runs an experiment."""
109-
if mode == Experiment.Mode.TRAIN_AND_EVAL:
110-
return experiment.trainer.train_and_evaluate(experiment.task)
111-
elif mode == Experiment.Mode.TRAIN:
112-
return experiment.trainer.train(experiment.task)
113-
elif mode == Experiment.Mode.EVAL:
114-
return experiment.trainer.evaluate(experiment.task)
115-
elif mode == Experiment.Mode.CONTINUOUS_EVAL:
116-
return experiment.trainer.evaluate_continuously(experiment.task)
117-
else:
118-
raise ValueError(f"The job mode provided is not supported: {mode}.")
124+
experiment.trainer.run(experiment.task, mode)
119125

120126

121127
def get_iterators(
@@ -161,9 +167,7 @@ def get_iterators(
161167
k: iterator.TFDatasetIterator(v) for k, v in eval_datasets.items()
162168
}
163169

164-
if not all(
165-
isinstance(v, iterator.Iterator) for v in eval_datasets.values()
166-
):
170+
if not all(isinstance(v, iterator.Iterator) for v in eval_datasets.values()):
167171
raise ValueError(
168172
"Expected all values in the evaluation datasets mapping to be either"
169173
" `tf.data.Dataset` instances or CLU `DatasetIterator` instances,"
@@ -179,7 +183,7 @@ def get_shape(
179183
"""Gets the shape of a dense / sparse / ragged tensor or tensor spec."""
180184
if isinstance(x, tf.SparseTensor):
181185
return [x.shape[0]] + [None for _ in x.shape[1:]]
182-
return x.shape.as_list()
186+
return x.shape.as_list() # pylint: disable=attribute-error
183187

184188

185189
def in_tracing_context() -> bool:

recml/core/training/jax_trainer.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
StateT = TypeVar("StateT")
4747
MetricsT = TypeVar("MetricsT", bound=Mapping[str, clu_metrics.Metric])
48-
MetaT = TypeVar("MetaT")
48+
ModelT = TypeVar("ModelT")
4949
PyTree = Any
5050

5151

@@ -61,7 +61,7 @@ def opt_state(self) -> optax.OptState:
6161
"""Returns the optimizer state."""
6262

6363

64-
class JaxState(struct.PyTreeNode, Generic[MetaT]):
64+
class JaxState(struct.PyTreeNode, Generic[ModelT]):
6565
"""A training state for a Jax model created using Flax / Haiku.
6666
6767
Attributes:
@@ -77,7 +77,7 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
7777
_apply: An optional function that can be used to apply the forward pass of
7878
the model. For Flax models this is usually set to `model.apply` while for
7979
Haiku models this is usually set to `transform.apply`.
80-
_model: An optional reference to a stateless Flax model for convenience.
80+
_model: An optional reference to a model for convenience.
8181
mutable: A pytree of mutable variables that are used by `apply`.
8282
meta: Arbitrary metadata that is recorded on the state. This can be useful
8383
for tracking additional references in the state.
@@ -88,14 +88,14 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
8888
tx: optax.GradientTransformation = struct.field(pytree_node=False)
8989
opt_state: optax.OptState = struct.field(pytree_node=True)
9090
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
91-
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
91+
meta: Any = struct.field(pytree_node=False, default_factory=dict)
9292
_apply: Callable[..., Any] | None = struct.field(
9393
pytree_node=False, default_factory=None
9494
)
95-
_model: nn.Module | None = struct.field(pytree_node=False, default=None)
95+
_model: ModelT | None = struct.field(pytree_node=False, default=None)
9696

9797
@property
98-
def model(self) -> nn.Module:
98+
def model(self) -> ModelT:
9999
"""Returns a reference to the model used to create the state."""
100100
if self._model is None:
101101
raise ValueError("No Flax `model` is set on the state.")
@@ -112,7 +112,7 @@ def create(
112112
cls,
113113
*,
114114
apply: Callable[..., Any] | None = None,
115-
model: nn.Module | None = None,
115+
model: ModelT | None = None,
116116
params: PyTree,
117117
tx: optax.GradientTransformation,
118118
**kwargs,
@@ -123,9 +123,8 @@ def create(
123123
apply: A function that can be used to apply the forward pass of the model.
124124
For Flax models this is usually set to `model.apply`. This cannot be set
125125
along with `model`.
126-
model: A reference to a stateless Flax model. This cannot be set along
127-
with `apply`. When set the `apply` attribute of the state will be set to
128-
`model.apply`.
126+
model: A reference to a model. This cannot be set along with `apply`. When
127+
set the `apply` attribute of the state will be set to `model.apply`.
129128
params: A pytree of trainable variables that will be updated by `tx` and
130129
used in `apply`.
131130
tx: An optax gradient transformation that will be used to update the
@@ -137,7 +136,7 @@ def create(
137136
"""
138137
if apply is not None and model is not None:
139138
raise ValueError("Only one of `apply` or `model` can be provided.")
140-
elif model is not None:
139+
elif model is not None and isinstance(model, nn.Module):
141140
apply = model.apply
142141

143142
return cls(
@@ -311,30 +310,26 @@ def create_datasets(self) -> core.DatasetT:
311310
"""
312311

313312
@abc.abstractmethod
314-
def create_state(self, batch: PyTree, rng: jax.Array) -> StateT:
313+
def create_state(self, batch: PyTree) -> StateT:
315314
"""Creates the training state.
316315
317316
Args:
318317
batch: A pytree of arrays making up a dummy batch for state
319318
initialization.
320-
rng: A prng key that is passed from the trainer to control randomness
321-
during variable initialization.
322319
323320
Returns:
324321
The state to use for training.
325322
"""
326323

327324
@abc.abstractmethod
328325
def train_step(
329-
self, batch: PyTree, state: StateT, rng: jax.Array
326+
self, batch: PyTree, state: StateT
330327
) -> tuple[StateT, Mapping[str, clu_metrics.Metric]]:
331328
"""Updates the training state and accumulates metrics.
332329
333330
Args:
334331
batch: A pytree of arrays sampled from the training dataset.
335332
state: The training state created by `create_state`.
336-
rng: A prng key that is passed from the trainer to control randomness
337-
during training such as dropout.
338333
339334
Returns:
340335
A tuple[state, metrics] where the state is the updated training state
@@ -396,8 +391,6 @@ def __init__(
396391
checkpoint_interval: int | None = None,
397392
max_checkpoints_to_keep: int = 5,
398393
continuous_eval_timeout: int = 30,
399-
rng_seed: int = core.DEFAULT_RNG_SEED,
400-
rng_impl: str | None = None,
401394
):
402395
"""Initializes the instance.
403396
@@ -431,11 +424,6 @@ def __init__(
431424
checkpoint before timing out during continuous evaluation. When a
432425
timeout happens, the job will check for a marker file on disk and if it
433426
exists, it will terminate successfully. Defaults to 30 seconds.
434-
rng_seed: The seed to use for the PRNG key. By default this is set to a
435-
fixed constant.
436-
rng_impl: The implementation of the PRNG key. By default this is set to
437-
None which means that the default implementation (generally
438-
partitionable threefry) will be used.
439427
"""
440428

441429
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
@@ -451,8 +439,6 @@ def __init__(
451439
self._continuous_eval_timeout = continuous_eval_timeout
452440
self._checkpoint_interval = checkpoint_interval or steps_per_loop
453441
self._max_checkpoints_to_keep = max_checkpoints_to_keep
454-
self._rng_impl = rng_impl
455-
self._rng_seed = rng_seed
456442

457443
@functools.cached_property
458444
def checkpoint_manager(self) -> ocp.CheckpointManager:
@@ -610,18 +596,10 @@ def process_task(
610596
]:
611597
"""Initializes the objects required for training from the task."""
612598

613-
init_rng, step_rng = jax.random.split(
614-
jax.random.key(self._rng_seed, impl=self._rng_impl)
615-
)
616-
617-
def _create_state(inputs: PyTree) -> State:
618-
return task.create_state(inputs, init_rng)
619-
620599
def _train_step(
621600
inputs: PyTree, state: State
622601
) -> tuple[State, Mapping[str, clu_metrics.Metric]]:
623-
rng = jax.random.fold_in(step_rng, state.step) # pytype: disable=attribute-error
624-
state, metrics = task.train_step(inputs, state, rng)
602+
state, metrics = task.train_step(inputs, state)
625603
return state, {**_state_metrics(state), **metrics}
626604

627605
def _eval_step(
@@ -641,7 +619,7 @@ def _eval_step(
641619

642620
sharded_abstract_batch = self._partitioner.shard_inputs(abstract_batch)
643621
init_fn = self._partitioner.partition_init(
644-
_create_state, abstract_batch=sharded_abstract_batch
622+
task.create_state, abstract_batch=sharded_abstract_batch
645623
)
646624
train_step = self._partitioner.partition_step(_train_step, training=True)
647625
eval_step = self._partitioner.partition_step(_eval_step)

0 commit comments

Comments
 (0)