Skip to content

Commit 5bce71d

Browse files
committed
Update numpy.py
1 parent eb1f844 commit 5bce71d

File tree

1 file changed

+193
-17
lines changed

1 file changed

+193
-17
lines changed

keras/src/backend/openvino/numpy.py

+193-17
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,36 @@ def argmin(x, axis=None, keepdims=False):
336336

337337

338338
def argsort(x, axis=-1):
339-
raise NotImplementedError(
340-
"`argsort` is not supported with openvino backend"
341-
)
339+
x = get_ov_output(x)
340+
x_shape = x.get_partial_shape()
341+
rank = x_shape.rank.get_length()
342+
if rank == 0:
343+
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
344+
if axis is None:
345+
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
346+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
347+
x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0)
348+
k = ov_opset.reduce_prod(
349+
x_shape_tensor, ov_opset.constant([0], Type.i32), keep_dims=False
350+
)
351+
axis = 0
352+
else:
353+
if axis < 0:
354+
axis = rank + axis
355+
x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0)
356+
k = ov_opset.gather(
357+
x_shape_tensor,
358+
ov_opset.constant(axis, Type.i32).output(0),
359+
ov_opset.constant(0, Type.i32).output(0),
360+
).output(0)
361+
sorted_indices = ov_opset.topk(
362+
x,
363+
k=k,
364+
axis=axis,
365+
mode="min",
366+
sort="value",
367+
).output(1)
368+
return OpenVINOKerasTensor(sorted_indices)
342369

343370

344371
def array(x, dtype=None):
@@ -380,9 +407,48 @@ def average(x, axis=None, weights=None):
380407

381408

382409
def bincount(x, weights=None, minlength=0, sparse=False):
383-
raise NotImplementedError(
384-
"`bincount` is not supported with openvino backend"
385-
)
410+
if x is None:
411+
raise ValueError("input x is None")
412+
if sparse:
413+
raise ValueError("Unsupported value `sparse=True`")
414+
x = get_ov_output(x)
415+
x_type = x.get_element_type()
416+
shape_x = ov_opset.shape_of(x, "i64").output(0)
417+
rank_x = ov_opset.shape_of(shape_x, "i64").output(0)
418+
rank_x = ov_opset.convert(rank_x, x_type).output(0)
419+
scalar_shape = ov_opset.constant([], x_type).output(0)
420+
rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0)
421+
const_minus_one = ov_opset.constant(-1, x_type).output(0)
422+
rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0)
423+
minlength = get_ov_output(minlength)
424+
minlength = ov_opset.convert(minlength, x_type).output(0)
425+
const_one = ov_opset.constant(1, x_type).output(0)
426+
const_zero = ov_opset.constant(0, x_type).output(0)
427+
max_element = ov_opset.reduce_max(x, const_zero, keep_dims=False).output(0)
428+
depth = ov_opset.add(max_element, const_one).output(0)
429+
depth = ov_opset.maximum(depth, minlength).output(0)
430+
depth_scalar = ov_opset.reduce_max(
431+
depth, const_zero, keep_dims=False
432+
).output(0)
433+
one_hot = ov_opset.one_hot(
434+
x, depth_scalar, const_one, const_zero, axis=-1
435+
).output(0)
436+
if weights is not None:
437+
weights = get_ov_output(weights)
438+
weights_type = weights.get_element_type()
439+
weights_new = ov_opset.reshape(weights, [-1, 1], False).output(0)
440+
one_hot = ov_opset.convert(one_hot, weights_type).output(0)
441+
final_one_hot = ov_opset.multiply(one_hot, weights_new).output(0)
442+
final_output = ov_opset.reduce_sum(
443+
final_one_hot, rank_minus_one, keep_dims=False
444+
).output(0)
445+
return OpenVINOKerasTensor(final_output)
446+
else:
447+
final_output = ov_opset.reduce_sum(
448+
one_hot, rank_minus_one, keep_dims=False
449+
).output(0)
450+
final_output = ov_opset.convert(final_output, Type.i32).output(0)
451+
return OpenVINOKerasTensor(final_output)
386452

387453

388454
def broadcast_to(x, shape):
@@ -502,7 +568,76 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
502568

503569

504570
def diff(a, n=1, axis=-1):
505-
raise NotImplementedError("`diff` is not supported with openvino backend")
571+
if n == 0:
572+
return OpenVINOKerasTensor(get_ov_output(a))
573+
if n < 0:
574+
raise ValueError("order must be non-negative but got " + repr(n))
575+
a = get_ov_output(a)
576+
a_type = a.get_element_type()
577+
if isinstance(a, np.ndarray):
578+
rank = a.ndim
579+
else:
580+
rank = a.get_partial_shape().rank.get_length()
581+
if axis < 0:
582+
axis = axis + rank
583+
result = a
584+
for _ in range(n):
585+
rank = result.get_partial_shape().rank.get_length()
586+
strides = ov_opset.constant(
587+
np.array([1] * rank, dtype=np.int64), Type.i64
588+
).output(0)
589+
590+
begin_upper_list = [0] * rank
591+
begin_upper_list[axis] = 1
592+
begin_upper = ov_opset.constant(
593+
np.array(begin_upper_list, dtype=np.int64), Type.i64
594+
).output(0)
595+
end_upper = ov_opset.constant(
596+
np.array([0] * rank, dtype=np.int64), Type.i64
597+
).output(0)
598+
begin_mask_upper = [1] * rank
599+
begin_mask_upper[axis] = 0
600+
end_mask_upper = [1] * rank
601+
upper = ov_opset.strided_slice(
602+
data=result,
603+
begin=begin_upper,
604+
end=end_upper,
605+
strides=strides,
606+
begin_mask=begin_mask_upper,
607+
end_mask=end_mask_upper,
608+
new_axis_mask=[],
609+
shrink_axis_mask=[],
610+
ellipsis_mask=[],
611+
).output(0)
612+
613+
begin_lower = ov_opset.constant(
614+
np.array([0] * rank, dtype=np.int64), Type.i64
615+
).output(0)
616+
end_lower_list = [0] * rank
617+
end_lower_list[axis] = -1
618+
end_lower = ov_opset.constant(
619+
np.array(end_lower_list, dtype=np.int64), Type.i64
620+
).output(0)
621+
begin_mask_lower = [1] * rank
622+
end_mask_lower = [1] * rank
623+
end_mask_lower[axis] = 0
624+
lower = ov_opset.strided_slice(
625+
data=result,
626+
begin=begin_lower,
627+
end=end_lower,
628+
strides=strides,
629+
begin_mask=begin_mask_lower,
630+
end_mask=end_mask_lower,
631+
new_axis_mask=[],
632+
shrink_axis_mask=[],
633+
ellipsis_mask=[],
634+
).output(0)
635+
636+
if a_type == Type.boolean:
637+
result = ov_opset.not_equal(upper, lower).output(0)
638+
else:
639+
result = ov_opset.subtract(upper, lower).output(0)
640+
return OpenVINOKerasTensor(result)
506641

