|
4 | 4 | import pytest |
5 | 5 |
|
6 | 6 | from tpu_inference.models.jax.utils.multi_modal_utils import ( |
7 | | - MultiModalEmbeddings, NestedTensors, _flatten_embeddings, |
| 7 | + MultiModalEmbeddings, NestedTensors, flatten_embeddings, |
8 | 8 | merge_multimodal_embeddings, sanity_check_mm_encoder_outputs) |
9 | 9 |
|
10 | 10 | # --- Tests for sanity_check_mm_encoder_outputs --- |
@@ -65,45 +65,45 @@ def test_sanity_check_wrong_dimensions_in_list(): |
65 | 65 | sanity_check_mm_encoder_outputs(embeddings, 1) |
66 | 66 |
|
67 | 67 |
|
68 | | -# --- Tests for _flatten_embeddings --- |
| 68 | +# --- Tests for flatten_embeddings --- |
69 | 69 |
|
70 | 70 |
|
71 | 71 | def test_flatten_single_array(): |
72 | | - """Tests _flatten_embeddings with a single 2D array.""" |
| 72 | + """Tests flatten_embeddings with a single 2D array.""" |
73 | 73 | emb: NestedTensors = jnp.arange(12).reshape((3, 4)) |
74 | | - result = _flatten_embeddings(emb) |
| 74 | + result = flatten_embeddings(emb) |
75 | 75 | np.testing.assert_array_equal(result, emb) |
76 | 76 |
|
77 | 77 |
|
78 | 78 | def test_flatten_single_3d_array(): |
79 | | - """Tests _flatten_embeddings with a single 3D array.""" |
| 79 | + """Tests flatten_embeddings with a single 3D array.""" |
80 | 80 | emb: NestedTensors = jnp.arange(24).reshape((2, 3, 4)) |
81 | | - result = _flatten_embeddings(emb) |
| 81 | + result = flatten_embeddings(emb) |
82 | 82 | expected = jnp.arange(24).reshape((6, 4)) |
83 | 83 | np.testing.assert_array_equal(result, expected) |
84 | 84 |
|
85 | 85 |
|
86 | 86 | def test_flatten_list_of_arrays(): |
87 | | - """Tests _flatten_embeddings with a list of 2D arrays.""" |
| 87 | + """Tests flatten_embeddings with a list of 2D arrays.""" |
88 | 88 | emb: NestedTensors = [ |
89 | 89 | jnp.arange(12).reshape((3, 4)), |
90 | 90 | jnp.arange(12, 20).reshape((2, 4)) |
91 | 91 | ] |
92 | | - result = _flatten_embeddings(emb) |
| 92 | + result = flatten_embeddings(emb) |
93 | 93 | expected = jnp.arange(20).reshape((5, 4)) |
94 | 94 | np.testing.assert_array_equal(result, expected) |
95 | 95 |
|
96 | 96 |
|
97 | 97 | def test_flatten_nested_list(): |
98 | | - """Tests _flatten_embeddings with a nested list of arrays.""" |
| 98 | + """Tests flatten_embeddings with a nested list of arrays.""" |
99 | 99 | emb: NestedTensors = [ |
100 | 100 | jnp.arange(6).reshape((2, 3)), |
101 | 101 | [ |
102 | 102 | jnp.arange(6, 12).reshape((2, 3)), |
103 | 103 | jnp.arange(12, 15).reshape((1, 3)) |
104 | 104 | ] |
105 | 105 | ] |
106 | | - result = _flatten_embeddings(emb) |
| 106 | + result = flatten_embeddings(emb) |
107 | 107 | expected = jnp.arange(15).reshape((5, 3)) |
108 | 108 | np.testing.assert_array_equal(result, expected) |
109 | 109 |
|
@@ -191,7 +191,7 @@ def test_merge_mm_embeds_count_too_many_no_raise(placeholder_id, base_embeds): |
191 | 191 | # Check that the first 2 embeddings from mm_embeds_too_many were used. |
192 | 192 | expected = np.array(inputs_embeds) |
193 | 193 | is_mm = np.isin(input_ids, np.array(placeholder_id)) |
194 | | - expected[is_mm] = _flatten_embeddings(mm_embeds_too_many)[:2] |
| 194 | + expected[is_mm] = flatten_embeddings(mm_embeds_too_many)[:2] |
195 | 195 | np.testing.assert_array_equal(result, expected) |
196 | 196 | except Exception as e: |
197 | 197 | pytest.fail( |
|
0 commit comments