@@ -604,7 +604,6 @@ def kernel(x_ref, indices_ref, o_ref):
604604 )
605605
606606 def test_load_gather_1d (self ):
607- self .skip_if_tc_tiling ()
608607 x = jnp .arange (8 )
609608 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
610609
@@ -615,7 +614,6 @@ def kernel(x_ref, indices_ref, o_ref):
615614 np .testing .assert_array_equal (kernel (x , indices ), x [indices ])
616615
617616 def test_load_gather_2d (self ):
618- self .skip_if_tc_tiling ()
619617 x = jnp .arange (8 * 8 ).reshape (8 , - 1 )
620618 indices0 = indices1 = jax .random .permutation (
621619 jax .random .key (42 ), jnp .arange (8 )
@@ -649,7 +647,6 @@ def kernel(x_ref, indices_ref, o_ref):
649647
650648 @parameterized .parameters (* MASK_FNS )
651649 def test_load_gather_masked (self , mask_fn ):
652- self .skip_if_tc_tiling ()
653650 x = jnp .arange (8 )
654651 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
655652
@@ -683,7 +680,6 @@ def kernel(x_ref, indices_ref, o_ref):
683680
684681 @parameterized .parameters (* MASK_FNS )
685682 def test_store_scatter_masked (self , mask_fn ):
686- self .skip_if_tc_tiling ()
687683 x = jnp .arange (8 )
688684 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
689685
@@ -700,7 +696,6 @@ def kernel(x_ref, indices_ref, o_ref):
700696 )
701697
702698 def test_store_scatter_2d (self ):
703- self .skip_if_tc_tiling ()
704699 if not jtu .if_cloud_tpu_at_least (2025 , 10 , 31 ):
705700 self .skipTest ("Needs a newer libtpu" )
706701
0 commit comments