507642

508643
def digitize(x, bins):
@@ -512,11 +647,30 @@ def digitize(x, bins):
512647

513648

514649
def dot(x, y):
515-
raise NotImplementedError("`dot` is not supported with openvino backend")
650+
element_type = None
651+
if isinstance(x, OpenVINOKerasTensor):
652+
element_type = x.output.get_element_type()
653+
if isinstance(y, OpenVINOKerasTensor):
654+
element_type = y.output.get_element_type()
655+
x = get_ov_output(x, element_type)
656+
y = get_ov_output(y, element_type)
657+
x, y = _align_operand_types(x, y, "dot()")
658+
if x.get_partial_shape().rank == 0 or y.get_partial_shape().rank == 0:
659+
return OpenVINOKerasTensor(ov_opset.multiply(x, y).output(0))
660+
return OpenVINOKerasTensor(ov_opset.matmul(x, y, False, False).output(0))
516661

517662

518663
def empty(shape, dtype=None):
519-
raise NotImplementedError("`empty` is not supported with openvino backend")
664+
dtype = standardize_dtype(dtype) or config.floatx()
665+
ov_type = OPENVINO_DTYPES[dtype]
666+
if isinstance(shape, tuple):
667+
shape = list(shape)
668+
elif isinstance(shape, int):
669+
shape = [shape]
670+
shape_node = ov_opset.constant(shape, Type.i32).output(0)
671+
const_zero = ov_opset.constant(0, dtype=ov_type).output(0)
672+
empty_tensor = ov_opset.broadcast(const_zero, shape_node).output(0)
673+
return OpenVINOKerasTensor(empty_tensor)
520674

521675

522676
def equal(x1, x2):
@@ -533,14 +687,17 @@ def equal(x1, x2):
533687

534688
def exp(x):
535689
x = get_ov_output(x)
690+
x_type = x.get_element_type()
691+
if x_type.is_integral():
692+
ov_type = OPENVINO_DTYPES[config.floatx()]
693+
x = ov_opset.convert(x, ov_type)
536694
return OpenVINOKerasTensor(ov_opset.exp(x).output(0))
537695

538696

539697
def expand_dims(x, axis):
540-
if isinstance(x, OpenVINOKerasTensor):
541-
x = x.output
542-
else:
543-
assert False
698+
x = get_ov_output(x)
699+
if isinstance(axis, tuple):
700+
axis = list(axis)
544701
axis = ov_opset.constant(axis, Type.i32).output(0)
545702
return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0))
546703

@@ -571,9 +728,15 @@ def full(shape, fill_value, dtype=None):
571728

572729

573730
def full_like(x, fill_value, dtype=None):
574-
raise NotImplementedError(
575-
"`full_like` is not supported with openvino backend"
576-
)
731+
x = get_ov_output(x)
732+
shape_x = ov_opset.shape_of(x)
733+
if dtype is not None:
734+
ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]
735+
else:
736+
ov_type = x.get_element_type()
737+
const_value = ov_opset.constant(fill_value, ov_type).output(0)
738+
res = ov_opset.broadcast(const_value, shape_x).output(0)
739+
return OpenVINOKerasTensor(res)
577740

578741

579742
def greater(x1, x2):
@@ -601,7 +764,20 @@ def greater_equal(x1, x2):
601764

602765

603766
def hstack(xs):
604-
raise NotImplementedError("`hstack` is not supported with openvino backend")
767+
if not isinstance(xs, (list, tuple)):
768+
raise TypeError("Input to `hstack` must be a list or tuple of tensors.")
769+
if len(xs) == 0:
770+
raise ValueError("Input list to `hstack` cannot be empty.")
771+
element_type = None
772+
for x in xs:
773+
if isinstance(x, OpenVINOKerasTensor):
774+
element_type = x.output.get_element_type()
775+
break
776+
xs = [get_ov_output(x, element_type) for x in xs]
777+
xs = _align_operand_types(xs[0], xs[1], "hstack()")
778+
rank = len(xs[0].get_partial_shape())
779+
axis = 1 if rank > 1 else 0
780+
return OpenVINOKerasTensor(ov_opset.concat(xs, axis=axis).output(0))
605781

606782

607783
def identity(n, dtype=None):

0 commit comments

Comments
 (0)