Skip to content

Commit 8cf0f28

Browse files
committed
add default values to arange
1 parent 9506c28 commit 8cf0f28

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

sharpy/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,22 @@ def _validate_device(device):
9696
raise ValueError(f"Invalid device string: {device}")
9797

9898

99+
def arange(start, /, end=None, step=1, dtype=int64, device="", team=1):
100+
if end is None:
101+
end = start
102+
start = 0
103+
assert step != 0, "step cannot be zero"
104+
if (end - start) * step < 0:
105+
# invalid range, return empty array
106+
start = end = 0
107+
step = 1
108+
return ndarray(
109+
_csp.Creator.arange(
110+
start, end, step, dtype, _validate_device(device), team
111+
)
112+
)
113+
114+
99115
for func in api.api_categories["Creator"]:
100116
FUNC = func.upper()
101117
if func == "full":
@@ -114,10 +130,6 @@ def _validate_device(device):
114130
exec(
115131
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
116132
)
117-
elif func == "arange":
118-
exec(
119-
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
120-
)
121133
elif func == "linspace":
122134
exec(
123135
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"

0 commit comments

Comments
 (0)