Skip to content

Commit 002ca3e

Browse files
authored
ENH: MLIR backend POC (#755)
1 parent df5b27c commit 002ca3e

File tree

12 files changed

+602
-2
lines changed

12 files changed

+602
-2
lines changed

.github/workflows/ci.yml

+28
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ jobs:
5454
with:
5555
files: ./**/coverage*.xml
5656

57+
test_mlir:
58+
runs-on: ubuntu-latest
59+
steps:
60+
- name: Checkout Repo
61+
uses: actions/checkout@v4
62+
- name: Setup Conda
63+
uses: conda-incubator/setup-miniconda@v3
64+
with:
65+
python-version: '3.10'
66+
channels: conda-forge
67+
activate-environment: sparse-dev
68+
miniforge-variant: Mambaforge
69+
miniforge-version: latest
70+
use-mamba: true
71+
- name: Update Conda Environment
72+
run: |
73+
mamba env update -n sparse-dev -f ci/environment.yml
74+
mamba run pip install '.[tests]'
75+
mamba install conda-forge::mlir-python-bindings
76+
- name: Build and run tests
77+
shell: bash -l {0}
78+
run: |
79+
conda activate sparse-dev
80+
SPARSE_BACKEND=MLIR pytest sparse/mlir_backend -v
81+
5782
examples:
5883
runs-on: ubuntu-latest
5984
steps:
@@ -71,6 +96,7 @@ jobs:
7196
- name: Run examples
7297
run: |
7398
source ci/test_examples.sh
99+
74100
notebooks:
75101
runs-on: ubuntu-latest
76102
steps:
@@ -87,6 +113,7 @@ jobs:
87113
- name: Run notebooks
88114
run: |
89115
source ci/test_notebooks.sh
116+
90117
array_api_tests:
91118
strategy:
92119
matrix:
@@ -121,6 +148,7 @@ jobs:
121148
run: |
122149
cd ${GITHUB_WORKSPACE}/array-api-tests
123150
pytest array_api_tests -v -c pytest.ini -n 4 --max-examples=2 --derandomize --disable-deadline -o xfail_strict=True --xfails-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-xfails.txt --skips-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-skips.txt
151+
124152
on:
125153
# Trigger the workflow on push or pull request,
126154
# but only for the main branch

conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def add_doctest_modules(doctest_namespace):
88
import numpy as np
99

1010
if sparse._BackendType.Numba != sparse._BACKEND:
11-
pytest.skip()
11+
pass # TODO: pytest.skip() skips Finch and MLIR tests
1212

1313
doctest_namespace["np"] = np
1414
doctest_namespace["sparse"] = sparse

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ tests = [
4646
"pre-commit",
4747
"scipy",
4848
"sparse[finch]",
49-
"pytest-codspeed"
49+
"pytest-codspeed",
5050
]
5151
tox = ["sparse[tests]", "tox"]
5252
notebooks = ["sparse[tests]", "nbmake", "matplotlib"]

sparse/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class _BackendType(Enum):
1111
Numba = "Numba"
1212
Finch = "Finch"
13+
MLIR = "MLIR"
1314

1415

1516
_ENV_VAR_NAME = "SPARSE_BACKEND"
@@ -40,6 +41,9 @@ class SparseFutureWarning(FutureWarning):
4041
if _BackendType.Finch == _BACKEND:
4142
from sparse.finch_backend import * # noqa: F403
4243
from sparse.finch_backend import __all__
44+
elif _BackendType.MLIR == _BACKEND:
45+
from sparse.mlir_backend import * # noqa: F403
46+
from sparse.mlir_backend import __all__
4347
else:
4448
from sparse.numba_backend import * # noqa: F403
4549
from sparse.numba_backend import ( # noqa: F401

sparse/mlir_backend/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
try:
2+
import mlir # noqa: F401
3+
except ModuleNotFoundError as e:
4+
raise ImportError(
5+
"MLIR Python bindings not installed. Run "
6+
"`conda install conda-forge::mlir-python-bindings` "
7+
"to enable MLIR backend."
8+
) from e
9+
10+
from ._constructors import (
11+
asarray,
12+
)
13+
from ._ops import (
14+
add,
15+
)
16+
17+
__all__ = [
18+
"add",
19+
"asarray",
20+
]

sparse/mlir_backend/_constructors.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import ctypes
2+
import ctypes.util
3+
4+
import mlir.execution_engine
5+
import mlir.passmanager
6+
from mlir import ir
7+
from mlir.dialects import arith, bufferization, func, sparse_tensor, tensor
8+
9+
import numpy as np
10+
import scipy.sparse as sps
11+
12+
from ._core import DEBUG, MLIR_C_RUNNER_UTILS, SCRIPT_PATH, ctx
13+
from ._dtypes import DType, Float64, Index
14+
from ._memref import MemrefF64_1D, MemrefIdx_1D
15+
16+
17+
class Tensor:
18+
def __init__(self, obj, module, tensor_type, disassemble_fn, values_dtype, index_dtype):
19+
self.obj = obj
20+
self.module = module
21+
self.tensor_type = tensor_type
22+
self.disassemble_fn = disassemble_fn
23+
self.values_dtype = values_dtype
24+
self.index_dtype = index_dtype
25+
26+
def __del__(self):
27+
self.module.invoke("free_tensor", ctypes.pointer(self.obj))
28+
29+
def to_scipy_sparse(self):
30+
"""
31+
Returns scipy.sparse or ndarray
32+
"""
33+
return self.disassemble_fn(self.module, self.obj)
34+
35+
36+
class DenseFormat:
37+
modules = {}
38+
39+
def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
40+
with ir.Location.unknown(ctx):
41+
module = ir.Module.create()
42+
values_dtype = values_dtype.get()
43+
index_dtype = index_dtype.get()
44+
index_width = getattr(index_dtype, "width", 0)
45+
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.dense)
46+
ordering = ir.AffineMap.get_permutation([0, 1])
47+
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
48+
dense_shaped = ir.RankedTensorType.get(list(shape), values_dtype, encoding)
49+
tensor_1d = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size()], values_dtype)
50+
51+
with ir.InsertionPoint(module.body):
52+
53+
@func.FuncOp.from_py_func(tensor_1d)
54+
def assemble(data):
55+
return sparse_tensor.assemble(dense_shaped, data, [])
56+
57+
@func.FuncOp.from_py_func(dense_shaped)
58+
def disassemble(tensor_shaped):
59+
data = tensor.EmptyOp([arith.constant(ir.IndexType.get(), 0)], values_dtype)
60+
data, data_len = sparse_tensor.disassemble(
61+
tensor_1d,
62+
[],
63+
index_dtype,
64+
[],
65+
tensor_shaped,
66+
data,
67+
[],
68+
)
69+
shape_x = arith.constant(index_dtype, shape[0])
70+
shape_y = arith.constant(index_dtype, shape[1])
71+
return data, data_len, shape_x, shape_y
72+
73+
@func.FuncOp.from_py_func(dense_shaped)
74+
def free_tensor(tensor_shaped):
75+
bufferization.dealloc_tensor(tensor_shaped)
76+
77+
assemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
78+
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
79+
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
80+
if DEBUG:
81+
(SCRIPT_PATH / "dense_module.mlir").write_text(str(module))
82+
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
83+
pm.run(module.operation)
84+
if DEBUG:
85+
(SCRIPT_PATH / "dense_module_opt.mlir").write_text(str(module))
86+
87+
module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
88+
return (module, dense_shaped)
89+
90+
@classmethod
91+
def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
92+
data = MemrefF64_1D.from_numpy(arr.flatten())
93+
out = ctypes.c_void_p()
94+
module.invoke(
95+
"assemble",
96+
ctypes.pointer(ctypes.pointer(data)),
97+
ctypes.pointer(out),
98+
)
99+
return out
100+
101+
@classmethod
102+
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> np.ndarray:
103+
class Dense(ctypes.Structure):
104+
_fields_ = [
105+
("data", MemrefF64_1D),
106+
("data_len", np.ctypeslib.c_intp),
107+
("shape_x", np.ctypeslib.c_intp),
108+
("shape_y", np.ctypeslib.c_intp),
109+
]
110+
111+
def to_np(self) -> np.ndarray:
112+
data = self.data.to_numpy()[: self.data_len]
113+
return data.copy().reshape((self.shape_x, self.shape_y))
114+
115+
arr = Dense()
116+
module.invoke(
117+
"disassemble",
118+
ctypes.pointer(ctypes.pointer(arr)),
119+
ctypes.pointer(ptr),
120+
)
121+
return arr.to_np()
122+
123+
124+
class COOFormat:
125+
modules = {}
126+
# TODO: implement
127+
128+
129+
class CSRFormat:
130+
modules = {}
131+
132+
def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
133+
with ir.Location.unknown(ctx):
134+
module = ir.Module.create()
135+
values_dtype = values_dtype.get()
136+
index_dtype = index_dtype.get()
137+
index_width = getattr(index_dtype, "width", 0)
138+
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.compressed)
139+
ordering = ir.AffineMap.get_permutation([0, 1])
140+
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
141+
csr_shaped = ir.RankedTensorType.get(list(shape), values_dtype, encoding)
142+
143+
tensor_1d_index = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size()], index_dtype)
144+
tensor_1d_values = tensor.RankedTensorType.get([ir.ShapedType.get_dynamic_size()], values_dtype)
145+
146+
with ir.InsertionPoint(module.body):
147+
148+
@func.FuncOp.from_py_func(tensor_1d_index, tensor_1d_index, tensor_1d_values)
149+
def assemble(pos, crd, data):
150+
return sparse_tensor.assemble(csr_shaped, data, (pos, crd))
151+
152+
@func.FuncOp.from_py_func(csr_shaped)
153+
def disassemble(tensor_shaped):
154+
pos = tensor.EmptyOp([arith.constant(ir.IndexType.get(), 0)], index_dtype)
155+
crd = tensor.EmptyOp([arith.constant(ir.IndexType.get(), 0)], index_dtype)
156+
data = tensor.EmptyOp([arith.constant(ir.IndexType.get(), 0)], values_dtype)
157+
data, pos, crd, data_len, pos_len, crd_len = sparse_tensor.disassemble(
158+
tensor_1d_values,
159+
(tensor_1d_index, tensor_1d_index),
160+
index_dtype,
161+
(index_dtype, index_dtype),
162+
tensor_shaped,
163+
data,
164+
(pos, crd),
165+
)
166+
shape_x = arith.constant(index_dtype, shape[0])
167+
shape_y = arith.constant(index_dtype, shape[1])
168+
return data, pos, crd, data_len, pos_len, crd_len, shape_x, shape_y
169+
170+
@func.FuncOp.from_py_func(csr_shaped)
171+
def free_tensor(tensor_shaped):
172+
bufferization.dealloc_tensor(tensor_shaped)
173+
174+
assemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
175+
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
176+
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
177+
if DEBUG:
178+
(SCRIPT_PATH / "scr_module.mlir").write_text(str(module))
179+
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
180+
pm.run(module.operation)
181+
if DEBUG:
182+
(SCRIPT_PATH / "csr_module_opt.mlir").write_text(str(module))
183+
184+
module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
185+
return (module, csr_shaped)
186+
187+
@classmethod
188+
def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
189+
out = ctypes.c_void_p()
190+
module.invoke(
191+
"assemble",
192+
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indptr))),
193+
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indices))),
194+
ctypes.pointer(ctypes.pointer(MemrefF64_1D.from_numpy(arr.data))),
195+
ctypes.pointer(out),
196+
)
197+
return out
198+
199+
@classmethod
200+
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> sps.csr_array:
201+
class Csr(ctypes.Structure):
202+
_fields_ = [
203+
("data", MemrefF64_1D),
204+
("pos", MemrefIdx_1D),
205+
("crd", MemrefIdx_1D),
206+
("data_len", np.ctypeslib.c_intp),
207+
("pos_len", np.ctypeslib.c_intp),
208+
("crd_len", np.ctypeslib.c_intp),
209+
("shape_x", np.ctypeslib.c_intp),
210+
("shape_y", np.ctypeslib.c_intp),
211+
]
212+
213+
def to_sps(self) -> sps.csr_array:
214+
pos = self.pos.to_numpy()[: self.pos_len]
215+
crd = self.crd.to_numpy()[: self.crd_len]
216+
data = self.data.to_numpy()[: self.data_len]
217+
return sps.csr_array((data.copy(), crd.copy(), pos.copy()), shape=(self.shape_x, self.shape_y))
218+
219+
arr = Csr()
220+
module.invoke(
221+
"disassemble",
222+
ctypes.pointer(ctypes.pointer(arr)),
223+
ctypes.pointer(ptr),
224+
)
225+
return arr.to_sps()
226+
227+
228+
def _is_scipy_sparse_obj(x) -> bool:
229+
return hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse")
230+
231+
232+
def _is_numpy_obj(x) -> bool:
233+
return isinstance(x, np.ndarray)
234+
235+
236+
def asarray(obj) -> Tensor:
237+
# TODO: discover obj's dtype
238+
values_dtype = Float64
239+
index_dtype = Index
240+
241+
# TODO: support other scipy formats
242+
if _is_scipy_sparse_obj(obj):
243+
format_class = CSRFormat
244+
elif _is_numpy_obj(obj):
245+
format_class = DenseFormat
246+
else:
247+
raise Exception(f"{type(obj)} not supported.")
248+
249+
# TODO: support proper caching
250+
if hash(obj.shape) in format_class.modules:
251+
module, tensor_type = format_class.modules[hash(obj.shape)]
252+
else:
253+
module, tensor_type = format_class.get_module(obj.shape, values_dtype, index_dtype)
254+
format_class.modules[hash(obj.shape)] = module, tensor_type
255+
256+
assembled_obj = format_class.assemble(module, obj)
257+
return Tensor(assembled_obj, module, tensor_type, format_class.disassemble, values_dtype, index_dtype)

sparse/mlir_backend/_core.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import ctypes
2+
import os
3+
import pathlib
4+
5+
from mlir.ir import Context
6+
7+
DEBUG = bool(int(os.environ.get("DEBUG", "0")))
8+
SCRIPT_PATH = pathlib.Path(__file__).parent
9+
10+
MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
11+
libc = ctypes.CDLL(ctypes.util.find_library("c"))
12+
libc.free.argtypes = [ctypes.c_void_p]
13+
libc.free.restype = None
14+
15+
# TODO: remove global state
16+
ctx = Context()

0 commit comments

Comments
 (0)