Skip to content

Commit df5b27c

Browse files
authored
test: Converting elemwise_example.py to CodSpeed benchmark (#749)
1 parent 09adc98 commit df5b27c

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

benchmarks/test_elemwise.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import importlib
2+
import itertools
3+
import operator
4+
import os
5+
6+
import sparse
7+
8+
import pytest
9+
10+
import numpy as np
11+
import scipy.sparse as sps
12+
13+
DENSITY = 0.001
14+
15+
16+
def get_test_id(side):
17+
return f"{side=}"
18+
19+
20+
@pytest.fixture(params=[100, 500, 1000], ids=get_test_id)
21+
def elemwise_args(request, seed, max_size):
22+
side = request.param
23+
if side**2 >= max_size:
24+
pytest.skip()
25+
rng = np.random.default_rng(seed=seed)
26+
s1_sps = sps.random(side, side, format="csr", density=DENSITY, random_state=rng) * 10
27+
s1_sps.sum_duplicates()
28+
s2_sps = sps.random(side, side, format="csr", density=DENSITY, random_state=rng) * 10
29+
s2_sps.sum_duplicates()
30+
return s1_sps, s2_sps
31+
32+
33+
def get_elemwise_id(param):
34+
f, backend = param
35+
return f"{f=}-{backend=}"
36+
37+
38+
@pytest.fixture(
39+
params=itertools.product([operator.add, operator.mul, operator.gt], ["SciPy", "Numba", "Finch"]),
40+
scope="function",
41+
ids=get_elemwise_id,
42+
)
43+
def backend(request):
44+
f, backend = request.param
45+
os.environ[sparse._ENV_VAR_NAME] = backend
46+
importlib.reload(sparse)
47+
yield f, sparse, backend
48+
del os.environ[sparse._ENV_VAR_NAME]
49+
importlib.reload(sparse)
50+
51+
52+
def test_elemwise(benchmark, backend, elemwise_args):
53+
s1_sps, s2_sps = elemwise_args
54+
f, sparse, backend = backend
55+
56+
if backend == "SciPy":
57+
s1 = s1_sps
58+
s2 = s2_sps
59+
elif backend == "Numba":
60+
s1 = sparse.asarray(s1_sps)
61+
s2 = sparse.asarray(s2_sps)
62+
elif backend == "Finch":
63+
s1 = sparse.asarray(s1_sps.asformat("csc"), format="csc")
64+
s2 = sparse.asarray(s2_sps.asformat("csc"), format="csc")
65+
66+
f(s1, s2)
67+
68+
@benchmark
69+
def bench():
70+
f(s1, s2)

0 commit comments

Comments
 (0)