Skip to content

Commit c97d272

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:SC] Enable load_gather_masked tests with TC tiling
PiperOrigin-RevId: 832340622
1 parent 0ec9dbf commit c97d272

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,6 @@ def kernel(x_ref, indices_ref, o_ref):
604604
)
605605

606606
def test_load_gather_1d(self):
607-
self.skip_if_tc_tiling()
608607
x = jnp.arange(8)
609608
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
610609

@@ -615,7 +614,6 @@ def kernel(x_ref, indices_ref, o_ref):
615614
np.testing.assert_array_equal(kernel(x, indices), x[indices])
616615

617616
def test_load_gather_2d(self):
618-
self.skip_if_tc_tiling()
619617
x = jnp.arange(8 * 8).reshape(8, -1)
620618
indices0 = indices1 = jax.random.permutation(
621619
jax.random.key(42), jnp.arange(8)
@@ -649,7 +647,6 @@ def kernel(x_ref, indices_ref, o_ref):
649647

650648
@parameterized.parameters(*MASK_FNS)
651649
def test_load_gather_masked(self, mask_fn):
652-
self.skip_if_tc_tiling()
653650
x = jnp.arange(8)
654651
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
655652

@@ -683,7 +680,6 @@ def kernel(x_ref, indices_ref, o_ref):
683680

684681
@parameterized.parameters(*MASK_FNS)
685682
def test_store_scatter_masked(self, mask_fn):
686-
self.skip_if_tc_tiling()
687683
x = jnp.arange(8)
688684
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
689685

@@ -700,7 +696,6 @@ def kernel(x_ref, indices_ref, o_ref):
700696
)
701697

702698
def test_store_scatter_2d(self):
703-
self.skip_if_tc_tiling()
704699
if not jtu.if_cloud_tpu_at_least(2025, 10, 31):
705700
self.skipTest("Needs a newer libtpu")
706701

0 commit comments

Comments
 (0)