Skip to content

Commit 2d66fd7

Browse files
authored
Fix BLAS_Order.RowMajor import and similar in test_cython_blas with Cython 3.1 (scikit-learn#31301)
1 parent f0c80e8 commit 2d66fd7

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

sklearn/utils/tests/test_cython_blas.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import pytest
33

44
from sklearn.utils._cython_blas import (
5-
ColMajor,
6-
NoTrans,
7-
RowMajor,
8-
Trans,
5+
BLAS_Order,
6+
BLAS_Trans,
97
_asum_memview,
108
_axpy_memview,
119
_copy_memview,
@@ -30,7 +28,7 @@ def _numpy_to_cython(dtype):
3028

3129

3230
RTOL = {np.float32: 1e-6, np.float64: 1e-12}
33-
ORDER = {RowMajor: "C", ColMajor: "F"}
31+
ORDER = {BLAS_Order.RowMajor: "C", BLAS_Order.ColMajor: "F"}
3432

3533

3634
def _no_op(x):
@@ -166,9 +164,15 @@ def test_rot(dtype):
166164

167165
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
168166
@pytest.mark.parametrize(
169-
"opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
167+
"opA, transA",
168+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
169+
ids=["NoTrans", "Trans"],
170+
)
171+
@pytest.mark.parametrize(
172+
"order",
173+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
174+
ids=["RowMajor", "ColMajor"],
170175
)
171-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
172176
def test_gemv(dtype, opA, transA, order):
173177
gemv = _gemv_memview[_numpy_to_cython(dtype)]
174178

@@ -187,7 +191,11 @@ def test_gemv(dtype, opA, transA, order):
187191

188192

189193
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
190-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
194+
@pytest.mark.parametrize(
195+
"order",
196+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
197+
ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"],
198+
)
191199
def test_ger(dtype, order):
192200
ger = _ger_memview[_numpy_to_cython(dtype)]
193201

@@ -207,12 +215,20 @@ def test_ger(dtype, order):
207215

208216
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
209217
@pytest.mark.parametrize(
210-
"opB, transB", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
218+
"opB, transB",
219+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
220+
ids=["NoTrans", "Trans"],
221+
)
222+
@pytest.mark.parametrize(
223+
"opA, transA",
224+
[(_no_op, BLAS_Trans.NoTrans), (np.transpose, BLAS_Trans.Trans)],
225+
ids=["NoTrans", "Trans"],
211226
)
212227
@pytest.mark.parametrize(
213-
"opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], ids=["NoTrans", "Trans"]
228+
"order",
229+
[BLAS_Order.RowMajor, BLAS_Order.ColMajor],
230+
ids=["BLAS_Order.RowMajor", "BLAS_Order.ColMajor"],
214231
)
215-
@pytest.mark.parametrize("order", [RowMajor, ColMajor], ids=["RowMajor", "ColMajor"])
216232
def test_gemm(dtype, opA, transA, opB, transB, order):
217233
gemm = _gemm_memview[_numpy_to_cython(dtype)]
218234

0 commit comments

Comments
 (0)