Skip to content

Commit b2ccd7b

Browse files
authored
Replace jax.random.KeyArray with jax.Array to suppress deprecation warnings. (#166)
1 parent 32d439e commit b2ccd7b

27 files changed

+81
-88
lines changed

Diff for: axlearn/common/adapter_flax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _create_dummy_inputs(self):
6363
return cfg.create_dummy_input_fn(**cfg.create_dummy_input_kwargs)
6464

6565
def initialize_parameters_recursively(
66-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
66+
self, prng_key: utils.Tensor, *, prebuilt: Optional[NestedTensor] = None
6767
) -> NestedTensor:
6868
if self._use_prebuilt_params(prebuilt):
6969
return prebuilt

Diff for: axlearn/common/attention.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def transform_factorization_spec(
875875
)
876876

877877
def initialize_parameters_recursively(
878-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
878+
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
879879
) -> NestedTensor:
880880
if self._use_prebuilt_params(prebuilt):
881881
return prebuilt
@@ -2735,7 +2735,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
27352735
self._layers.append(self._add_child(f"layer{i}", layer_cfg))
27362736

27372737
def initialize_parameters_recursively(
2738-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
2738+
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
27392739
) -> NestedTensor:
27402740
cfg = self.config # type: StackedTransformerLayer.Config
27412741
prng_key = split_prng_key(prng_key, cfg.num_layers)
@@ -3057,7 +3057,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
30573057
self._add_child("repeat", repeat_cfg)
30583058

30593059
def initialize_parameters_recursively(
3060-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
3060+
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
30613061
) -> NestedTensor:
30623062
# We need to call self.repeat.initialize_parameters_recursively() with the same prng_key
30633063
# to ensure initialization parity with StackedTransformerLayer.
@@ -3188,7 +3188,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
31883188
self._add_child("pipeline", pipeline_cfg)
31893189

31903190
def initialize_parameters_recursively(
3191-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
3191+
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
31923192
) -> NestedTensor:
31933193
cfg = self.config # type: PipelinedTransformerLayer.Config
31943194
# We pre-split all num_layers keys to ensure initialization parity with

Diff for: axlearn/common/base_layer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class RematSpec:
121121
class ParameterNoise(Configurable):
122122
"""An interface for applying parameter noise."""
123123

124-
def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
124+
def apply(self, prng_key: Tensor, params: NestedTensor) -> NestedTensor:
125125
"""To be implemented by subclasses."""
126126
raise NotImplementedError(self)
127127

@@ -275,7 +275,7 @@ def create_parameter_specs_recursively(self) -> NestedParameterSpec:
275275
return specs
276276

277277
def initialize_parameters_recursively(
278-
self, prng_key: jax.random.KeyArray, *, prebuilt: Optional[NestedTensor] = None
278+
self, prng_key: Tensor, *, prebuilt: Optional[NestedTensor] = None
279279
) -> NestedTensor:
280280
params = {}
281281
param_specs = self._create_layer_parameter_specs()
@@ -318,7 +318,7 @@ def _use_prebuilt_params(self, prebuilt: Optional[NestedTensor]) -> bool:
318318
return True
319319

320320
def _initialize_parameter(
321-
self, name: str, *, prng_key: jax.random.KeyArray, parameter_spec: ParameterSpec
321+
self, name: str, *, prng_key: Tensor, parameter_spec: ParameterSpec
322322
) -> Tensor:
323323
"""Adds a parameter with the given name and shape.
324324
@@ -345,7 +345,7 @@ def _initialize_parameter(
345345
return param
346346

347347
def apply_parameter_noise_recursively(
348-
self, prng_key: jax.random.KeyArray, params: NestedTensor
348+
self, prng_key: Tensor, params: NestedTensor
349349
) -> NestedTensor:
350350
"""Applies parameter noise recursively on `params`.
351351

Diff for: axlearn/common/base_layer_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class ParameterScaler(ParameterNoise):
102102
class Config(ParameterNoise.Config):
103103
scale: float = 1.0
104104

105-
def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
105+
def apply(self, prng_key: utils.Tensor, params: NestedTensor) -> NestedTensor:
106106
cfg = self.config
107107
return jax.tree_util.tree_map(lambda x: x * cfg.scale, params)
108108

Diff for: axlearn/common/conformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from typing import Optional, Tuple, Union
2121

22-
import jax
2322
from jax import numpy as jnp
2423

2524
from axlearn.common.attention import (
@@ -328,7 +327,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
328327

329328
def initialize_parameters_recursively(
330329
self,
331-
prng_key: jax.random.KeyArray,
330+
prng_key: Tensor,
332331
*,
333332
prebuilt: Optional[NestedTensor] = None,
334333
) -> NestedTensor:

Diff for: axlearn/common/decoding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ class DecodingState(NamedTuple):
729729
# The current state of the autoregressive decoding caches.
730730
cache: NestedTensor
731731
# Random generator state.
732-
prng_key: jax.random.KeyArray
732+
prng_key: Tensor
733733

734734

735735
def _decode_init(
@@ -739,7 +739,7 @@ def _decode_init(
739739
num_decodes: int,
740740
max_decode_len: int,
741741
cache: NestedTensor,
742-
prng_key: jax.random.KeyArray,
742+
prng_key: Tensor,
743743
pad_id: int,
744744
token_scores: Optional[Tensor] = None,
745745
) -> DecodingState:
@@ -902,7 +902,7 @@ def sample_decode(
902902
tokens_to_scores: Callable[[Tensor, NestedTensor], Tuple[Tensor, NestedTensor]],
903903
stop_decoding_condition: StopDecodingCondition,
904904
num_decodes: int,
905-
prng_key: jax.random.KeyArray,
905+
prng_key: Tensor,
906906
max_decode_len: Optional[int] = None,
907907
loop: Literal["lax", "python"] = "lax",
908908
pad_id: int = 0,

Diff for: axlearn/common/evaler.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def __init__(
103103
if mesh.empty:
104104
raise RuntimeError("MetricCalculator should be created within the context of a mesh")
105105

106-
def init_state(
107-
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
108-
) -> NestedTensor:
106+
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
109107
"""Initializes the state.
110108
111109
Will be called at the beginning of an evaluation step.
@@ -212,7 +210,7 @@ def _call_model(
212210
self,
213211
*,
214212
method: str,
215-
prng_key: jax.random.KeyArray,
213+
prng_key: Tensor,
216214
model_params: NestedTensor,
217215
input_batch: NestedTensor,
218216
**kwargs,
@@ -285,9 +283,7 @@ def __init__(
285283
self._metric_accumulator = None
286284
self._jit_forward = self._pjit(self._forward_in_pjit)
287285

288-
def init_state(
289-
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
290-
) -> NestedTensor:
286+
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
291287
cfg = self.config
292288
self._metric_accumulator = cfg.metric_accumulator.instantiate()
293289
return dict(prng_key=prng_key)
@@ -308,7 +304,7 @@ def forward(
308304
def _forward_in_pjit(
309305
self,
310306
model_params: NestedTensor,
311-
prng_key: jax.random.KeyArray,
307+
prng_key: Tensor,
312308
input_batch: NestedTensor,
313309
) -> Dict[str, NestedTensor]:
314310
"""Calls `self._model` and returns summaries."""
@@ -379,9 +375,7 @@ def __init__(
379375
model_param_partition_specs=model_param_partition_specs,
380376
)
381377

382-
def init_state(
383-
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
384-
) -> NestedTensor:
378+
def init_state(self, *, prng_key: Tensor, model_params: NestedTensor) -> NestedTensor:
385379
states = {}
386380
for name, calculator in self.children.items():
387381
states[name] = calculator.init_state(prng_key=prng_key, model_params=model_params)
@@ -525,11 +519,11 @@ def eval_step(
525519
self,
526520
step: int,
527521
*,
528-
prng_key: jax.random.KeyArray,
522+
prng_key: Tensor,
529523
model_params: NestedTensor,
530524
return_aux: bool = False,
531525
train_summaries: Optional[NestedTensor] = None,
532-
) -> Tuple[jax.random.KeyArray, Optional[Dict[str, Any]], Optional[List[NestedTensor]]]:
526+
) -> Tuple[Tensor, Optional[Dict[str, Any]], Optional[List[NestedTensor]]]:
533527
"""Runs eval for the given step.
534528
535529
Args:
@@ -682,7 +676,7 @@ def __init__(
682676
self._metric_accumulator: MetricAccumulator = None
683677

684678
def init_state( # pylint: disable=duplicate-code
685-
self, *, prng_key: jax.random.KeyArray, model_params: NestedTensor
679+
self, *, prng_key: Tensor, model_params: NestedTensor
686680
) -> NestedTensor:
687681
self._metric_accumulator = MetricAccumulator.default_config().instantiate()
688682
return dict(prng_key=prng_key)
@@ -724,7 +718,7 @@ def forward(
724718
def _predict_in_pjit(
725719
self,
726720
model_params: NestedTensor,
727-
prng_key: jax.random.KeyArray,
721+
prng_key: Tensor,
728722
input_batch: NestedTensor,
729723
) -> Dict[str, NestedTensor]:
730724
"""Core function that calls model's predict() method for each batch and will be pjit-ed."""
@@ -759,7 +753,7 @@ def _calculate_metrics(self, outputs: PredictionOutputs) -> Dict[str, Tensor]:
759753
def _compute_metrics_in_pjit(
760754
self,
761755
model_params: NestedTensor,
762-
prng_key: jax.random.KeyArray,
756+
prng_key: Tensor,
763757
outputs: List[PredictionOutputs],
764758
) -> Dict[str, NestedTensor]:
765759
"""Computes metrics and returns them in "replicated"."""

Diff for: axlearn/common/evaler_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _call_model(
152152
self,
153153
*,
154154
method: str,
155-
prng_key: jax.random.KeyArray,
155+
prng_key: Tensor,
156156
model_params: NestedTensor,
157157
input_batch: NestedTensor,
158158
**kwargs,

Diff for: axlearn/common/inference.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
NestedPartitionSpec,
4646
NestedTensor,
4747
PartitionSpec,
48+
Tensor,
4849
TensorSpec,
4950
)
5051

@@ -56,12 +57,12 @@ class MethodRunner:
5657
def __init__(
5758
self,
5859
*,
59-
prng_key: jax.random.KeyArray,
60+
prng_key: Tensor,
6061
mesh: jax.sharding.Mesh,
6162
input_batch_partition_spec: DataPartitionType,
6263
jit_run_on_batch: Callable[
63-
[jax.random.KeyArray, NestedTensor],
64-
Tuple[jax.random.KeyArray, NestedTensor, NestedTensor],
64+
[Tensor, NestedTensor],
65+
Tuple[Tensor, NestedTensor, NestedTensor],
6566
],
6667
):
6768
"""Initializes MethodRunner object.
@@ -141,7 +142,7 @@ def __call__(self, input_batch: NestedTensor) -> Output:
141142
class _InferenceRunnerState(NamedTuple):
142143
"""Contains inference runner {state | state-partition-specs}."""
143144

144-
prng_key: Union[jax.random.KeyArray, NestedPartitionSpec]
145+
prng_key: Union[Tensor, NestedPartitionSpec]
145146
model: Union[NestedTensor, NestedPartitionSpec]
146147
learner: Optional[Union[NestedTensor, NestedPartitionSpec]] = None
147148

@@ -255,7 +256,7 @@ def run(
255256
input_batches: Iterable[NestedTensor],
256257
*,
257258
method: str,
258-
prng_key: Optional[jax.random.KeyArray] = None,
259+
prng_key: Optional[Tensor] = None,
259260
**kwargs,
260261
) -> Generator[NestedTensor, None, None]:
261262
"""Runs inference on the provided input batches.
@@ -296,7 +297,7 @@ def create_method_runner(
296297
self,
297298
*,
298299
method: str,
299-
prng_key: Optional[jax.random.KeyArray] = None,
300+
prng_key: Optional[Tensor] = None,
300301
**kwargs,
301302
) -> MethodRunner:
302303
"""Creates MethodRunner for the specified method and arguments.
@@ -361,13 +362,13 @@ def inference_iter(model_params, prng_key, input_batch):
361362

362363
def _inference_iter(
363364
self,
364-
prng_key: jax.random.KeyArray,
365+
prng_key: Tensor,
365366
model_params: NestedTensor,
366367
input_batch: Dict[str, Any],
367368
*,
368369
method,
369370
**kwargs,
370-
) -> Tuple[jax.random.KeyArray, NestedTensor, NestedTensor]:
371+
) -> Tuple[Tensor, NestedTensor, NestedTensor]:
371372
"""Implements inference for a single input batch."""
372373
cfg = self.config
373374
new_prng_key, iter_key = jax.random.split(prng_key)

Diff for: axlearn/common/inference_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def initialize(
106106
self,
107107
name: str,
108108
*,
109-
prng_key: jax.random.KeyArray,
109+
prng_key: Tensor,
110110
shape: Shape,
111111
dtype: jnp.dtype,
112112
axes: Optional[FanAxes] = None,
@@ -238,7 +238,7 @@ def _build_ckpt(
238238
root_dir: str,
239239
mesh_shape: Tuple[int, int],
240240
mesh_axis_names: Tuple[str, str],
241-
prng_key: jax.random.KeyArray,
241+
prng_key: Tensor,
242242
use_ema: bool = False,
243243
) -> Tuple[NestedTensor, str]:
244244
devices = mesh_utils.create_device_mesh(mesh_shape)

Diff for: axlearn/common/layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,7 @@ class Config(ParameterNoise.Config):
17771777

17781778
vn_std: Required[float] = REQUIRED
17791779

1780-
def apply(self, prng_key: jax.random.KeyArray, params: NestedTensor) -> NestedTensor:
1780+
def apply(self, prng_key: Tensor, params: NestedTensor) -> NestedTensor:
17811781
cfg = self.config
17821782
if cfg.vn_std <= 0:
17831783
return params

Diff for: axlearn/common/metrics_glue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Config(ModelSummaryAccumulator.Config):
7777
def _forward_in_pjit(
7878
self,
7979
model_params: NestedTensor,
80-
prng_key: jax.random.KeyArray,
80+
prng_key: Tensor,
8181
input_batch: NestedTensor,
8282
) -> Dict[str, NestedTensor]:
8383
"""Calls `self._model` and returns summaries."""

Diff for: axlearn/common/module.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class InvocationContext: # pylint: disable=too-many-instance-attributes
149149
# The state of the module.
150150
state: NestedTensor
151151
is_training: bool
152-
prng_key: Optional[jax.random.KeyArray]
152+
prng_key: Optional[Tensor]
153153
output_collection: OutputCollection
154154

155155
def path(self):
@@ -670,7 +670,7 @@ def is_training(self) -> bool:
670670
return self.get_invocation_context().is_training
671671

672672
@property
673-
def prng_key(self) -> jax.random.KeyArray:
673+
def prng_key(self) -> Tensor:
674674
return self.get_invocation_context().prng_key
675675

676676
@property
@@ -724,7 +724,7 @@ def nullary():
724724

725725
def functional(
726726
module: Module,
727-
prng_key: Optional[jax.random.KeyArray],
727+
prng_key: Optional[Tensor],
728728
state: NestedTensor,
729729
inputs: Union[Sequence[Any], Dict[str, Any]],
730730
*,

0 commit comments

Comments
 (0)