From 5bce71db3efd260ad656b7b849806af9f4e22c0e Mon Sep 17 00:00:00 2001 From: Sonal kumari <165447633+Hmm-1224@users.noreply.github.com> Date: Thu, 13 Mar 2025 22:43:26 +0530 Subject: [PATCH 1/6] Update numpy.py --- keras/src/backend/openvino/numpy.py | 210 +++++++++++++++++++++++++--- 1 file changed, 193 insertions(+), 17 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index e539d649baa1..25ffe866ea63 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -336,9 +336,36 @@ def argmin(x, axis=None, keepdims=False): def argsort(x, axis=-1): - raise NotImplementedError( - "`argsort` is not supported with openvino backend" - ) + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + if rank == 0: + return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0)) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) + k = ov_opset.reduce_prod( + x_shape_tensor, ov_opset.constant([0], Type.i32), keep_dims=False + ) + axis = 0 + else: + if axis < 0: + axis = rank + axis + x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) + k = ov_opset.gather( + x_shape_tensor, + ov_opset.constant(axis, Type.i32).output(0), + ov_opset.constant(0, Type.i32).output(0), + ).output(0) + sorted_indices = ov_opset.topk( + x, + k=k, + axis=axis, + mode="min", + sort="value", + ).output(1) + return OpenVINOKerasTensor(sorted_indices) def array(x, dtype=None): @@ -380,9 +407,48 @@ def average(x, axis=None, weights=None): def bincount(x, weights=None, minlength=0, sparse=False): - raise NotImplementedError( - "`bincount` is not supported with openvino backend" - ) + if x is None: + raise ValueError("input x is None") + if sparse: + raise ValueError("Unsupported value `sparse=True`") + x = get_ov_output(x) + x_type = x.get_element_type() + shape_x = ov_opset.shape_of(x, "i64").output(0) + rank_x = ov_opset.shape_of(shape_x, "i64").output(0) + rank_x = ov_opset.convert(rank_x, x_type).output(0) + scalar_shape = ov_opset.constant([], x_type).output(0) + rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0) + const_minus_one = ov_opset.constant(-1, x_type).output(0) + rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0) + minlength = get_ov_output(minlength) + minlength = ov_opset.convert(minlength, x_type).output(0) + const_one = ov_opset.constant(1, x_type).output(0) + const_zero = ov_opset.constant(0, x_type).output(0) + max_element = ov_opset.reduce_max(x, const_zero, keep_dims=False).output(0) + depth = ov_opset.add(max_element, const_one).output(0) + depth = ov_opset.maximum(depth, minlength).output(0) + depth_scalar = ov_opset.reduce_max( + depth, const_zero, keep_dims=False + ).output(0) + one_hot = ov_opset.one_hot( + x, depth_scalar, const_one, const_zero, axis=-1 + ).output(0) + if weights is not None: + weights = get_ov_output(weights) + weights_type = weights.get_element_type() + weights_new = ov_opset.reshape(weights, [-1, 1], False).output(0) + one_hot = ov_opset.convert(one_hot, weights_type).output(0) + final_one_hot = ov_opset.multiply(one_hot, weights_new).output(0) + final_output = ov_opset.reduce_sum( + final_one_hot, rank_minus_one, keep_dims=False + ).output(0) + return OpenVINOKerasTensor(final_output) + else: + final_output = ov_opset.reduce_sum( + one_hot, rank_minus_one, keep_dims=False + ).output(0) + final_output = ov_opset.convert(final_output, Type.i32).output(0) + return OpenVINOKerasTensor(final_output) def broadcast_to(x, shape): @@ -502,7 +568,76 @@ def diagonal(x, offset=0, axis1=0, axis2=1): def diff(a, n=1, axis=-1): - raise NotImplementedError("`diff` is not supported with openvino backend") + if n == 0: + return OpenVINOKerasTensor(get_ov_output(a)) + if n < 0: + raise ValueError("order must be non-negative but got " + repr(n)) + a = get_ov_output(a) + a_type = a.get_element_type() + if isinstance(a, np.ndarray): + rank = a.ndim + else: + rank = a.get_partial_shape().rank.get_length() + if axis < 0: + axis = axis + rank + result = a + for _ in range(n): + rank = result.get_partial_shape().rank.get_length() + strides = ov_opset.constant( + np.array([1] * rank, dtype=np.int64), Type.i64 + ).output(0) + + begin_upper_list = [0] * rank + begin_upper_list[axis] = 1 + begin_upper = ov_opset.constant( + np.array(begin_upper_list, dtype=np.int64), Type.i64 + ).output(0) + end_upper = ov_opset.constant( + np.array([0] * rank, dtype=np.int64), Type.i64 + ).output(0) + begin_mask_upper = [1] * rank + begin_mask_upper[axis] = 0 + end_mask_upper = [1] * rank + upper = ov_opset.strided_slice( + data=result, + begin=begin_upper, + end=end_upper, + strides=strides, + begin_mask=begin_mask_upper, + end_mask=end_mask_upper, + new_axis_mask=[], + shrink_axis_mask=[], + ellipsis_mask=[], + ).output(0) + + begin_lower = ov_opset.constant( + np.array([0] * rank, dtype=np.int64), Type.i64 + ).output(0) + end_lower_list = [0] * rank + end_lower_list[axis] = -1 + end_lower = ov_opset.constant( + np.array(end_lower_list, dtype=np.int64), Type.i64 + ).output(0) + begin_mask_lower = [1] * rank + end_mask_lower = [1] * rank + end_mask_lower[axis] = 0 + lower = ov_opset.strided_slice( + data=result, + begin=begin_lower, + end=end_lower, + strides=strides, + begin_mask=begin_mask_lower, + end_mask=end_mask_lower, + new_axis_mask=[], + shrink_axis_mask=[], + ellipsis_mask=[], + ).output(0) + + if a_type == Type.boolean: + result = ov_opset.not_equal(upper, lower).output(0) + else: + result = ov_opset.subtract(upper, lower).output(0) + return OpenVINOKerasTensor(result) def digitize(x, bins): @@ -512,11 +647,30 @@ def digitize(x, bins): def dot(x, y): - raise NotImplementedError("`dot` is not supported with openvino backend") + element_type = None + if isinstance(x, OpenVINOKerasTensor): + element_type = x.output.get_element_type() + if isinstance(y, OpenVINOKerasTensor): + element_type = y.output.get_element_type() + x = get_ov_output(x, element_type) + y = get_ov_output(y, element_type) + x, y = _align_operand_types(x, y, "dot()") + if x.get_partial_shape().rank == 0 or y.get_partial_shape().rank == 0: + return OpenVINOKerasTensor(ov_opset.multiply(x, y).output(0)) + return OpenVINOKerasTensor(ov_opset.matmul(x, y, False, False).output(0)) def empty(shape, dtype=None): - raise NotImplementedError("`empty` is not supported with openvino backend") + dtype = standardize_dtype(dtype) or config.floatx() + ov_type = OPENVINO_DTYPES[dtype] + if isinstance(shape, tuple): + shape = list(shape) + elif isinstance(shape, int): + shape = [shape] + shape_node = ov_opset.constant(shape, Type.i32).output(0) + const_zero = ov_opset.constant(0, dtype=ov_type).output(0) + empty_tensor = ov_opset.broadcast(const_zero, shape_node).output(0) + return OpenVINOKerasTensor(empty_tensor) def equal(x1, x2): @@ -533,14 +687,17 @@ def equal(x1, x2): def exp(x): x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) return OpenVINOKerasTensor(ov_opset.exp(x).output(0)) def expand_dims(x, axis): - if isinstance(x, OpenVINOKerasTensor): - x = x.output - else: - assert False + x = get_ov_output(x) + if isinstance(axis, tuple): + axis = list(axis) axis = ov_opset.constant(axis, Type.i32).output(0) return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0)) @@ -571,9 +728,15 @@ def full(shape, fill_value, dtype=None): def full_like(x, fill_value, dtype=None): - raise NotImplementedError( - "`full_like` is not supported with openvino backend" - ) + x = get_ov_output(x) + shape_x = ov_opset.shape_of(x) + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + else: + ov_type = x.get_element_type() + const_value = ov_opset.constant(fill_value, ov_type).output(0) + res = ov_opset.broadcast(const_value, shape_x).output(0) + return OpenVINOKerasTensor(res) def greater(x1, x2): @@ -601,7 +764,20 @@ def greater_equal(x1, x2): def hstack(xs): - raise NotImplementedError("`hstack` is not supported with openvino backend") + if not isinstance(xs, (list, tuple)): + raise TypeError("Input to `hstack` must be a list or tuple of tensors.") + if len(xs) == 0: + raise ValueError("Input list to `hstack` cannot be empty.") + element_type = None + for x in xs: + if isinstance(x, OpenVINOKerasTensor): + element_type = x.output.get_element_type() + break + xs = [get_ov_output(x, element_type) for x in xs] + xs = _align_operand_types(xs[0], xs[1], "hstack()") + rank = len(xs[0].get_partial_shape()) + axis = 1 if rank > 1 else 0 + return OpenVINOKerasTensor(ov_opset.concat(xs, axis=axis).output(0)) def identity(n, dtype=None): From 6fb3be25a836b2db18b73ef5c5cb380c46b206eb Mon Sep 17 00:00:00 2001 From: Sonal kumari <165447633+Hmm-1224@users.noreply.github.com> Date: Thu, 13 Mar 2025 22:45:31 +0530 Subject: [PATCH 2/6] Update excluded_concrete_tests.txt --- .../openvino/excluded_concrete_tests.txt | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index dea57a326241..208f598b4f55 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -9,9 +9,7 @@ NumpyDtypeTest::test_any NumpyDtypeTest::test_argmax NumpyDtypeTest::test_argmin NumpyDtypeTest::test_argpartition -NumpyDtypeTest::test_argsort NumpyDtypeTest::test_array -NumpyDtypeTest::test_bincount NumpyDtypeTest::test_bitwise NumpyDtypeTest::test_ceil NumpyDtypeTest::test_concatenate @@ -20,19 +18,13 @@ NumpyDtypeTest::test_cross NumpyDtypeTest::test_cumprod NumpyDtypeTest::test_cumsum_bool NumpyDtypeTest::test_diag -NumpyDtypeTest::test_diff NumpyDtypeTest::test_digitize -NumpyDtypeTest::test_dot NumpyDtypeTest::test_einsum -NumpyDtypeTest::test_empty NumpyDtypeTest::test_exp2 -NumpyDtypeTest::test_exp NumpyDtypeTest::test_expm1 NumpyDtypeTest::test_eye NumpyDtypeTest::test_flip NumpyDtypeTest::test_floor -NumpyDtypeTest::test_full_like -NumpyDtypeTest::test_hstack NumpyDtypeTest::test_identity NumpyDtypeTest::test_inner NumpyDtypeTest::test_isclose @@ -91,22 +83,19 @@ NumpyOneInputOpsCorrectnessTest::test_any NumpyOneInputOpsCorrectnessTest::test_argmax NumpyOneInputOpsCorrectnessTest::test_argmin NumpyOneInputOpsCorrectnessTest::test_argpartition -NumpyOneInputOpsCorrectnessTest::test_argsort NumpyOneInputOpsCorrectnessTest::test_array -NumpyOneInputOpsCorrectnessTest::test_bincount NumpyOneInputOpsCorrectnessTest::test_bitwise_invert NumpyOneInputOpsCorrectnessTest::test_conj NumpyOneInputOpsCorrectnessTest::test_correlate NumpyOneInputOpsCorrectnessTest::test_cumprod NumpyOneInputOpsCorrectnessTest::test_diag NumpyOneInputOpsCorrectnessTest::test_diagonal -NumpyOneInputOpsCorrectnessTest::test_diff -NumpyOneInputOpsCorrectnessTest::test_dot NumpyOneInputOpsCorrectnessTest::test_exp NumpyOneInputOpsCorrectnessTest::test_expand_dims +NumpyOneInputOpsCorrectnessTest::test_exp2 +NumpyOneInputOpsCorrectnessTest::test_expm1 NumpyOneInputOpsCorrectnessTest::test_flip NumpyOneInputOpsCorrectnessTest::test_floor_divide -NumpyOneInputOpsCorrectnessTest::test_hstack NumpyOneInputOpsCorrectnessTest::test_imag NumpyOneInputOpsCorrectnessTest::test_isfinite NumpyOneInputOpsCorrectnessTest::test_isinf @@ -167,7 +156,6 @@ NumpyTwoInputOpsCorrectnessTest::test_cross NumpyTwoInputOpsCorrectnessTest::test_digitize NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan NumpyTwoInputOpsCorrectnessTest::test_einsum -NumpyTwoInputOpsCorrectnessTest::test_full_like NumpyTwoInputOpsCorrectnessTest::test_inner NumpyTwoInputOpsCorrectnessTest::test_isclose NumpyTwoInputOpsCorrectnessTest::test_linspace @@ -177,4 +165,4 @@ NumpyTwoInputOpsCorrectnessTest::test_quantile NumpyTwoInputOpsCorrectnessTest::test_take_along_axis NumpyTwoInputOpsCorrectnessTest::test_tensordot NumpyTwoInputOpsCorrectnessTest::test_vdot -NumpyTwoInputOpsCorrectnessTest::test_where \ No newline at end of file +NumpyTwoInputOpsCorrectnessTest::test_where From 4957517ea687da9aa332882f991da1c1cb2cb454 Mon Sep 17 00:00:00 2001 From: Sonal kumari <165447633+Hmm-1224@users.noreply.github.com> Date: Fri, 14 Mar 2025 11:34:09 +0530 Subject: [PATCH 3/6] Update excluded_concrete_tests.txt --- keras/src/backend/openvino/excluded_concrete_tests.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 7fb40bd261bf..208f598b4f55 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -25,7 +25,6 @@ NumpyDtypeTest::test_expm1 NumpyDtypeTest::test_eye NumpyDtypeTest::test_flip NumpyDtypeTest::test_floor -NumpyDtypeTest::test_hstack NumpyDtypeTest::test_identity NumpyDtypeTest::test_inner NumpyDtypeTest::test_isclose @@ -166,4 +165,4 @@ NumpyTwoInputOpsCorrectnessTest::test_quantile NumpyTwoInputOpsCorrectnessTest::test_take_along_axis NumpyTwoInputOpsCorrectnessTest::test_tensordot NumpyTwoInputOpsCorrectnessTest::test_vdot -NumpyTwoInputOpsCorrectnessTest::test_where \ No newline at end of file +NumpyTwoInputOpsCorrectnessTest::test_where From c7420cdead2fb0e6c0e6bca2069db681b989b356 Mon Sep 17 00:00:00 2001 From: Sonal kumari <165447633+Hmm-1224@users.noreply.github.com> Date: Tue, 29 Apr 2025 12:13:48 +0530 Subject: [PATCH 4/6] Update numpy.py --- keras/src/backend/openvino/numpy.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 98857a665e76..7b634559bd8e 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -830,17 +830,14 @@ def greater_equal(x1, x2): def hstack(xs): - if not isinstance(xs, (list, tuple)): - raise TypeError("Input to `hstack` must be a list or tuple of tensors.") - if len(xs) == 0: - raise ValueError("Input list to `hstack` cannot be empty.") element_type = None for x in xs: if isinstance(x, OpenVINOKerasTensor): element_type = x.output.get_element_type() break xs = [get_ov_output(x, element_type) for x in xs] - xs = _align_operand_types(xs[0], xs[1], "hstack()") + for i in range(1, len(xs)): + xs[0], xs[i] = _align_operand_types(xs[0], xs[i], "hstack()") rank = len(xs[0].get_partial_shape()) axis = 1 if rank > 1 else 0 return OpenVINOKerasTensor(ov_opset.concat(xs, axis=axis).output(0)) From 7af8c84fd6480f5ea39b611be476f520a2772030 Mon Sep 17 00:00:00 2001 From: Sonal kumari <165447633+Hmm-1224@users.noreply.github.com> Date: Tue, 29 Apr 2025 19:43:03 +0530 Subject: [PATCH 5/6] Update numpy.py --- keras/src/backend/openvino/numpy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 7b634559bd8e..822884c1d1e1 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,7 +1,6 @@ import numpy as np import openvino.runtime.opset14 as ov_opset from openvino import Type - from keras.src.backend import config from keras.src.backend.common import dtypes from keras.src.backend.common.variables import standardize_dtype From 87d3f3eaa5bc9da100ba53683db66446c19cf290 Mon Sep 17 00:00:00 2001 From: sonal Date: Tue, 29 Apr 2025 14:30:07 +0000 Subject: [PATCH 6/6] Apply api-gen pre-commit fixes --- keras/src/backend/openvino/numpy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 822884c1d1e1..7b634559bd8e 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,6 +1,7 @@ import numpy as np import openvino.runtime.opset14 as ov_opset from openvino import Type + from keras.src.backend import config from keras.src.backend.common import dtypes from keras.src.backend.common.variables import standardize_dtype