Skip to content

[IR] Create tensor adaptor for mlx #2060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions onnxscript/ir/tensor_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

__all__ = [
"TorchTensor",
"MlxTensor",
]

import ctypes
Expand All @@ -39,14 +40,20 @@

from onnxscript import ir
from onnxscript.ir import _core
import ml_dtypes

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "ml_dtypes" should be placed before first party imports "onnxscript.ir", "onnxscript.ir._core" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused import ml_dtypes (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

ml\_dtypes imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
import numpy as np

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "numpy" should be placed before first party imports "onnxscript.ir", "onnxscript.ir._core" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

Check notice

Code scanning / lintrunner

PYLINT/C0412 Note

Imports from package numpy are not grouped (ungrouped-imports)
See ungrouped-imports. To disable, use # pylint: disable=ungrouped-imports

if TYPE_CHECKING:
import torch
import mlx.core as mx

Check warning on line 48 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L48

Added line #L48 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/import-not-found Error

Cannot find implementation or library stub for module named "mlx.core" To disable, use # type: ignore[import-not-found]

Check failure

Code scanning / lintrunner

MYPY/import-not-found Error

Cannot find implementation or library stub for module named "mlx" To disable, use # type: ignore[import-not-found]


class TorchTensor(_core.Tensor):
def __init__(
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
self,
tensor: torch.Tensor,
name: str | None = None,
doc_string: str | None = None,
):
# Pass the tensor as the raw data to ir.Tensor's constructor
import torch
Expand All @@ -73,23 +80,33 @@
torch.uint64: ir.DataType.UINT64,
}
super().__init__(
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
tensor,
dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype],
name=name,
doc_string=doc_string,
)

def numpy(self) -> npt.NDArray:
import torch

self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True)
return (
self.raw.view(torch.uint16)
.numpy(force=True)
.view(dtype=self.dtype.numpy())
)
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
# TODO: Use ml_dtypes
return self.raw.view(torch.uint8).numpy(force=True)
return (
self.raw.view(torch.uint8)
.numpy(force=True)
.view(dtype=self.dtype.numpy())
)
return self.raw.numpy(force=True)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
Expand Down Expand Up @@ -120,3 +137,47 @@
tensor.data_ptr()
)
)


class MlxTensor(_core.Tensor):
def __init__(
self, tensor: mx.array, name: str | None = None, doc_string: str | None = None
):
import mlx.core as mx

Check warning on line 146 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L146

Added line #L146 was not covered by tests

_MLX_DTYPE_TO_ONNX: dict[mx.Dtype, ir.DataType] = {

Check warning on line 148 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L148

Added line #L148 was not covered by tests
mx.bfloat16: ir.DataType.BFLOAT16,
mx.complex64: ir.DataType.COMPLEX64,
mx.float16: ir.DataType.FLOAT16,
mx.float32: ir.DataType.FLOAT,
mx.int16: ir.DataType.INT16,
mx.int32: ir.DataType.INT32,
mx.int64: ir.DataType.INT64,
mx.int8: ir.DataType.INT8,
mx.uint8: ir.DataType.UINT8,
mx.uint16: ir.DataType.UINT16,
mx.uint32: ir.DataType.UINT32,
mx.uint64: ir.DataType.UINT64,
}
super().__init__(

Check warning on line 162 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L162

Added line #L162 was not covered by tests
tensor,
dtype=_MLX_DTYPE_TO_ONNX[tensor.dtype],
name=name,
doc_string=doc_string,
)

def numpy(self) -> npt.NDArray:
import mlx.core as mx

Check warning on line 170 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L170

Added line #L170 was not covered by tests

self.raw: mx.array

Check warning on line 172 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L172

Added line #L172 was not covered by tests
if self.dtype == ir.DataType.BFLOAT16:
return np.array(self.raw.view(mx.uint16), copy=False).view(

Check warning on line 174 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L174

Added line #L174 was not covered by tests
dtype=self.dtype.numpy()
)
return np.array(self.raw, copy=False)

Check warning on line 177 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L177

Added line #L177 was not covered by tests

def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
del copy # Unused, but needed for the signature

Check warning on line 180 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L180

Added line #L180 was not covered by tests
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)

Check warning on line 183 in onnxscript/ir/tensor_adapters.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/tensor_adapters.py#L182-L183

Added lines #L182 - L183 were not covered by tests
Loading