-
Notifications
You must be signed in to change notification settings - Fork 63
[IR] Create tensor adaptor for mlx #2060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
|
||
__all__ = [ | ||
"TorchTensor", | ||
"MlxTensor", | ||
] | ||
|
||
import ctypes | ||
|
@@ -39,14 +40,20 @@ | |
|
||
from onnxscript import ir | ||
from onnxscript.ir import _core | ||
import ml_dtypes | ||
Check warningCode scanning / lintrunner PYLINT/W0611 Warning
Unused import ml_dtypes (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import Check warningCode scanning / lintrunner RUFF/F401 Warning
ml\_dtypes imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
import numpy as np | ||
Check noticeCode scanning / lintrunner PYLINT/C0411 Note
third party import "numpy" should be placed before first party imports "onnxscript.ir", "onnxscript.ir._core" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order Check noticeCode scanning / lintrunner PYLINT/C0412 Note
Imports from package numpy are not grouped (ungrouped-imports)
See ungrouped-imports. To disable, use # pylint: disable=ungrouped-imports |
||
|
||
if TYPE_CHECKING: | ||
import torch | ||
import mlx.core as mx | ||
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "mlx.core"
To disable, use # type: ignore[import-not-found]
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "mlx"
To disable, use # type: ignore[import-not-found]
|
||
|
||
|
||
class TorchTensor(_core.Tensor): | ||
def __init__( | ||
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None | ||
self, | ||
tensor: torch.Tensor, | ||
name: str | None = None, | ||
doc_string: str | None = None, | ||
): | ||
# Pass the tensor as the raw data to ir.Tensor's constructor | ||
import torch | ||
|
@@ -73,23 +80,33 @@ | |
torch.uint64: ir.DataType.UINT64, | ||
} | ||
super().__init__( | ||
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string | ||
tensor, | ||
dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], | ||
name=name, | ||
doc_string=doc_string, | ||
) | ||
|
||
def numpy(self) -> npt.NDArray: | ||
import torch | ||
|
||
self.raw: torch.Tensor | ||
if self.dtype == ir.DataType.BFLOAT16: | ||
return self.raw.view(torch.uint16).numpy(force=True) | ||
return ( | ||
self.raw.view(torch.uint16) | ||
.numpy(force=True) | ||
.view(dtype=self.dtype.numpy()) | ||
) | ||
if self.dtype in { | ||
ir.DataType.FLOAT8E4M3FN, | ||
ir.DataType.FLOAT8E4M3FNUZ, | ||
ir.DataType.FLOAT8E5M2, | ||
ir.DataType.FLOAT8E5M2FNUZ, | ||
}: | ||
# TODO: Use ml_dtypes | ||
return self.raw.view(torch.uint8).numpy(force=True) | ||
return ( | ||
self.raw.view(torch.uint8) | ||
.numpy(force=True) | ||
.view(dtype=self.dtype.numpy()) | ||
) | ||
return self.raw.numpy(force=True) | ||
|
||
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: | ||
|
@@ -120,3 +137,47 @@ | |
tensor.data_ptr() | ||
) | ||
) | ||
|
||
|
||
class MlxTensor(_core.Tensor): | ||
def __init__( | ||
self, tensor: mx.array, name: str | None = None, doc_string: str | None = None | ||
): | ||
import mlx.core as mx | ||
|
||
_MLX_DTYPE_TO_ONNX: dict[mx.Dtype, ir.DataType] = { | ||
mx.bfloat16: ir.DataType.BFLOAT16, | ||
mx.complex64: ir.DataType.COMPLEX64, | ||
mx.float16: ir.DataType.FLOAT16, | ||
mx.float32: ir.DataType.FLOAT, | ||
mx.int16: ir.DataType.INT16, | ||
mx.int32: ir.DataType.INT32, | ||
mx.int64: ir.DataType.INT64, | ||
mx.int8: ir.DataType.INT8, | ||
mx.uint8: ir.DataType.UINT8, | ||
mx.uint16: ir.DataType.UINT16, | ||
mx.uint32: ir.DataType.UINT32, | ||
mx.uint64: ir.DataType.UINT64, | ||
} | ||
super().__init__( | ||
tensor, | ||
dtype=_MLX_DTYPE_TO_ONNX[tensor.dtype], | ||
name=name, | ||
doc_string=doc_string, | ||
) | ||
|
||
def numpy(self) -> npt.NDArray: | ||
import mlx.core as mx | ||
|
||
self.raw: mx.array | ||
if self.dtype == ir.DataType.BFLOAT16: | ||
return np.array(self.raw.view(mx.uint16), copy=False).view( | ||
dtype=self.dtype.numpy() | ||
) | ||
return np.array(self.raw, copy=False) | ||
|
||
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: | ||
del copy # Unused, but needed for the signature | ||
if dtype is None: | ||
return self.numpy() | ||
return self.numpy().__array__(dtype) | ||
Check notice
Code scanning / lintrunner
PYLINT/C0411 Note