Skip to content

Arithmetic operations accept numpy arrays #102

Closed
@ev-br

Description

@ev-br

Supposedly, mixing array-api-strict arrays with other array types should not be allowed.

Or all of them should be allowed, but then we'd need to specify something like __array_priority__ and that opens quite a Pandora box, so I guess not?

In [5]: import numpy as np

In [6]: import array_api_strict as xp

In [7]: xp.arange(5, dtype=xp.int8) + np.arange(5, dtype=np.complex64)
Out[7]: array([0.+0.j, 2.+0.j, 4.+0.j, 6.+0.j, 8.+0.j], dtype=complex64)           # xp + np -> np !

In [8]: import torch

In [10]: xp.arange(5, dtype=xp.int8) + torch.arange(5)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 xp.arange(5, dtype=xp.int8) + torch.arange(5)

TypeError: unsupported operand type(s) for +: 'Array' and 'Tensor'

In [11]: import jax.numpy as jnp

In [12]: xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)

TypeError: unsupported operand type(s) for +: 'Array' and 'jaxlib.xla_extension.ArrayImpl'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions