@@ -105,6 +105,8 @@ class SparsecoreConfig:
105105 sharding_strategy: The sharding strategy to use for the embedding table.
106106 Defaults to 'MOD' sharding. See the sparsecore documentation for more
107107 details.
108+ allow_id_dropping: Whether to allow dropping of IDs that do not fit within
109+ the XLA buffers allocated for each partition. Defaults to False.
108110 num_sc_per_device: The number of sparsecores per Jax device. By default, a
109111 fixed mapping is used to determine this based on device 0. This may fail
110112 on newer TPU architectures if the mapping is not updated of if device 0 is
@@ -163,6 +165,7 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
163165 optimizer : OptimizerSpec
164166 sharding_axis : str | int = 0
165167 sharding_strategy : str = 'MOD'
168+ allow_id_dropping : bool = False
166169
167170 # TODO(aahil): Come up with better defaults / heuristics here.
168171 max_ids_per_partition_fn : Callable [[str , int ], int ] = dataclasses .field (
@@ -339,7 +342,7 @@ def _to_np(x: Any) -> np.ndarray:
339342 global_device_count = self .sparsecore_config .global_device_count ,
340343 num_sc_per_device = self .sparsecore_config .num_sc_per_device ,
341344 sharding_strategy = self .sparsecore_config .sharding_strategy ,
342- allow_id_dropping = False ,
345+ allow_id_dropping = self . sparsecore_config . allow_id_dropping ,
343346 batch_number = self ._batch_number ,
344347 )
345348
0 commit comments